diff --git a/dubhe-server/.gitignore b/dubhe-server/.gitignore index ab86f52..a2e5604 100644 --- a/dubhe-server/.gitignore +++ b/dubhe-server/.gitignore @@ -48,6 +48,5 @@ output/ .classpath logs/ /dubhe-k8s/src/main/resources/kubeconfig -/dubhe-k8s/src/main/resources *.log /dubhe-admin/kubeconfig \ No newline at end of file diff --git a/dubhe-server/README.md b/dubhe-server/README.md index b6fa52c..333111d 100644 --- a/dubhe-server/README.md +++ b/dubhe-server/README.md @@ -1,13 +1,22 @@ -# 一站式开发平台-服务端 +# 之江天枢-服务端 -## 本地开发 +**之江天枢一站式人工智能开源平台**(简称:**之江天枢**),包括海量数据处理、交互式模型构建(包含Notebook和模型可视化)、AI模型高效训练。多维度产品形态满足从开发者到大型企业的不同需求,将提升人工智能技术的研发效率、扩大算法模型的应用范围,进一步构建人工智能生态“朋友圈”。 + +## 源码部署 ### 准备环境 安装如下软件环境。 - OpenJDK:1.8+ - Redis: 3.0+ - Maven: 3.0+ -- MYSQL: 5.5.0+ +- MYSQL: 5.7.0+ + +### 下载源码 +``` bash +git clone https://codeup.teambition.com/zhejianglab/dubhe-server.git +# 进入项目根目录 +cd dubhe-server +``` ### 创建DB 在MySQL中依次执行如下sql文件 @@ -20,14 +29,35 @@ sql/v1/02-Dubhe-DML.sql ### 配置 根据实际情况修改如下配置文件。 ``` -dubhe-admin/src/main/resources/config/application-dev.yml +dubhe-admin/src/main/resources/config/application-prod.yml ``` -### 启动: +### 构建 +``` bash +# 构建,生成的 jar 包位于 ./dubhe-admin/target/dubhe-admin-1.0.jar +mvn clean compile package ``` -mvn spring-boot:run + +### 启动 +``` bash +# 指定启动环境为 prod +## admin模块 +java -jar ./dubhe-admin/target/dubhe-admin-1.0-exec.jar --spring.profiles.active=prod + +## task模块 +java -jar ./dubhe-task/target/dubhe-task-1.0.jar --spring.profiles.active=prod ``` +## 本地开发 + +### 必要条件: + 导入maven项目,下载所需的依赖包 + mysql下创建数据库dubhe,初始化数据脚本 + 安装redis + +### 启动: + mvn spring-boot:run + ## 代码结构: ``` ├── common 公共模块 @@ -49,6 +79,7 @@ mvn spring-boot:run ├── dubhe-data 数据处理模块 ├── dubhe-model 模型管理模块 ├── dubhe-system 系统管理 +├── dubhe-task 定时任务模块 ``` ## docker服务器 diff --git a/dubhe-server/common/pom.xml b/dubhe-server/common/pom.xml index 21b9ea8..008f43c 100644 --- a/dubhe-server/common/pom.xml +++ b/dubhe-server/common/pom.xml @@ -29,11 +29,6 @@ guava 21.0 - - com.github.penggle - kaptcha - ${kaptcha.version} - org.apache.shiro @@ -95,6 +90,10 @@ commons-compress 1.20 + + com.github.whvcse + easy-captcha + diff --git a/dubhe-server/common/src/main/java/org/dubhe/annotation/DataPermission.java b/dubhe-server/common/src/main/java/org/dubhe/annotation/DataPermission.java index 96361bb..34de34a 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/annotation/DataPermission.java +++ b/dubhe-server/common/src/main/java/org/dubhe/annotation/DataPermission.java @@ -14,31 +14,28 @@ * limitations under the License. * ============================================================= */ - package org.dubhe.annotation; -import java.lang.annotation.*; - /** - * 数据权限过滤Mapper拦截 - * - * @date 2020-06-22 + * @description 数据权限注解 + * @date 2020-09-24 */ +import java.lang.annotation.*; + @Target({ElementType.METHOD, ElementType.TYPE}) @Retention(RetentionPolicy.RUNTIME) @Documented public @interface DataPermission { /** - * 不需要数据权限的方法名 + * 只在类的注解上使用,代表方法的数据权限类型 + * @return */ - String[] ignores() default {}; + String permission() default ""; /** - * 只在方法的注解上使用,代表方法的数据权限类型,如果不加注解,只会识别带"select"方法名的方法 - * + * 不需要数据权限的方法名 * @return */ - String[] permission() default {}; - + String[] ignoresMethod() default {}; } diff --git a/dubhe-server/common/src/main/java/org/dubhe/annotation/DataPermissionMethod.java b/dubhe-server/common/src/main/java/org/dubhe/annotation/DataPermissionMethod.java new file mode 100644 index 0000000..0d4e087 --- /dev/null +++ b/dubhe-server/common/src/main/java/org/dubhe/annotation/DataPermissionMethod.java @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ +package org.dubhe.annotation; + +import org.dubhe.enums.DatasetTypeEnum; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * @description 数据权限方法注解 + * @date 2020-09-24 + */ +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +public @interface DataPermissionMethod { + + /** + * 是否需要拦截标识 true: 不拦截 false: 拦截 + * + * @return 拦截标识 + */ + boolean interceptFlag() default false; + + /** + * 数据类型 + * + * @return 数据集类型 + */ + DatasetTypeEnum dataType() default DatasetTypeEnum.PRIVATE; +} diff --git a/dubhe-server/common/src/main/java/org/dubhe/annotation/EnumValue.java b/dubhe-server/common/src/main/java/org/dubhe/annotation/EnumValue.java index 1fb7c93..d5f15b8 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/annotation/EnumValue.java +++ b/dubhe-server/common/src/main/java/org/dubhe/annotation/EnumValue.java @@ -17,19 +17,20 @@ package org.dubhe.annotation; +import javax.validation.Constraint; +import javax.validation.ConstraintValidator; +import javax.validation.ConstraintValidatorContext; +import javax.validation.Payload; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; -import javax.validation.Constraint; -import javax.validation.ConstraintValidator; -import javax.validation.ConstraintValidatorContext; -import javax.validation.Payload; /** - * @date: 2020-05-21 + * @description 接口枚举类检测标注类 + * @date 2020-05-21 */ @Target({ ElementType.METHOD, ElementType.FIELD, ElementType.ANNOTATION_TYPE }) @Retention(RetentionPolicy.RUNTIME) diff --git a/dubhe-server/common/src/main/java/org/dubhe/annotation/FlagValidator.java b/dubhe-server/common/src/main/java/org/dubhe/annotation/FlagValidator.java new file mode 100644 index 0000000..d47071b --- /dev/null +++ b/dubhe-server/common/src/main/java/org/dubhe/annotation/FlagValidator.java @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ +package org.dubhe.annotation; + +import javax.validation.Constraint; +import javax.validation.ConstraintValidator; +import javax.validation.ConstraintValidatorContext; +import javax.validation.Payload; +import java.lang.annotation.*; +import java.util.Arrays; + +/** + * @description 自定义状态校验注解(传入值是否在指定状态范围内) + * @date 2020-09-18 + */ +@Target({ElementType.FIELD, ElementType.PARAMETER}) +@Retention(RetentionPolicy.RUNTIME) +@Constraint(validatedBy = FlagValidator.Validator.class) +@Documented +public @interface FlagValidator { + + String[] value() default {}; + + String message() default "flag value is invalid"; + + Class[] groups() default {}; + + Class[] payload() default {}; + + /** + * @description 校验传入值是否在默认值范围校验逻辑 + * @date 2020-09-18 + */ + class Validator implements ConstraintValidator { + + private String[] values; + + @Override + public void initialize(FlagValidator flagValidator) { + this.values = flagValidator.value(); + } + + @Override + public boolean isValid(Integer value, ConstraintValidatorContext constraintValidatorContext) { + if (value == null) { + //当状态为空时,使用默认值 + return false; + } + return Arrays.stream(values).anyMatch(value::equals); + } + } +} diff --git a/dubhe-server/common/src/main/java/org/dubhe/aspect/LogAspect.java b/dubhe-server/common/src/main/java/org/dubhe/aspect/LogAspect.java old mode 100644 new mode 100755 index ce7485f..71e3c77 --- a/dubhe-server/common/src/main/java/org/dubhe/aspect/LogAspect.java +++ b/dubhe-server/common/src/main/java/org/dubhe/aspect/LogAspect.java @@ -15,8 +15,7 @@ */ package org.dubhe.aspect; -import java.util.UUID; - +import lombok.extern.slf4j.Slf4j; import org.aspectj.lang.JoinPoint; import org.aspectj.lang.ProceedingJoinPoint; import org.aspectj.lang.annotation.Around; @@ -28,10 +27,11 @@ import org.slf4j.MDC; import org.springframework.stereotype.Component; import org.springframework.util.StringUtils; -import lombok.extern.slf4j.Slf4j; +import java.util.UUID; /** - * @date 2020/04/10 + * @description 日志切面 + * @date 2020-04-10 */ @Component @Aspect @@ -54,7 +54,7 @@ public class LogAspect { public void taskAspect() { } - @Pointcut(" serviceAspect() || taskAspect() ") + @Pointcut(" serviceAspect() ") public void aroundAspect() { } diff --git a/dubhe-server/common/src/main/java/org/dubhe/aspect/PermissionAspect.java b/dubhe-server/common/src/main/java/org/dubhe/aspect/PermissionAspect.java new file mode 100644 index 0000000..cac6f48 --- /dev/null +++ b/dubhe-server/common/src/main/java/org/dubhe/aspect/PermissionAspect.java @@ -0,0 +1,116 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ +package org.dubhe.aspect; + +import org.aspectj.lang.ProceedingJoinPoint; +import org.aspectj.lang.annotation.Around; +import org.aspectj.lang.annotation.Aspect; +import org.aspectj.lang.annotation.Pointcut; +import org.aspectj.lang.reflect.MethodSignature; +import org.dubhe.annotation.DataPermissionMethod; +import org.dubhe.base.BaseService; +import org.dubhe.base.DataContext; +import org.dubhe.domain.dto.CommonPermissionDataDTO; +import org.dubhe.enums.DatasetTypeEnum; +import org.dubhe.enums.LogEnum; +import org.dubhe.utils.JwtUtils; +import org.dubhe.utils.LogUtil; +import org.springframework.stereotype.Component; + +import java.lang.reflect.Method; +import java.util.HashSet; +import java.util.Objects; +import java.util.Set; + +/** + * @description 数据权限切面 + * @date 2020-09-24 + */ +@Aspect +@Component +public class PermissionAspect { + + /** + * 公共数据的有用户ID + */ + public static final Long PUBLIC_DATA_USER_ID = 0L; + + /** + * 基于注解的切面方法 + */ + @Pointcut("@annotation(org.dubhe.annotation.DataPermissionMethod)") + private void cutMethod() { + + } + + /** + *环绕通知 + * @param joinPoint 切入参数对象 + * @return 返回方法结果集 + * @throws Throwable + */ + @Around("cutMethod()") + public Object around(ProceedingJoinPoint joinPoint) throws Throwable { + // 获取方法传入参数 + Object[] params = joinPoint.getArgs(); + DataPermissionMethod dataPermissionMethod = getDeclaredAnnotation(joinPoint); + + if (!Objects.isNull(JwtUtils.getCurrentUserDto()) && !Objects.isNull(dataPermissionMethod)) { + Set ids = new HashSet<>(); + ids.add(JwtUtils.getCurrentUserDto().getId()); + CommonPermissionDataDTO commonPermissionDataDTO = CommonPermissionDataDTO.builder().type(dataPermissionMethod.interceptFlag()).resourceUserIds(ids).build(); + if (DatasetTypeEnum.PUBLIC.equals(dataPermissionMethod.dataType())) { + ids.add(PUBLIC_DATA_USER_ID); + commonPermissionDataDTO.setResourceUserIds(ids); + } + DataContext.set(commonPermissionDataDTO); + } + // 执行源方法 + try { + return joinPoint.proceed(params); + } finally { + // 模拟进行验证 + BaseService.removeContext(); + } + } + + /** + * 获取方法中声明的注解 + * + * @param joinPoint 切入参数对象 + * @return DataPermissionMethod 方法注解类型 + */ + public DataPermissionMethod getDeclaredAnnotation(ProceedingJoinPoint joinPoint){ + // 获取方法名 + String methodName = joinPoint.getSignature().getName(); + // 反射获取目标类 + Class targetClass = joinPoint.getTarget().getClass(); + // 拿到方法对应的参数类型 + Class[] parameterTypes = ((MethodSignature) joinPoint.getSignature()).getParameterTypes(); + // 根据类、方法、参数类型(重载)获取到方法的具体信息 + Method objMethod = null; + try { + objMethod = targetClass.getMethod(methodName, parameterTypes); + } catch (NoSuchMethodException e) { + LogUtil.error(LogEnum.BIZ_DATASET,"获取注解方法参数异常 error:{}",e); + } + // 拿到方法定义的注解信息 + DataPermissionMethod annotation = objMethod.getDeclaredAnnotation(DataPermissionMethod.class); + // 返回 + return annotation; + } +} diff --git a/dubhe-server/common/src/main/java/org/dubhe/base/BaseImageDTO.java b/dubhe-server/common/src/main/java/org/dubhe/base/BaseImageDTO.java index 950e453..a726f2d 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/base/BaseImageDTO.java +++ b/dubhe-server/common/src/main/java/org/dubhe/base/BaseImageDTO.java @@ -25,7 +25,7 @@ import java.io.Serializable; /** * @description 镜像基础类DTO - * @date: 2020-07-14 + * @date 2020-07-14 */ @Data @Accessors(chain = true) diff --git a/dubhe-server/common/src/main/java/org/dubhe/base/BaseService.java b/dubhe-server/common/src/main/java/org/dubhe/base/BaseService.java new file mode 100644 index 0000000..2a43bb9 --- /dev/null +++ b/dubhe-server/common/src/main/java/org/dubhe/base/BaseService.java @@ -0,0 +1,77 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ +package org.dubhe.base; + +import org.dubhe.constant.PermissionConstant; +import org.dubhe.domain.dto.UserDTO; +import org.dubhe.domain.entity.Role; +import org.dubhe.exception.BaseErrorCode; +import org.dubhe.exception.BusinessException; +import org.dubhe.utils.JwtUtils; +import org.springframework.util.CollectionUtils; + +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +/** + * @description 服务层基础数据公共方法类 + * @date 2020-03-27 + */ +public class BaseService { + + private BaseService (){} + + /** + * 校验是否具有管理员权限 + */ + public static void checkAdminPermission() { + if(!isAdmin()){ + throw new BusinessException(BaseErrorCode.DATASET_ADMIN_PERMISSION_ERROR); + } + } + + /** + * 校验是否是管理管理员 + * + * @return 校验标识 + */ + public static Boolean isAdmin() { + UserDTO currentUserDto = JwtUtils.getCurrentUserDto(); + if (currentUserDto != null && !CollectionUtils.isEmpty(currentUserDto.getRoles())) { + List roles = currentUserDto.getRoles(); + List roleList = roles.stream(). + filter(a -> a.getId().compareTo(PermissionConstant.ADMIN_USER_ID) == 0) + .collect(Collectors.toList()); + if (!CollectionUtils.isEmpty(roleList)) { + return true; + } + } + return false; + } + + + /** + * 清除本地线程数据权限数据 + */ + public static void removeContext(){ + if( !Objects.isNull(DataContext.get())){ + DataContext.remove(); + } + } + +} diff --git a/dubhe-server/common/src/main/java/org/dubhe/base/BaseVO.java b/dubhe-server/common/src/main/java/org/dubhe/base/BaseVO.java index 7a6fff0..4cf377e 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/base/BaseVO.java +++ b/dubhe-server/common/src/main/java/org/dubhe/base/BaseVO.java @@ -24,8 +24,8 @@ import java.io.Serializable; import java.sql.Timestamp; /** - * @description: VO基础类 - * @date: 2020-05-22 + * @description VO基础类 + * @date 2020-05-22 */ @Data public class BaseVO implements Serializable { diff --git a/dubhe-server/common/src/main/java/org/dubhe/base/DataContext.java b/dubhe-server/common/src/main/java/org/dubhe/base/DataContext.java new file mode 100644 index 0000000..20c8c57 --- /dev/null +++ b/dubhe-server/common/src/main/java/org/dubhe/base/DataContext.java @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.dubhe.base; + + +import org.dubhe.domain.dto.CommonPermissionDataDTO; + +/** + * @description 共享上下文数据集信息 + * @date 2020-04-10 + */ +public class DataContext { + + /** + * 私有化构造参数 + */ + private DataContext() { + } + + private static final ThreadLocal CONTEXT = new ThreadLocal<>(); + + /** + * 存放数据集信息 + * + * @param datasetVO + */ + public static void set(CommonPermissionDataDTO datasetVO) { + CONTEXT.set(datasetVO); + } + + /** + * 获取用户信息 + * + * @return + */ + public static CommonPermissionDataDTO get() { + return CONTEXT.get(); + } + + /** + * 清除当前线程内引用,防止内存泄漏 + */ + public static void remove() { + CONTEXT.remove(); + } + +} diff --git a/dubhe-server/common/src/main/java/org/dubhe/base/MagicNumConstant.java b/dubhe-server/common/src/main/java/org/dubhe/base/MagicNumConstant.java index a4376e7..0454bd0 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/base/MagicNumConstant.java +++ b/dubhe-server/common/src/main/java/org/dubhe/base/MagicNumConstant.java @@ -18,7 +18,8 @@ package org.dubhe.base; /** - * @date: 2020-05-14 + * @description 常用常量类 + * @date 2020-05-14 */ public final class MagicNumConstant { @@ -86,6 +87,7 @@ public final class MagicNumConstant { public static final long TWELVE_LONG = 12L; public static final long SIXTY_LONG = 60L; + public static final long THOUSAND_LONG = 1000L; public static final long TEN_THOUSAND_LONG = 10000L; public static final long ONE_ZERO_ONE_ZERO_ONE_ZERO_LONG = 101010L; public static final long NINE_ZERO_NINE_ZERO_NINE_ZERO_LONG = 909090L; diff --git a/dubhe-server/common/src/main/java/org/dubhe/base/PageQueryBase.java b/dubhe-server/common/src/main/java/org/dubhe/base/PageQueryBase.java index e39aaa2..77752ed 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/base/PageQueryBase.java +++ b/dubhe-server/common/src/main/java/org/dubhe/base/PageQueryBase.java @@ -26,8 +26,8 @@ import org.dubhe.constant.NumberConstant; import javax.validation.constraints.Min; /** - * @description: 分页基类 - * @date: 2020-05-8 + * @description 分页基类 + * @date 2020-05-08 */ @Data @Accessors(chain = true) diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/task/HarborProjectNameSyncTask.java b/dubhe-server/common/src/main/java/org/dubhe/base/ScheduleTaskHandler.java similarity index 52% rename from dubhe-server/dubhe-admin/src/main/java/org/dubhe/task/HarborProjectNameSyncTask.java rename to dubhe-server/common/src/main/java/org/dubhe/base/ScheduleTaskHandler.java index c2d2e9f..52da2a1 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/task/HarborProjectNameSyncTask.java +++ b/dubhe-server/common/src/main/java/org/dubhe/base/ScheduleTaskHandler.java @@ -14,32 +14,31 @@ * limitations under the License. * ============================================================= */ - -package org.dubhe.task; +package org.dubhe.base; import org.dubhe.enums.LogEnum; -import org.dubhe.service.PtImageService; import org.dubhe.utils.LogUtil; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.scheduling.annotation.Scheduled; -import org.springframework.stereotype.Component; /** - * @description 从harbor同步projectName - * @date 2020-6-23 - **/ -@Component -public class HarborProjectNameSyncTask { + * @description 定时任务处理器, 主要做日志标识 + * @date 2020-08-13 + */ +public class ScheduleTaskHandler { + + + public static void process(Handler handler) { + LogUtil.startScheduleTrace(); + try { + handler.run(); + } catch (Exception e) { + LogUtil.error(LogEnum.BIZ_SYS, "There is something wrong in schedule task handler :{}", e); + } finally { + LogUtil.cleanTrace(); + } + } - @Autowired - private PtImageService ptImageService; - /** - * 每天晚上11点开始同步 - **/ - @Scheduled(cron = "0 0 23 * * ?") - public void syncProjectName() { - LogUtil.info(LogEnum.BIZ_TRAIN, "开始到harbor同步projectName到harbor_project表。。。。。"); - ptImageService.harborImageNameSync(); + public interface Handler { + void run(); } } diff --git a/dubhe-server/common/src/main/java/org/dubhe/config/KaptchaConfig.java b/dubhe-server/common/src/main/java/org/dubhe/config/KaptchaConfig.java deleted file mode 100644 index 86bf416..0000000 --- a/dubhe-server/common/src/main/java/org/dubhe/config/KaptchaConfig.java +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2020 Zhejiang Lab. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================= - */ - -package org.dubhe.config; - -import com.google.code.kaptcha.impl.DefaultKaptcha; -import com.google.code.kaptcha.util.Config; -import org.springframework.context.annotation.Bean; -import org.springframework.stereotype.Component; - -import java.util.Properties; - -/** - * @description 验证码配置 - * @date 2020-02-23 - */ -@Component -public class KaptchaConfig { - private final static String CODE_LENGTH = "4"; - private final static String SESSION_KEY = "verification_session_key"; - - @Bean - public DefaultKaptcha defaultKaptcha() { - DefaultKaptcha defaultKaptcha = new DefaultKaptcha(); - Properties properties = new Properties(); - - // 设置边框 - properties.setProperty("kaptcha.border", "yes"); - // 设置边框颜色 - properties.setProperty("kaptcha.border.color", "105,179,90"); - // 设置字体颜色 - properties.setProperty("kaptcha.textproducer.font.color", "blue"); - // 设置图片宽度 - properties.setProperty("kaptcha.image.width", "108"); - // 设置图片高度 - properties.setProperty("kaptcha.image.height", "28"); - // 设置字体尺寸 - properties.setProperty("kaptcha.textproducer.font.size", "26"); - // 设置session key - properties.setProperty("kaptcha.session.key", SESSION_KEY); - // 设置验证码长度 - properties.setProperty("kaptcha.textproducer.char.length", CODE_LENGTH); - // 设置字体 - properties.setProperty("kaptcha.textproducer.font.names", "宋体,楷体,黑体"); - //去噪点 - properties.setProperty("kaptcha.noise.impl", "com.google.code.kaptcha.impl.NoNoise"); - Config config = new Config(properties); - defaultKaptcha.setConfig(config); - return defaultKaptcha; - } -} diff --git a/dubhe-server/common/src/main/java/org/dubhe/config/MetaHandlerConfig.java b/dubhe-server/common/src/main/java/org/dubhe/config/MetaHandlerConfig.java index b7e9860..f9d17fa 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/config/MetaHandlerConfig.java +++ b/dubhe-server/common/src/main/java/org/dubhe/config/MetaHandlerConfig.java @@ -1,12 +1,12 @@ /** * Copyright 2020 Zhejiang Lab. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -30,21 +30,17 @@ import java.util.Objects; /** * @description 处理新增和更新的基础数据填充,配合BaseEntity和MyBatisPlusConfig使用 - * @date 2020-6-10 + * @date 2020-06-10 */ @Component public class MetaHandlerConfig implements MetaObjectHandler { - private final String LOCK_USER_ID = "LOCK_USER_ID"; - - /** * 新增数据执行 * - * @param metaObject + * @param metaObject 基础数据 */ - @Override public void insertFill(MetaObject metaObject) { if (Objects.isNull(getFieldValByName(StringConstant.CREATE_TIME, metaObject))) { @@ -53,13 +49,14 @@ public class MetaHandlerConfig implements MetaObjectHandler { if (Objects.isNull(getFieldValByName(StringConstant.UPDATE_TIME, metaObject))) { this.setFieldValByName(StringConstant.UPDATE_TIME, DateUtil.getCurrentTimestamp(), metaObject); } - synchronized (LOCK_USER_ID){ - if (Objects.isNull(getFieldValByName(StringConstant.UPDATE_USER_ID, metaObject))) { - this.setFieldValByName(StringConstant.UPDATE_USER_ID, getUserId(), metaObject); - } - if (Objects.isNull(getFieldValByName(StringConstant.CREATE_USER_ID, metaObject))) { - this.setFieldValByName(StringConstant.CREATE_USER_ID, getUserId(), metaObject); - } + if (Objects.isNull(getFieldValByName(StringConstant.UPDATE_USER_ID, metaObject))) { + this.setFieldValByName(StringConstant.UPDATE_USER_ID, getUserId(), metaObject); + } + if (Objects.isNull(getFieldValByName(StringConstant.CREATE_USER_ID, metaObject))) { + this.setFieldValByName(StringConstant.CREATE_USER_ID, getUserId(), metaObject); + } + if (Objects.isNull(getFieldValByName(StringConstant.ORIGIN_USER_ID, metaObject))) { + this.setFieldValByName(StringConstant.ORIGIN_USER_ID, getUserId(), metaObject); } if (Objects.isNull(getFieldValByName(StringConstant.DELETED, metaObject))) { this.setFieldValByName(StringConstant.DELETED, SwitchEnum.getBooleanValue(SwitchEnum.OFF.getValue()), metaObject); @@ -69,7 +66,7 @@ public class MetaHandlerConfig implements MetaObjectHandler { /** * 更新数据执行 * - * @param metaObject + * @param metaObject 基础数据 */ @Override public void updateFill(MetaObject metaObject) { diff --git a/dubhe-server/common/src/main/java/org/dubhe/config/MybatisPlusConfig.java b/dubhe-server/common/src/main/java/org/dubhe/config/MybatisPlusConfig.java index 18deab6..8235596 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/config/MybatisPlusConfig.java +++ b/dubhe-server/common/src/main/java/org/dubhe/config/MybatisPlusConfig.java @@ -17,294 +17,27 @@ package org.dubhe.config; -import com.baomidou.mybatisplus.core.override.MybatisMapperProxy; -import com.baomidou.mybatisplus.core.parser.ISqlParser; -import com.baomidou.mybatisplus.core.parser.SqlParserHelper; -import com.baomidou.mybatisplus.extension.plugins.PaginationInterceptor; -import com.baomidou.mybatisplus.extension.plugins.tenant.TenantHandler; -import com.baomidou.mybatisplus.extension.plugins.tenant.TenantSqlParser; -import com.google.common.collect.Sets; -import net.sf.jsqlparser.expression.Expression; -import net.sf.jsqlparser.expression.LongValue; -import net.sf.jsqlparser.expression.operators.relational.ExpressionList; -import net.sf.jsqlparser.expression.operators.relational.InExpression; -import net.sf.jsqlparser.schema.Column; -import org.apache.ibatis.mapping.MappedStatement; -import org.apache.shiro.UnavailableSecurityManagerException; -import org.dubhe.annotation.DataPermission; -import org.dubhe.base.MagicNumConstant; -import org.dubhe.constant.PermissionConstant; -import org.dubhe.domain.dto.UserDTO; -import org.dubhe.domain.entity.Role; -import org.dubhe.enums.LogEnum; -import org.dubhe.utils.JwtUtils; -import org.dubhe.utils.LogUtil; -import org.springframework.beans.BeansException; -import org.springframework.context.ApplicationContext; -import org.springframework.context.ApplicationContextAware; -import org.springframework.context.ApplicationListener; +import org.dubhe.interceptor.PaginationInterceptor; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.springframework.context.event.ContextRefreshedEvent; -import org.springframework.core.annotation.AnnotationUtils; import org.springframework.transaction.annotation.EnableTransactionManagement; -import org.springframework.util.CollectionUtils; - -import java.lang.annotation.Annotation; -import java.lang.reflect.Field; -import java.lang.reflect.Method; -import java.lang.reflect.Proxy; -import java.util.*; -import java.util.stream.Collectors; - /** - * @description MybatisPlus配置类 - * @date 2020-06-24 + * @description mybatis plus拦截器 + * @date 2020-06-10 */ @EnableTransactionManagement @Configuration -public class MybatisPlusConfig implements ApplicationListener, ApplicationContextAware { - - /** - * 以此字段作为租户实现数据隔离 - */ - private static final String TENANT_ID_COLUMN = "create_user_id"; - /** - * 以0作为公共数据的标识 - */ - private static final long PUBLIC_TENANT_ID = MagicNumConstant.ZERO; - private static final Set PUBLIC_TENANT_ID_SET = new HashSet() {{ - add(PUBLIC_TENANT_ID); - }}; - private static final String PACKAGE_SEPARATOR = "."; - private static final Set SELECT_PERMISSION = new HashSet() {{ - add(PermissionConstant.SELECT); - }}; - private static final Set UPDATE_DELETE_PERMISSION = new HashSet() {{ - add(PermissionConstant.UPDATE); - add(PermissionConstant.DELETE); - }}; +public class MybatisPlusConfig { - private static final String SELECT_STR = "select"; /** - * 优先级高于dataFilters,如果ignore,则不进行sql注入 - */ - private Map> dataFilters = new HashMap<>(); - - private ApplicationContext applicationContext; - public Set tenantId; - - /** - * mybatis plus 分页插件 - * 其中增加了通过多租户实现了数据权限功能 + * 注入 MybatisPlus 分页拦截器 * - * @return + * @return 自定义MybatisPlus分页拦截器 */ @Bean public PaginationInterceptor paginationInterceptor() { PaginationInterceptor paginationInterceptor = new PaginationInterceptor(); - List sqlParserList = new ArrayList<>(); - TenantSqlParser tenantSqlParser = new TenantSqlParser(); - tenantSqlParser.setTenantHandler(new TenantHandler() { - @Override - public Expression getTenantId(boolean where) { - Set tenants = tenantId; - - final boolean multipleTenantIds = tenants.size() > MagicNumConstant.ONE; - if (multipleTenantIds) { - return multipleTenantIdCondition(tenants); - } else { - return singleTenantIdCondition(tenants); - } - } - - private Expression singleTenantIdCondition(Set tenants) { - return new LongValue((Long) tenants.toArray()[0]); - } - - private Expression multipleTenantIdCondition(Set tenants) { - final InExpression inExpression = new InExpression(); - inExpression.setLeftExpression(new Column(getTenantIdColumn())); - final ExpressionList itemsList = new ExpressionList(); - final List inValues = new ArrayList<>(tenants.size()); - tenants.forEach(i -> - inValues.add(new LongValue(i)) - ); - itemsList.setExpressions(inValues); - inExpression.setRightItemsList(itemsList); - return inExpression; - } - - @Override - public String getTenantIdColumn() { - return TENANT_ID_COLUMN; - } - - @Override - public boolean doTableFilter(String tableName) { - return false; - } - }); - sqlParserList.add(tenantSqlParser); - paginationInterceptor.setSqlParserList(sqlParserList); - paginationInterceptor.setSqlParserFilter(metaObject -> { - MappedStatement ms = SqlParserHelper.getMappedStatement(metaObject); - String method = ms.getId(); - if (!dataFilters.containsKey(method) || isAdmin()) { - return true; - } - Set permission = dataFilters.get(method); - tenantId = getTenantId(permission); - return false; - }); return paginationInterceptor; } - - /** - * 判断用户是否是管理员 - * 如果未登录,无法请求任何接口,所以不会到该层,因此匿名认为是定时任务,给予admin权限。 - * - * @return 判断用户是否是管理员 - */ - private boolean isAdmin() { - UserDTO user; - try { - user = JwtUtils.getCurrentUserDto(); - } catch (UnavailableSecurityManagerException e) { - return true; - } - if (Objects.isNull(user)) { - return true; - } - List roles; - if ((roles = user.getRoles()) == null) { - return false; - } - Set permissions = roles.stream().map(Role::getPermission).collect(Collectors.toSet()); - if (CollectionUtils.isEmpty(permissions)) { - return false; - } - return user.getId() == PermissionConstant.ANONYMOUS_USER || user.getId() == PermissionConstant.ADMIN_USER_ID; - } - - /** - * 如果是管理员,在前一步isAdmin已过滤; - * 如果是匿名用户,在shiro层被过滤; - * 因此只会是无角色、权限用户或普通用户 - * - * @return Set 租户ID集合 - */ - private Set getTenantId(Set permission) { - UserDTO user = JwtUtils.getCurrentUserDto(); - List roles; - if (Objects.isNull(user) || (roles = user.getRoles()) == null) { - if (permission.contains(PermissionConstant.SELECT)) { - return PUBLIC_TENANT_ID_SET; - } - return Collections.EMPTY_SET; - } - Set permissions = roles.stream().map(Role::getPermission).collect(Collectors.toSet()); - if (CollectionUtils.isEmpty(permissions)) { - if (permission.contains(PermissionConstant.SELECT)) { - return PUBLIC_TENANT_ID_SET; - } - return Collections.EMPTY_SET; - } - if (permission.contains(PermissionConstant.SELECT)) { - return new HashSet() {{ - add(PUBLIC_TENANT_ID); - add(user.getId()); - }}; - } - return new HashSet() {{ - add(user.getId()); - }}; - } - - /** - * 设置上下文 - * #需要通过上下文 获取SpringBean - * - * @param applicationContext spring上下文 - * @throws BeansException 找不到bean异常 - */ - @Override - public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { - this.applicationContext = applicationContext; - } - - @Override - public void onApplicationEvent(ContextRefreshedEvent contextRefreshedEvent) { - Class annotationClass = DataPermission.class; - Map beanWithAnnotation = applicationContext.getBeansWithAnnotation(annotationClass); - Set> entitySet = beanWithAnnotation.entrySet(); - for (Map.Entry entry : entitySet) { - Proxy proxy = (Proxy) entry.getValue(); - Class clazz = getMapperClass(proxy); - populateDataFilters(clazz); - } - } - - /** - * 根据mapper对应代理对象获取Class - * - * @param proxy mapper对应代理对象 - * @return - */ - private Class getMapperClass(Proxy proxy) { - try { - Field field = proxy.getClass().getSuperclass().getDeclaredField("h"); - field.setAccessible(true); - MybatisMapperProxy mapperProxy = (MybatisMapperProxy) field.get(proxy); - field = mapperProxy.getClass().getDeclaredField("mapperInterface"); - field.setAccessible(true); - return (Class) field.get(mapperProxy); - } catch (NoSuchFieldException | IllegalAccessException e) { - LogUtil.error(LogEnum.BIZ_DATASET, "reflect error", e); - } - return null; - } - - /** - * 填充数据权限过滤,处理那些需要排除的方法 - * - * @param clazz 需要处理的类(mapper) - */ - private void populateDataFilters(Class clazz) { - if (clazz == null) { - return; - } - Method[] methods = clazz.getMethods(); - DataPermission dataPermission = AnnotationUtils.findAnnotation((Class) clazz, DataPermission.class); - Set ignores = Sets.newHashSet(dataPermission.ignores()); - for (Method method : methods) { - if (ignores.contains(method.getName())) { - continue; - } - Set permission = getDataPermission(method); - dataFilters.put(clazz.getName() + PACKAGE_SEPARATOR + method.getName(), permission); - } - } - - /** - * 获取方法上权限注解 - * 权限注解包含 - * 1.用户拥有指定权限才可以执行该方法:比如 PermissionConstant.SELECT 表示用户必须拥有select权限,才可以使用该方法 - * 2.方法权限校验排除:比如 ignores = {"insert"} 表示insert方法不做权限处理 - * - * @param method 方法对象 - * @return - */ - private Set getDataPermission(Method method) { - DataPermission dataPermission = AnnotationUtils.findAnnotation(method, DataPermission.class); - // 无注解时以方法名判断 - if (dataPermission == null) { - if (method.getName().contains(SELECT_STR)) { - return SELECT_PERMISSION; - } - return UPDATE_DELETE_PERMISSION; - } - return Sets.newHashSet(dataPermission.permission()); - } - } + diff --git a/dubhe-server/common/src/main/java/org/dubhe/config/RecycleConfig.java b/dubhe-server/common/src/main/java/org/dubhe/config/RecycleConfig.java new file mode 100644 index 0000000..418b97e --- /dev/null +++ b/dubhe-server/common/src/main/java/org/dubhe/config/RecycleConfig.java @@ -0,0 +1,62 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ +package org.dubhe.config; + +import lombok.Data; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.stereotype.Component; + +/** + * @description 垃圾回收机制配置常量 + * @date 2020-09-21 + */ +@Data +@Component +@ConfigurationProperties(prefix = "recycle.timeout") +public class RecycleConfig { + + /** + * 回收无效文件的默认有效时长 + */ + private Integer date; + + /** + * 用户上传文件至临时路径下后文件最大有效时长,以小时为单位 + */ + private Integer fileValid; + + /** + * 用户删除某一算法后,其算法文件最大有效时长,以天为单位 + */ + private Integer algorithmValid; + + /** + * 用户删除某一模型后,其模型文件最大有效时长,以天为单位 + */ + private Integer modelValid; + + /** + * 用户删除训练任务后,其训练管理文件最大有效时长,以天为单位 + */ + private Integer trainValid; + + /** + * 删除服务器无效文件(大文件) + * 示例:rsync --delete-before -d /空目录 /需要回收的源目录 + */ + public static final String DEL_COMMAND = "ssh %s@%s \"mkdir -p %s; rsync --delete-before -d %s %s; rmdir %s %s\""; +} \ No newline at end of file diff --git a/dubhe-server/common/src/main/java/org/dubhe/constant/TrainJobConstant.java b/dubhe-server/common/src/main/java/org/dubhe/config/TrainJobConfig.java similarity index 66% rename from dubhe-server/common/src/main/java/org/dubhe/constant/TrainJobConstant.java rename to dubhe-server/common/src/main/java/org/dubhe/config/TrainJobConfig.java index 196d6c5..bdb8a47 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/constant/TrainJobConstant.java +++ b/dubhe-server/common/src/main/java/org/dubhe/config/TrainJobConfig.java @@ -15,77 +15,71 @@ * ============================================================= */ -package org.dubhe.constant; +package org.dubhe.config; import lombok.Data; -import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.stereotype.Component; /** * @description 训练常量 - * @create: 2020-05-12 + * @date 2020-05-12 */ @Component @Data -public class TrainJobConstant { +@ConfigurationProperties(prefix = "train-job") +public class TrainJobConfig { - @Value("${train-job.namespace}") private String namespace; - @Value("${train-job.version-label}") private String versionLabel; - @Value("${train-job.separator}") private String separator; - @Value("${train-job.pod-name}") private String podName; - @Value("${train-job.python-format}") private String pythonFormat; - @Value("${train-job.manage}") private String manage; - @Value("${train-job.out-path}") private String outPath; - @Value("${train-job.log-path}") private String logPath; - @Value("${train-job.visualized-log-path}") private String visualizedLogPath; - @Value("${train-job.docker-dataset-path}") private String dockerDatasetPath; - @Value("${train-job.docker-train-path}") private String dockerTrainPath; - @Value("${train-job.docker-out-path}") private String dockerOutPath; - @Value("${train-job.docker-log-path}") private String dockerLogPath; - @Value("${train-job.docker-dataset}") private String dockerDataset; - @Value("${train-job.docker-visualized-log-path}") + private String dockerModelPath; + + private String dockerValDatasetPath; + + private String loadValDatasetKey; + private String dockerVisualizedLogPath; - @Value("${train-job.load-path}") private String loadPath; - @Value("${train-job.load-key}") private String loadKey; - @Value("${train-job.eight}") private String eight; - @Value("${train-job.plus-eight}") private String plusEight; + private String nodeIps; + + private String nodeNum; + + private String gpuNumPerNode; + public static final String TRAIN_ID = "trainId"; public static final String TRAIN_VERSION = "trainVersion"; @@ -97,4 +91,5 @@ public class TrainJobConstant { public static final String CREATE_TIME = "createTime"; public static final String ALGORITHM_NAME = "algorithmName"; + } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/config/TrainPoolConfig.java b/dubhe-server/common/src/main/java/org/dubhe/config/TrainPoolConfig.java similarity index 72% rename from dubhe-server/dubhe-admin/src/main/java/org/dubhe/config/TrainPoolConfig.java rename to dubhe-server/common/src/main/java/org/dubhe/config/TrainPoolConfig.java index 1a3d281..049aa1e 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/config/TrainPoolConfig.java +++ b/dubhe-server/common/src/main/java/org/dubhe/config/TrainPoolConfig.java @@ -16,9 +16,13 @@ */ package org.dubhe.config; +import org.dubhe.enums.LogEnum; +import org.dubhe.utils.LogUtil; +import org.springframework.aop.interceptor.AsyncUncaughtExceptionHandler; import org.springframework.beans.factory.annotation.Value; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.scheduling.annotation.AsyncConfigurer; import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; import java.util.concurrent.Executor; @@ -29,7 +33,7 @@ import java.util.concurrent.ThreadPoolExecutor; * @date 2020-07-17 */ @Configuration -public class TrainPoolConfig { +public class TrainPoolConfig implements AsyncConfigurer { @Value("${basepool.corePoolSize:40}") private Integer corePoolSize; @@ -44,8 +48,9 @@ public class TrainPoolConfig { * 训练任务异步处理线程池 * @return Executor 线程实例 */ - @Bean - public Executor trainJobAsyncExecutor() { + @Bean("trainExecutor") + @Override + public Executor getAsyncExecutor() { ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor(); //核心线程数 taskExecutor.setCorePoolSize(corePoolSize); @@ -57,10 +62,18 @@ public class TrainPoolConfig { //配置队列大小 taskExecutor.setQueueCapacity(blockQueueSize); //配置线程池前缀 - taskExecutor.setThreadNamePrefix("async-train-job-"); + taskExecutor.setThreadNamePrefix("async-train-"); //拒绝策略 taskExecutor.setRejectedExecutionHandler(new ThreadPoolExecutor.AbortPolicy()); taskExecutor.initialize(); return taskExecutor; } + + @Override + public AsyncUncaughtExceptionHandler getAsyncUncaughtExceptionHandler() { + LogUtil.error(LogEnum.BIZ_TRAIN, "开始捕获训练管理异步任务异常信息-----》》》"); + return (ex, method, params) -> { + LogUtil.error(LogEnum.BIZ_TRAIN, "训练管理方法名{}的异步任务执行失败,参数信息:{},异常信息:{}", method.getName(), params, ex); + }; + } } diff --git a/dubhe-server/common/src/main/java/org/dubhe/constant/NumberConstant.java b/dubhe-server/common/src/main/java/org/dubhe/constant/NumberConstant.java index 08e124f..220fe22 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/constant/NumberConstant.java +++ b/dubhe-server/common/src/main/java/org/dubhe/constant/NumberConstant.java @@ -18,8 +18,8 @@ package org.dubhe.constant; /** - * @Description 数字常量 - * @Date 2020-6-9 + * @description 数字常量 + * @date 2020-06-09 */ public class NumberConstant { @@ -32,6 +32,8 @@ public class NumberConstant { public final static int NUMBER_30 = 30; public final static int NUMBER_50 = 50; public final static int NUMBER_60 = 60; + public final static int NUMBER_1024 = 1024; + public final static int NUMBER_1000 = 1000; public final static int HOUR_SECOND = 60 * 60; public final static int DAY_SECOND = 60 * 60 * 24; public final static int WEEK_SECOND = 60 * 60 * 24 * 7; diff --git a/dubhe-server/common/src/main/java/org/dubhe/constant/PermissionConstant.java b/dubhe-server/common/src/main/java/org/dubhe/constant/PermissionConstant.java index 8989218..0d30a7d 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/constant/PermissionConstant.java +++ b/dubhe-server/common/src/main/java/org/dubhe/constant/PermissionConstant.java @@ -21,8 +21,8 @@ import lombok.Data; import org.springframework.stereotype.Component; /** - * @description: 权限常量 - * @since: 2020-05-25 14:39 + * @description 权限常量 + * @date 2020-05-25 */ @Component @Data @@ -32,9 +32,10 @@ public class PermissionConstant { * 超级用户 */ public static final long ADMIN_USER_ID = 1L; - public static final long ANONYMOUS_USER = -1L; - public static final String SELECT = "select"; - public static final String UPDATE = "update"; - public static final String DELETE = "delete"; + + /** + * 数据集模块类型 + */ + public static final Integer RESOURCE_DATA_MODEL = 1; } diff --git a/dubhe-server/common/src/main/java/org/dubhe/constant/StringConstant.java b/dubhe-server/common/src/main/java/org/dubhe/constant/StringConstant.java index bd97aad..414e5fe 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/constant/StringConstant.java +++ b/dubhe-server/common/src/main/java/org/dubhe/constant/StringConstant.java @@ -23,24 +23,29 @@ package org.dubhe.constant; */ public final class StringConstant { - public static final String MSIE = "MSIE"; - public static final String MOZILLA = "Mozilla"; + public static final String MSIE = "MSIE"; + public static final String MOZILLA = "Mozilla"; + public static final String REQUEST_METHOD_GET = "GET"; - /** - * 公共字段 - */ - public static final String CREATE_TIME = "createTime"; - public static final String UPDATE_TIME = "updateTime"; - public static final String UPDATE_USER_ID = "updateUserId"; - public static final String CREATE_USER_ID = "createUserId"; - public static final String DELETED = "deleted"; - public static final String UTF8 = "utf-8"; + /** + * 公共字段 + */ + public static final String CREATE_TIME = "createTime"; + public static final String UPDATE_TIME = "updateTime"; + public static final String UPDATE_USER_ID = "updateUserId"; + public static final String CREATE_USER_ID = "createUserId"; + public static final String ORIGIN_USER_ID = "originUserId"; + public static final String DELETED = "deleted"; + public static final String UTF8 = "utf-8"; + public static final String JSON_REQUEST = "application/json"; + public static final String K8S_CALLBACK_URI = "/api/k8s/callback/pod"; + public static final String MULTIPART = "multipart/form-data"; - /** - * 测试环境 - */ - public static final String PROFILE_ACTIVE_TEST = "test"; + /** + * 测试环境 + */ + public static final String PROFILE_ACTIVE_TEST = "test"; - private StringConstant() { - } + private StringConstant() { + } } diff --git a/dubhe-server/common/src/main/java/org/dubhe/constant/SymbolConstant.java b/dubhe-server/common/src/main/java/org/dubhe/constant/SymbolConstant.java index 1e0476a..6cb1be9 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/constant/SymbolConstant.java +++ b/dubhe-server/common/src/main/java/org/dubhe/constant/SymbolConstant.java @@ -19,7 +19,7 @@ package org.dubhe.constant; /** * @description 符号常量 - * @Date 2020-5-29 + * @date 2020-5-29 */ public class SymbolConstant { public static final String SLASH = "/"; @@ -40,6 +40,8 @@ public class SymbolConstant { public static final String DOUBLE_MARK= "\"\""; public static final String MARK= "\""; + public static final String FLAG_EQUAL = "="; + private SymbolConstant() { } diff --git a/dubhe-server/common/src/main/java/org/dubhe/constant/UserAuxiliaryInfoConstant.java b/dubhe-server/common/src/main/java/org/dubhe/constant/UserAuxiliaryInfoConstant.java index 27b19d8..8975b02 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/constant/UserAuxiliaryInfoConstant.java +++ b/dubhe-server/common/src/main/java/org/dubhe/constant/UserAuxiliaryInfoConstant.java @@ -17,15 +17,12 @@ package org.dubhe.constant; -import org.springframework.stereotype.Component; - import lombok.Data; /** * @description 算法用途 - * @date: 2020-06-23 + * @date 2020-06-23 */ -@Component @Data public class UserAuxiliaryInfoConstant { diff --git a/dubhe-server/common/src/main/java/org/dubhe/domain/dto/CommonPermissionDataDTO.java b/dubhe-server/common/src/main/java/org/dubhe/domain/dto/CommonPermissionDataDTO.java new file mode 100644 index 0000000..c1cfcc4 --- /dev/null +++ b/dubhe-server/common/src/main/java/org/dubhe/domain/dto/CommonPermissionDataDTO.java @@ -0,0 +1,52 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ +package org.dubhe.domain.dto; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.io.Serializable; +import java.util.Set; + +/** + * @description 公共权限信息DTO + * @date 2020-09-24 + */ +@AllArgsConstructor +@NoArgsConstructor +@Builder +@Data +public class CommonPermissionDataDTO implements Serializable { + + /** + * 资源拥有者ID + */ + private Long id; + + /** + * 公共类型 + */ + private Boolean type; + /** + * 资源所属用户ids + */ + private Set resourceUserIds; + + +} diff --git a/dubhe-server/common/src/main/java/org/dubhe/domain/entity/LogInfo.java b/dubhe-server/common/src/main/java/org/dubhe/domain/entity/LogInfo.java index d37791d..f01468d 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/domain/entity/LogInfo.java +++ b/dubhe-server/common/src/main/java/org/dubhe/domain/entity/LogInfo.java @@ -16,15 +16,14 @@ */ package org.dubhe.domain.entity; -import java.io.Serializable; - -import org.dubhe.base.MagicNumConstant; - -import com.alibaba.fastjson.annotation.JSONField; import cn.hutool.core.date.DateUtil; +import com.alibaba.fastjson.annotation.JSONField; import lombok.Data; import lombok.experimental.Accessors; +import org.dubhe.base.MagicNumConstant; + +import java.io.Serializable; /** * @description 日志对象封装类 @@ -34,31 +33,27 @@ import lombok.experimental.Accessors; @Accessors(chain = true) public class LogInfo implements Serializable { - @JSONField(ordinal = MagicNumConstant.ONE) - private String traceId; + private static final long serialVersionUID = 5250395474667395607L; + + @JSONField(ordinal = MagicNumConstant.ONE) + private String traceId; - @JSONField(ordinal = MagicNumConstant.TWO) - private String type; + @JSONField(ordinal = MagicNumConstant.TWO) + private String type; - @JSONField(ordinal = MagicNumConstant.THREE) - private String level; + @JSONField(ordinal = MagicNumConstant.THREE) + private String level; - @JSONField(ordinal = MagicNumConstant.FOUR) - private String cName; + @JSONField(ordinal = MagicNumConstant.FOUR) + private String location; - @JSONField(ordinal = MagicNumConstant.FIVE) - private String mName; - - @JSONField(ordinal = MagicNumConstant.SIX) - private String line; - - @JSONField(ordinal = MagicNumConstant.SEVEN) - private String time = DateUtil.now(); + @JSONField(ordinal = MagicNumConstant.FIVE) + private String time = DateUtil.now(); - @JSONField(ordinal = MagicNumConstant.EIGHT) - private Object info; + @JSONField(ordinal = MagicNumConstant.SIX) + private Object info; - public void setInfo(Object info) { - this.info = info; - } + public void setInfo(Object info) { + this.info = info; + } } diff --git a/dubhe-server/common/src/main/java/org/dubhe/domain/entity/Menu.java b/dubhe-server/common/src/main/java/org/dubhe/domain/entity/Menu.java index 32e3162..4136d11 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/domain/entity/Menu.java +++ b/dubhe-server/common/src/main/java/org/dubhe/domain/entity/Menu.java @@ -22,12 +22,9 @@ import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; import org.dubhe.base.BaseEntity; -import org.hibernate.validator.constraints.Length; import javax.validation.constraints.NotBlank; -import javax.validation.constraints.NotNull; import java.io.Serializable; -import java.sql.Timestamp; import java.util.Objects; /** diff --git a/dubhe-server/common/src/main/java/org/dubhe/domain/entity/Role.java b/dubhe-server/common/src/main/java/org/dubhe/domain/entity/Role.java index 77227ec..f50cd86 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/domain/entity/Role.java +++ b/dubhe-server/common/src/main/java/org/dubhe/domain/entity/Role.java @@ -22,12 +22,7 @@ import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; import org.dubhe.base.BaseEntity; -import org.hibernate.validator.constraints.Length; - -import javax.validation.constraints.NotBlank; -import javax.validation.constraints.NotEmpty; import java.io.Serializable; -import java.sql.Timestamp; import java.util.Objects; import java.util.Set; diff --git a/dubhe-server/common/src/main/java/org/dubhe/dto/GlobalRequestRecordDTO.java b/dubhe-server/common/src/main/java/org/dubhe/dto/GlobalRequestRecordDTO.java new file mode 100644 index 0000000..6937ae2 --- /dev/null +++ b/dubhe-server/common/src/main/java/org/dubhe/dto/GlobalRequestRecordDTO.java @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.dubhe.dto; + +import lombok.Data; + +/** + * @description 全局请求日志信息 + * @date 2020-08-13 + */ +@Data +public class GlobalRequestRecordDTO { + /** + * 客户主机地址 + */ + private String clientHost; + /** + * 请求地址 + */ + private String uri; + /** + * 授权信息 + */ + private String authorization; + /** + * 用户名 + */ + private String username; + /** + * form参数 + */ + private String params; + /** + * 返回值类型 + */ + private String contentType; + /** + * 返回状态 + */ + private Integer status; + /** + * 时间耗费 + */ + private Long timeCost; + /** + * 请求方式 + */ + private String method; + /** + * 请求体body参数 + */ + private String requestBody; + /** + * 返回值json数据 + */ + private String responseBody; + +} diff --git a/dubhe-server/common/src/main/java/org/dubhe/dto/callback/BaseK8sPodCallbackCreateDTO.java b/dubhe-server/common/src/main/java/org/dubhe/dto/callback/BaseK8sPodCallbackCreateDTO.java index bc5d4f8..ffc4a9f 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/dto/callback/BaseK8sPodCallbackCreateDTO.java +++ b/dubhe-server/common/src/main/java/org/dubhe/dto/callback/BaseK8sPodCallbackCreateDTO.java @@ -44,6 +44,14 @@ public class BaseK8sPodCallbackCreateDTO { @NotEmpty(message = "podName 不能为空!") private String podName; + @ApiModelProperty(required = true,value = "k8s pod parent type") + @NotEmpty(message = "podParentType 不能为空!") + private String podParentType; + + @ApiModelProperty(required = true,value = "k8s pod parent name") + @NotEmpty(message = "podParentName 不能为空!") + private String podParentName; + @ApiModelProperty(value = "k8s pod phase",notes = "对应PodPhaseEnum") @NotEmpty(message = "phase 不能为空!") private String phase; @@ -55,10 +63,12 @@ public class BaseK8sPodCallbackCreateDTO { } - public BaseK8sPodCallbackCreateDTO(String namespace,String resourceName,String podName,String phase,String messages){ + public BaseK8sPodCallbackCreateDTO(String namespace,String resourceName,String podName,String podParentType,String podParentName,String phase,String messages){ this.namespace = namespace; this.resourceName = resourceName; this.podName = podName; + this.podParentType = podParentType; + this.podParentName = podParentName; this.phase = phase; this.messages = messages; } @@ -69,6 +79,8 @@ public class BaseK8sPodCallbackCreateDTO { "namespace='" + namespace + '\'' + ", resourceName='" + resourceName + '\'' + ", podName='" + podName + '\'' + + ", podParentType='" + podParentType + '\'' + + ", podParentName='" + podParentName + '\'' + ", phase='" + phase + '\'' + ", messages=" + messages + '}'; diff --git a/dubhe-server/common/src/main/java/org/dubhe/enums/BizEnum.java b/dubhe-server/common/src/main/java/org/dubhe/enums/BizEnum.java index 8a87acb..691c0d2 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/enums/BizEnum.java +++ b/dubhe-server/common/src/main/java/org/dubhe/enums/BizEnum.java @@ -23,9 +23,8 @@ import java.util.HashMap; import java.util.Map; /** - * @desc: 业务模块 - * - * @date 2020.05.25 + * @description 业务模块 + * @date 2020-05-25 */ @Getter public enum BizEnum { @@ -38,6 +37,10 @@ public enum BizEnum { * 算法管理 */ ALGORITHM("算法管理","algorithm",1), + /** + * 模型管理 + */ + MODEL("模型管理","model",2), ; /** diff --git a/dubhe-server/common/src/main/java/org/dubhe/enums/BizNfsEnum.java b/dubhe-server/common/src/main/java/org/dubhe/enums/BizNfsEnum.java index 2b06087..796f5b7 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/enums/BizNfsEnum.java +++ b/dubhe-server/common/src/main/java/org/dubhe/enums/BizNfsEnum.java @@ -21,8 +21,8 @@ import java.util.HashMap; import java.util.Map; /** - * @desc: 业务NFS路径枚举 - * @date 2020.05.13 + * @description 业务NFS路径枚举 + * @date 2020-05-13 */ public enum BizNfsEnum { /** @@ -33,6 +33,10 @@ public enum BizNfsEnum { * 算法管理 NFS 路径命名 */ ALGORITHM(BizEnum.ALGORITHM, "algorithm-manage"), + /** + * 模型管理 NFS 路径命名 + */ + MODEL(BizEnum.MODEL, "model"), ; BizNfsEnum(BizEnum bizEnum, String bizNfsPath) { @@ -81,11 +85,11 @@ public enum BizNfsEnum { return bizNfsPath; } - public BizEnum getBizEnum(){ + public BizEnum getBizEnum() { return bizEnum; } public String getBizCode() { - return bizEnum == null ? null :bizEnum.getBizCode(); + return bizEnum == null ? null : bizEnum.getBizCode(); } } diff --git a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/DatasetTypeEnum.java b/dubhe-server/common/src/main/java/org/dubhe/enums/DatasetTypeEnum.java similarity index 97% rename from dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/DatasetTypeEnum.java rename to dubhe-server/common/src/main/java/org/dubhe/enums/DatasetTypeEnum.java index cd4546f..5f72aaa 100644 --- a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/DatasetTypeEnum.java +++ b/dubhe-server/common/src/main/java/org/dubhe/enums/DatasetTypeEnum.java @@ -15,7 +15,7 @@ * ============================================================= */ -package org.dubhe.data.constant; +package org.dubhe.enums; import lombok.Getter; diff --git a/dubhe-server/common/src/main/java/org/dubhe/enums/LogEnum.java b/dubhe-server/common/src/main/java/org/dubhe/enums/LogEnum.java index 2b985bc..8fbd401 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/enums/LogEnum.java +++ b/dubhe-server/common/src/main/java/org/dubhe/enums/LogEnum.java @@ -26,36 +26,52 @@ import lombok.Getter; @Getter public enum LogEnum { - // 系统报错日志 - SYS_ERR, - // 用户请求日志 - REST_REQ, - // 训练模块 - BIZ_TRAIN, - // 系统模块 - BIZ_SYS, - // 模型模块 - BIZ_MODEL, - // 数据集模块 - BIZ_DATASET, - // k8s模块 - BIZ_K8S, - //note book - NOTE_BOOK, - //NFS UTILS - NFS_UTIL; + // 系统报错日志 + SYS_ERR, + // 用户请求日志 + REST_REQ, + //全局请求日志 + GLOBAL_REQ, + // 训练模块 + BIZ_TRAIN, + // 系统模块 + BIZ_SYS, + // 模型模块 + BIZ_MODEL, + // 数据集模块 + BIZ_DATASET, + // k8s模块 + BIZ_K8S, + //note book + NOTE_BOOK, + //NFS UTILS + NFS_UTIL, + //localFileUtil + LOCAL_FILE_UTIL, + //FILE UTILS + FILE_UTIL, + //FILE UTILS + UPLOAD_TEMP, + //STATE MACHINE + STATE_MACHINE, + //全局垃圾回收 + GARBAGE_RECYCLE, + //DATA_SEQUENCE + DATA_SEQUENCE, + //IO UTIL + IO_UTIL; - /** - * 判断日志类型不能为空 - * - * @param logType 日志类型 - * @return boolean 返回类型 - */ - public static boolean isLogType(LogEnum logType) { + /** + * 判断日志类型不能为空 + * + * @param logType 日志类型 + * @return boolean 返回类型 + */ + public static boolean isLogType(LogEnum logType) { - if (logType != null) { - return true; - } - return false; - } + if (logType != null) { + return true; + } + return false; + } } diff --git a/dubhe-server/common/src/main/java/org/dubhe/enums/OperationTypeEnum.java b/dubhe-server/common/src/main/java/org/dubhe/enums/OperationTypeEnum.java new file mode 100644 index 0000000..59b284f --- /dev/null +++ b/dubhe-server/common/src/main/java/org/dubhe/enums/OperationTypeEnum.java @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.dubhe.enums; + +import lombok.Getter; +import lombok.ToString; + +/** + * @Description 操作类型枚举 + * @Date 2020-08-24 + */ +@ToString +@Getter +public enum OperationTypeEnum { + /** + * SELECT 查询类型 + */ + SELECT("select", "查询"), + + /** + * UPDATE 修改类型 + */ + UPDATE("update", "修改"), + + /** + * DELETE 删除类型 + */ + DELETE("delete", "删除"), + + /** + * LIMIT 禁止操作类型 + */ + LIMIT("limit", "禁止操作"), + + /** + * INSERT 新增类型 + */ + INSERT("insert", "新增类型"), + + ; + + /** + * 操作类型值 + */ + private String type; + + /** + * 操作类型备注 + */ + private String desc; + + OperationTypeEnum(String type, String desc) { + this.type = type; + this.desc = desc; + } + +} diff --git a/dubhe-server/common/src/main/java/org/dubhe/enums/RecycleResourceEnum.java b/dubhe-server/common/src/main/java/org/dubhe/enums/RecycleResourceEnum.java new file mode 100644 index 0000000..1d6b257 --- /dev/null +++ b/dubhe-server/common/src/main/java/org/dubhe/enums/RecycleResourceEnum.java @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.dubhe.enums; + +import lombok.Getter; + +import java.util.HashSet; +import java.util.Set; + +/** + * @description 资源回收枚举类 + * @date 2020-10-10 + */ +@Getter +public enum RecycleResourceEnum { + + /** + * 数据集文件回收 + */ + DATASET_RECYCLE_FILE("datasetRecycleFile", "数据集文件回收"), + /** + * 数据集版本文件回收 + */ + DATASET_RECYCLE_VERSION_FILE("datasetRecycleVersionFile", "数据集版本文件回收"), + + ; + + private String className; + + private String message; + + RecycleResourceEnum(String className, String message) { + this.className = className; + this.message = message; + } + + + +} diff --git a/dubhe-server/common/src/main/java/org/dubhe/enums/TrainJobStatusEnum.java b/dubhe-server/common/src/main/java/org/dubhe/enums/TrainJobStatusEnum.java index 560d4ce..58a91d4 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/enums/TrainJobStatusEnum.java +++ b/dubhe-server/common/src/main/java/org/dubhe/enums/TrainJobStatusEnum.java @@ -19,8 +19,8 @@ package org.dubhe.enums; import lombok.Getter; -import java.util.Arrays; -import java.util.List; +import java.util.HashSet; +import java.util.Set; /** * @description 训练任务枚举类 @@ -72,7 +72,13 @@ public enum TrainJobStatusEnum { this.message = message; } - public static TrainJobStatusEnum get(String msg) { + /** + * 根据信息获取枚举类对象 + * + * @param msg 信息 + * @return 枚举类对象 + */ + public static TrainJobStatusEnum getByMessage(String msg) { for (TrainJobStatusEnum statusEnum : values()) { if (statusEnum.message.equalsIgnoreCase(msg)) { return statusEnum; @@ -81,21 +87,63 @@ public enum TrainJobStatusEnum { return UNKNOWN; } + /** + * 回调状态转换 若是DELETED则转换为STOP,避免状态不统一 + * @param phase k8s pod phase + * @return + */ + public static TrainJobStatusEnum transferStatus(String phase) { + TrainJobStatusEnum enums = getByMessage(phase); + if (enums != DELETED) { + return enums; + } + return STOP; + } + + /** + * 根据状态获取枚举类对象 + * + * @param status 状态 + * @return 枚举类对象 + */ + public static TrainJobStatusEnum getByStatus(Integer status) { + for (TrainJobStatusEnum statusEnum : values()) { + if (statusEnum.status.equals(status)) { + return statusEnum; + } + } + return UNKNOWN; + } + + + /** + * 结束状态枚举集合 + */ + public static final Set END_TRAIN_JOB_STATUS; + + static { + END_TRAIN_JOB_STATUS = new HashSet<>(); + END_TRAIN_JOB_STATUS.add(SUCCEEDED); + END_TRAIN_JOB_STATUS.add(FAILED); + END_TRAIN_JOB_STATUS.add(STOP); + END_TRAIN_JOB_STATUS.add(CREATE_FAILED); + END_TRAIN_JOB_STATUS.add(DELETED); + } + public static boolean isEnd(String msg) { - List endList = Arrays.asList("SUCCEEDED", "FAILED", "STOP", "CREATE_FAILED"); - return endList.stream().anyMatch(s -> s.equalsIgnoreCase(msg)); + return END_TRAIN_JOB_STATUS.contains(getByMessage(msg)); } - public static boolean isEnd(Integer num) { - List endList = Arrays.asList(2, 3, 4, 7); - return endList.stream().anyMatch(s -> s.equals(num)); + public static boolean isEnd(Integer status) { + return END_TRAIN_JOB_STATUS.contains(getByStatus(status)); } public static boolean checkStopStatus(Integer num) { - return SUCCEEDED.getStatus().equals(num) || - FAILED.getStatus().equals(num) || - STOP.getStatus().equals(num) || - CREATE_FAILED.getStatus().equals(num) || - DELETED.getStatus().equals(num); + return isEnd(num); + } + + public static boolean checkRunStatus(Integer num) { + return PENDING.getStatus().equals(num) || + RUNNING.getStatus().equals(num); } } diff --git a/dubhe-server/common/src/main/java/org/dubhe/exception/BaseErrorCode.java b/dubhe-server/common/src/main/java/org/dubhe/exception/BaseErrorCode.java index dcd18b2..233f35c 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/exception/BaseErrorCode.java +++ b/dubhe-server/common/src/main/java/org/dubhe/exception/BaseErrorCode.java @@ -54,6 +54,7 @@ public enum BaseErrorCode implements ErrorCode { SYSTEM_USER_CANNOT_DELETE(20014, "系统默认用户不可删除!"), SYSTEM_ROLE_CANNOT_DELETE(20015, "系统默认角色不可删除!"), + DATASET_ADMIN_PERMISSION_ERROR(1310,"无此权限,请联系管理员"), ; diff --git a/dubhe-server/common/src/main/java/org/dubhe/exception/DataSequenceException.java b/dubhe-server/common/src/main/java/org/dubhe/exception/DataSequenceException.java new file mode 100644 index 0000000..ca6869b --- /dev/null +++ b/dubhe-server/common/src/main/java/org/dubhe/exception/DataSequenceException.java @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.dubhe.exception; + +import lombok.Getter; + +/** + * @description 获取序列异常 + * @date 2020-09-23 + */ +@Getter +public class DataSequenceException extends BusinessException { + + private static final long serialVersionUID = 1L; + + public DataSequenceException(String msg) { + super(msg); + } + + public DataSequenceException(String msg, Throwable cause) { + super(msg,cause); + } + + public DataSequenceException(Throwable cause) { + super(cause); + } +} diff --git a/dubhe-server/common/src/main/java/org/dubhe/exception/NotebookBizException.java b/dubhe-server/common/src/main/java/org/dubhe/exception/NotebookBizException.java index 9f48beb..cd9c067 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/exception/NotebookBizException.java +++ b/dubhe-server/common/src/main/java/org/dubhe/exception/NotebookBizException.java @@ -20,8 +20,7 @@ package org.dubhe.exception; import lombok.Getter; /** - * @description: Notebook 业务处理异常 - * + * @description Notebook 业务处理异常 * @date 2020.04.27 */ @Getter diff --git a/dubhe-server/common/src/main/java/org/dubhe/exception/handler/GlobalExceptionHandler.java b/dubhe-server/common/src/main/java/org/dubhe/exception/handler/GlobalExceptionHandler.java index b042abc..b94613f 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/exception/handler/GlobalExceptionHandler.java +++ b/dubhe-server/common/src/main/java/org/dubhe/exception/handler/GlobalExceptionHandler.java @@ -17,8 +17,8 @@ package org.dubhe.exception.handler; -import java.util.Objects; - +import lombok.extern.slf4j.Slf4j; +import org.apache.ibatis.exceptions.IbatisException; import org.apache.shiro.ShiroException; import org.apache.shiro.authc.AuthenticationException; import org.apache.shiro.authc.IncorrectCredentialsException; @@ -27,11 +27,7 @@ import org.apache.shiro.authc.UnknownAccountException; import org.dubhe.base.DataResponseBody; import org.dubhe.base.ResponseCode; import org.dubhe.enums.LogEnum; -import org.dubhe.exception.BusinessException; -import org.dubhe.exception.CaptchaException; -import org.dubhe.exception.LoginException; -import org.dubhe.exception.NotebookBizException; -import org.dubhe.exception.UnauthorizedException; +import org.dubhe.exception.*; import org.dubhe.utils.LogUtil; import org.springframework.http.HttpStatus; import org.springframework.http.ResponseEntity; @@ -41,7 +37,7 @@ import org.springframework.web.bind.MethodArgumentNotValidException; import org.springframework.web.bind.annotation.ExceptionHandler; import org.springframework.web.bind.annotation.RestControllerAdvice; -import lombok.extern.slf4j.Slf4j; +import java.util.Objects; /** * @description 处理异常 @@ -51,140 +47,144 @@ import lombok.extern.slf4j.Slf4j; @RestControllerAdvice public class GlobalExceptionHandler { - /** - * 处理所有不可知的异常 - */ - @ExceptionHandler(Throwable.class) - public ResponseEntity handleException(Throwable e) { - // 打印堆栈信息 - LogUtil.error(LogEnum.SYS_ERR, e); - return buildResponseEntity(HttpStatus.INTERNAL_SERVER_ERROR, - new DataResponseBody(ResponseCode.ERROR, e.getMessage())); - } - - /** - * UnauthorizedException - */ - @ExceptionHandler(UnauthorizedException.class) - public ResponseEntity badCredentialsException(UnauthorizedException e) { - // 打印堆栈信息 - LogUtil.error(LogEnum.SYS_ERR, e); - String message = "坏的凭证".equals(e.getMessage()) ? "用户名或密码不正确" : e.getMessage(); - return buildResponseEntity(HttpStatus.UNAUTHORIZED, new DataResponseBody(ResponseCode.ERROR, message)); - } - - /** - * 处理自定义异常 - */ - @ExceptionHandler(value = BusinessException.class) - public ResponseEntity badRequestException(BusinessException e) { - // 打印堆栈信息 - LogUtil.error(LogEnum.SYS_ERR, e); - return buildResponseEntity(HttpStatus.OK, e.getResponseBody()); - } - - /** - * 处理自定义异常 - */ - @ExceptionHandler(value = AuthenticationException.class) - public ResponseEntity badRequestException(AuthenticationException e) { - // 打印堆栈信息 - LogUtil.error(LogEnum.SYS_ERR, e); - return buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.UNAUTHORIZED, "无权访问")); - } - - /** - * shiro 异常捕捉 - */ - @ExceptionHandler(value = ShiroException.class) - public ResponseEntity accountException(ShiroException e) { - // 打印堆栈信息 - LogUtil.error(LogEnum.SYS_ERR, e); - ResponseEntity responseEntity; - if (e instanceof IncorrectCredentialsException) { - responseEntity = buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.ERROR, "密码不正确")); - } else if (e instanceof UnknownAccountException) { - responseEntity = buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.ERROR, "此账户不存在")); - } - - else if (e instanceof LockedAccountException) { - responseEntity = buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.ERROR, "未知的账号")); - } - - else if (e instanceof UnknownAccountException) { - responseEntity = buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.ERROR, "账户已被禁用")); - } - - else { - responseEntity = buildResponseEntity(HttpStatus.OK, - new DataResponseBody(ResponseCode.UNAUTHORIZED, "无权访问")); - } - return responseEntity; - } - - /** - * 处理自定义异常 - */ - @ExceptionHandler(value = LoginException.class) - public ResponseEntity loginException(LoginException e) { - // 打印堆栈信息 - LogUtil.error(LogEnum.SYS_ERR, e); - return buildResponseEntity(HttpStatus.UNAUTHORIZED, e.getResponseBody()); - } - - /** - * 处理自定义异常 - */ - @ExceptionHandler(value = CaptchaException.class) - public ResponseEntity captchaException(CaptchaException e) { - // 打印堆栈信息 - LogUtil.error(LogEnum.SYS_ERR, e); - return buildResponseEntity(HttpStatus.OK, e.getResponseBody()); - } - - /** - * 处理自定义异常 - */ - @ExceptionHandler(value = NotebookBizException.class) - public ResponseEntity captchaException(NotebookBizException e) { - // 打印堆栈信息 - LogUtil.error(LogEnum.SYS_ERR, e); - return buildResponseEntity(HttpStatus.OK, e.getResponseBody()); - } - - /** - * 处理所有接口数据验证异常 - */ - @ExceptionHandler(MethodArgumentNotValidException.class) - public ResponseEntity handleMethodArgumentNotValidException(MethodArgumentNotValidException e) { - // 打印堆栈信息 - LogUtil.error(LogEnum.SYS_ERR, e); - String[] str = Objects.requireNonNull(e.getBindingResult().getAllErrors().get(0).getCodes())[1].split("\\."); - String message = e.getBindingResult().getAllErrors().get(0).getDefaultMessage(); - String msg = "不能为空"; - if (msg.equals(message)) { - message = str[1] + ":" + message; - } - return buildResponseEntity(HttpStatus.BAD_REQUEST, new DataResponseBody(ResponseCode.ERROR, message)); - } - - @ExceptionHandler(BindException.class) - public ResponseEntity bindException(BindException e) { - // 打印堆栈信息 - LogUtil.error(LogEnum.SYS_ERR, e); - ObjectError error = e.getAllErrors().get(0); - return buildResponseEntity(HttpStatus.BAD_REQUEST, - new DataResponseBody(ResponseCode.ERROR, error.getDefaultMessage())); - } - - /** - * 统一返回 - * - * @param httpStatus - * @param responseBody - * @return - */ - private ResponseEntity buildResponseEntity(HttpStatus httpStatus, DataResponseBody responseBody) { - return new ResponseEntity<>(responseBody, httpStatus); - } + /** + * 处理所有不可知的异常 + */ + @ExceptionHandler(Throwable.class) + public ResponseEntity handleException(Throwable e) { + // 打印堆栈信息 + LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e); + return buildResponseEntity(HttpStatus.INTERNAL_SERVER_ERROR, + new DataResponseBody(ResponseCode.ERROR, e.getMessage())); + } + + /** + * UnauthorizedException + */ + @ExceptionHandler(UnauthorizedException.class) + public ResponseEntity badCredentialsException(UnauthorizedException e) { + // 打印堆栈信息 + LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e); + String message = "坏的凭证".equals(e.getMessage()) ? "用户名或密码不正确" : e.getMessage(); + return buildResponseEntity(HttpStatus.UNAUTHORIZED, new DataResponseBody(ResponseCode.ERROR, message)); + } + + /** + * 处理自定义异常 + */ + @ExceptionHandler(value = BusinessException.class) + public ResponseEntity badRequestException(BusinessException e) { + // 打印堆栈信息 + LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e); + return buildResponseEntity(HttpStatus.OK, e.getResponseBody()); + } + + /** + * 处理自定义异常 + */ + @ExceptionHandler(value = IbatisException.class) + public ResponseEntity persistenceException(IbatisException e) { + // 打印堆栈信息 + LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e); + return buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.ERROR, e.getMessage())); + } + + /** + * 处理自定义异常 + */ + @ExceptionHandler(value = AuthenticationException.class) + public ResponseEntity badRequestException(AuthenticationException e) { + // 打印堆栈信息 + LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e); + return buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.UNAUTHORIZED, "无权访问")); + } + + /** + * shiro 异常捕捉 + */ + @ExceptionHandler(value = ShiroException.class) + public ResponseEntity accountException(ShiroException e) { + // 打印堆栈信息 + LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e); + ResponseEntity responseEntity; + if (e instanceof IncorrectCredentialsException) { + responseEntity = buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.ERROR, "密码不正确")); + } else if (e instanceof UnknownAccountException) { + responseEntity = buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.ERROR, "此账户不存在")); + } else if (e instanceof LockedAccountException) { + responseEntity = buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.ERROR, "未知的账号")); + } else if (e instanceof UnknownAccountException) { + responseEntity = buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.ERROR, "账户已被禁用")); + } else { + responseEntity = buildResponseEntity(HttpStatus.OK, + new DataResponseBody(ResponseCode.UNAUTHORIZED, "无权访问")); + } + return responseEntity; + } + + /** + * 处理自定义异常 + */ + @ExceptionHandler(value = LoginException.class) + public ResponseEntity loginException(LoginException e) { + // 打印堆栈信息 + LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e); + return buildResponseEntity(HttpStatus.UNAUTHORIZED, e.getResponseBody()); + } + + /** + * 处理自定义异常 + */ + @ExceptionHandler(value = CaptchaException.class) + public ResponseEntity captchaException(CaptchaException e) { + // 打印堆栈信息 + LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e); + return buildResponseEntity(HttpStatus.OK, e.getResponseBody()); + } + + /** + * 处理自定义异常 + */ + @ExceptionHandler(value = NotebookBizException.class) + public ResponseEntity captchaException(NotebookBizException e) { + // 打印堆栈信息 + LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e); + return buildResponseEntity(HttpStatus.OK, e.getResponseBody()); + } + + /** + * 处理所有接口数据验证异常 + */ + @ExceptionHandler(MethodArgumentNotValidException.class) + public ResponseEntity handleMethodArgumentNotValidException(MethodArgumentNotValidException e) { + // 打印堆栈信息 + LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e); + String[] str = Objects.requireNonNull(e.getBindingResult().getAllErrors().get(0).getCodes())[1].split("\\."); + String message = e.getBindingResult().getAllErrors().get(0).getDefaultMessage(); + String msg = "不能为空"; + if (msg.equals(message)) { + message = str[1] + ":" + message; + } + return buildResponseEntity(HttpStatus.BAD_REQUEST, new DataResponseBody(ResponseCode.ERROR, message)); + } + + @ExceptionHandler(BindException.class) + public ResponseEntity bindException(BindException e) { + // 打印堆栈信息 + LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e); + ObjectError error = e.getAllErrors().get(0); + return buildResponseEntity(HttpStatus.BAD_REQUEST, + new DataResponseBody(ResponseCode.ERROR, error.getDefaultMessage())); + } + + /** + * 统一返回 + * + * @param httpStatus + * @param responseBody + * @return + */ + private ResponseEntity buildResponseEntity(HttpStatus httpStatus, DataResponseBody responseBody) { + return new ResponseEntity<>(responseBody, httpStatus); + } } diff --git a/dubhe-server/common/src/main/java/org/dubhe/filter/BaseLogFilter.java b/dubhe-server/common/src/main/java/org/dubhe/filter/BaseLogFilter.java new file mode 100644 index 0000000..3b83331 --- /dev/null +++ b/dubhe-server/common/src/main/java/org/dubhe/filter/BaseLogFilter.java @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.dubhe.filter; + +import ch.qos.logback.classic.Level; +import ch.qos.logback.classic.spi.ILoggingEvent; +import ch.qos.logback.core.filter.AbstractMatcherFilter; +import ch.qos.logback.core.spi.FilterReply; +import cn.hutool.core.util.StrUtil; +import org.slf4j.Marker; + +/** + * @description 自定义日志过滤器 + * @date 2020-07-21 + */ +public class BaseLogFilter extends AbstractMatcherFilter { + + Level level; + + /** + * 重写decide方法 + * + * @param iLoggingEvent event to decide upon. + * @return FilterReply + */ + @Override + public FilterReply decide(ILoggingEvent iLoggingEvent) { + if (!isStarted()) { + return FilterReply.NEUTRAL; + } + final String msg = iLoggingEvent.getMessage(); + //自定义级别 + if (checkLevel(iLoggingEvent) && msg != null && msg.startsWith(StrUtil.DELIM_START) && msg.endsWith(StrUtil.DELIM_END)) { + final Marker marker = iLoggingEvent.getMarker(); + if (marker != null && this.getName() != null && this.getName().contains(marker.getName())) { + return onMatch; + } + } + + return onMismatch; + } + + protected boolean checkLevel(ILoggingEvent iLoggingEvent) { + return this.level != null + && iLoggingEvent.getLevel() != null + && iLoggingEvent.getLevel().toInt() == this.level.toInt(); + } + + public void setLevel(Level level) { + this.level = level; + } + + @Override + public void start() { + if (this.level != null) { + super.start(); + } + } +} \ No newline at end of file diff --git a/dubhe-server/common/src/main/java/org/dubhe/filter/ConsoleLogFilter.java b/dubhe-server/common/src/main/java/org/dubhe/filter/ConsoleLogFilter.java new file mode 100644 index 0000000..69b1f78 --- /dev/null +++ b/dubhe-server/common/src/main/java/org/dubhe/filter/ConsoleLogFilter.java @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.dubhe.filter; + +import ch.qos.logback.classic.spi.ILoggingEvent; +import ch.qos.logback.core.spi.FilterReply; +import org.dubhe.utils.LogUtil; +import org.slf4j.MarkerFactory; + +/** + * @description 自定义日志过滤器 + * @date 2020-07-21 + */ +public class ConsoleLogFilter extends BaseLogFilter { + + @Override + public FilterReply decide(ILoggingEvent iLoggingEvent) { + if (!isStarted()) { + return FilterReply.NEUTRAL; + } + return checkLevel(iLoggingEvent) ? onMatch : onMismatch; + } + + protected boolean checkLevel(ILoggingEvent iLoggingEvent) { + + + return this.level != null + && iLoggingEvent.getLevel() != null + && iLoggingEvent.getLevel().toInt() >= this.level.toInt() + && !MarkerFactory.getMarker(LogUtil.K8S_CALLBACK_LEVEL).equals(iLoggingEvent.getMarker()) + && !MarkerFactory.getMarker(LogUtil.SCHEDULE_LEVEL).equals(iLoggingEvent.getMarker()) + && !"log4jdbc.log4j2".equals(iLoggingEvent.getLoggerName()); + } +} \ No newline at end of file diff --git a/dubhe-server/common/src/main/java/org/dubhe/filter/FileLogFilter.java b/dubhe-server/common/src/main/java/org/dubhe/filter/FileLogFilter.java deleted file mode 100644 index 9fe2b47..0000000 --- a/dubhe-server/common/src/main/java/org/dubhe/filter/FileLogFilter.java +++ /dev/null @@ -1,60 +0,0 @@ -/** - * Copyright 2019-2020 Zheng Jie - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.dubhe.filter; - -import ch.qos.logback.classic.Level; -import ch.qos.logback.classic.spi.ILoggingEvent; -import ch.qos.logback.core.filter.AbstractMatcherFilter; -import ch.qos.logback.core.spi.FilterReply; - -/** - * @description 自定义日志过滤器 - * @date 2020-07-21 - */ -public class FileLogFilter extends AbstractMatcherFilter { - - Level level; - - /** - * 重写decide方法 - * - * @param iLoggingEvent event to decide upon. - * @return FilterReply - */ - @Override - public FilterReply decide(ILoggingEvent iLoggingEvent) { - if (!isStarted()) { - return FilterReply.NEUTRAL; - } - if (iLoggingEvent.getLevel().equals(level) && iLoggingEvent.getMessage() != null - && iLoggingEvent.getMessage().startsWith("{") && iLoggingEvent.getMessage().endsWith("}")) { - return onMatch; - } - return onMismatch; - - } - - public void setLevel(Level level) { - this.level = level; - } - - @Override - public void start() { - if (this.level != null) { - super.start(); - } - } -} \ No newline at end of file diff --git a/dubhe-server/common/src/main/java/org/dubhe/filter/GlobalRequestLogFilter.java b/dubhe-server/common/src/main/java/org/dubhe/filter/GlobalRequestLogFilter.java new file mode 100644 index 0000000..ae8929b --- /dev/null +++ b/dubhe-server/common/src/main/java/org/dubhe/filter/GlobalRequestLogFilter.java @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.dubhe.filter; + +import ch.qos.logback.classic.spi.ILoggingEvent; + +/** + * @description 全局请求 日志过滤器 + * @date 2020-08-13 + */ +public class GlobalRequestLogFilter extends BaseLogFilter { + + + @Override + public boolean checkLevel(ILoggingEvent iLoggingEvent) { + return this.level != null; + } + +} \ No newline at end of file diff --git a/dubhe-server/common/src/main/java/org/dubhe/interceptor/MySqlInterceptor.java b/dubhe-server/common/src/main/java/org/dubhe/interceptor/MySqlInterceptor.java new file mode 100644 index 0000000..534a650 --- /dev/null +++ b/dubhe-server/common/src/main/java/org/dubhe/interceptor/MySqlInterceptor.java @@ -0,0 +1,113 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ +package org.dubhe.interceptor; + +import org.apache.ibatis.executor.statement.StatementHandler; +import org.apache.ibatis.mapping.BoundSql; +import org.apache.ibatis.mapping.MappedStatement; +import org.apache.ibatis.plugin.*; +import org.apache.ibatis.reflection.DefaultReflectorFactory; +import org.apache.ibatis.reflection.MetaObject; +import org.apache.ibatis.reflection.SystemMetaObject; +import org.dubhe.annotation.DataPermission; +import org.dubhe.base.DataContext; +import org.dubhe.domain.dto.UserDTO; +import org.dubhe.enums.OperationTypeEnum; +import org.dubhe.utils.JwtUtils; +import org.dubhe.utils.SqlUtil; +import org.springframework.stereotype.Component; + +import java.lang.reflect.Field; +import java.sql.Connection; +import java.util.Arrays; +import java.util.Objects; +import java.util.Properties; + + +/** + * @description mybatis拦截器 + * @date 2020-06-10 + */ +@Component +@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})}) +public class MySqlInterceptor implements Interceptor { + @Override + public Object intercept(Invocation invocation) throws Throwable { + + StatementHandler statementHandler = (StatementHandler) invocation.getTarget(); + MetaObject metaObject = MetaObject.forObject(statementHandler, SystemMetaObject.DEFAULT_OBJECT_FACTORY, + SystemMetaObject.DEFAULT_OBJECT_WRAPPER_FACTORY, new DefaultReflectorFactory()); + /* + * 先拦截到RoutingStatementHandler,里面有个StatementHandler类型的delegate变量,其实现类是BaseStatementHandler, + * 然后就到BaseStatementHandler的成员变量mappedStatement + */ + MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement"); + //id为执行的mapper方法的全路径名,如com.uv.dao.UserDao.selectPageVo + String id = mappedStatement.getId(); + //sql语句类型 select、delete、insert、update + String sqlCommandType = mappedStatement.getSqlCommandType().toString(); + BoundSql boundSql = statementHandler.getBoundSql(); + + //获取到原始sql语句 + String sql = boundSql.getSql(); + String mSql = sql; + + //注解逻辑判断 添加注解了才拦截 + Class classType = Class.forName(mappedStatement.getId().substring(0, mappedStatement.getId().lastIndexOf("."))); + String mName = mappedStatement.getId().substring(mappedStatement.getId().lastIndexOf(".") + 1, mappedStatement.getId().length()); + UserDTO currentUserDto = JwtUtils.getCurrentUserDto(); + + //获取类注解 获取需要忽略拦截的方法名称 + DataPermission dataAnnotation = classType.getAnnotation(DataPermission.class); + if (!Objects.isNull(dataAnnotation)) { + + String[] ignores = dataAnnotation.ignoresMethod(); + //校验拦截忽略方法名 忽略新增方法 忽略回调/定时方法 + if ((!Objects.isNull(ignores) && Arrays.asList(ignores).contains(mName)) + || OperationTypeEnum.INSERT.getType().equals(sqlCommandType.toLowerCase()) + || Objects.isNull(currentUserDto) + || (!Objects.isNull(DataContext.get()) && DataContext.get().getType()) + ) { + return invocation.proceed(); + } else { + //拦截所有sql操作类型 + mSql = SqlUtil.buildTargetSql(sql, SqlUtil.getResourceIds()); + } + } + + //通过反射修改sql语句 + Field field = boundSql.getClass().getDeclaredField("sql"); + field.setAccessible(true); + field.set(boundSql, mSql); + return invocation.proceed(); + } + + @Override + public Object plugin(Object target) { + if (target instanceof StatementHandler) { + return Plugin.wrap(target, this); + } else { + return target; + } + } + + @Override + public void setProperties(Properties properties) { + + } + +} \ No newline at end of file diff --git a/dubhe-server/common/src/main/java/org/dubhe/interceptor/PaginationInterceptor.java b/dubhe-server/common/src/main/java/org/dubhe/interceptor/PaginationInterceptor.java new file mode 100644 index 0000000..0eeb8f4 --- /dev/null +++ b/dubhe-server/common/src/main/java/org/dubhe/interceptor/PaginationInterceptor.java @@ -0,0 +1,457 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ +package org.dubhe.interceptor; + +import com.baomidou.mybatisplus.annotation.DbType; +import com.baomidou.mybatisplus.core.MybatisDefaultParameterHandler; +import com.baomidou.mybatisplus.core.metadata.IPage; +import com.baomidou.mybatisplus.core.metadata.OrderItem; +import com.baomidou.mybatisplus.core.parser.ISqlParser; +import com.baomidou.mybatisplus.core.parser.SqlInfo; +import com.baomidou.mybatisplus.core.toolkit.CollectionUtils; +import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils; +import com.baomidou.mybatisplus.core.toolkit.PluginUtils; +import com.baomidou.mybatisplus.core.toolkit.StringUtils; +import com.baomidou.mybatisplus.extension.handlers.AbstractSqlParserHandler; +import com.baomidou.mybatisplus.extension.plugins.pagination.DialectFactory; +import com.baomidou.mybatisplus.extension.plugins.pagination.DialectModel; +import com.baomidou.mybatisplus.extension.plugins.pagination.dialects.IDialect; +import com.baomidou.mybatisplus.extension.toolkit.JdbcUtils; +import com.baomidou.mybatisplus.extension.toolkit.SqlParserUtils; +import net.sf.jsqlparser.JSQLParserException; +import net.sf.jsqlparser.parser.CCJSqlParserUtil; +import net.sf.jsqlparser.schema.Column; +import net.sf.jsqlparser.statement.select.*; +import org.apache.ibatis.executor.statement.StatementHandler; +import org.apache.ibatis.logging.Log; +import org.apache.ibatis.logging.LogFactory; +import org.apache.ibatis.mapping.*; +import org.apache.ibatis.plugin.*; +import org.apache.ibatis.reflection.MetaObject; +import org.apache.ibatis.reflection.SystemMetaObject; +import org.apache.ibatis.scripting.defaults.DefaultParameterHandler; +import org.apache.ibatis.session.Configuration; +import org.dubhe.annotation.DataPermission; +import org.dubhe.base.DataContext; +import org.dubhe.domain.dto.UserDTO; +import org.dubhe.enums.OperationTypeEnum; +import org.dubhe.utils.JwtUtils; +import org.dubhe.utils.SqlUtil; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.util.*; +import java.util.stream.Collectors; + +/** + * @description MybatisPlus 分页拦截器 + * @date 2020-10-09 + */ +@Intercepts({@Signature( + type = StatementHandler.class, + method = "prepare", + args = {Connection.class, Integer.class} +)}) +public class PaginationInterceptor extends AbstractSqlParserHandler implements Interceptor { + protected static final Log logger = LogFactory.getLog(PaginationInterceptor.class); + + /** + * COUNT SQL 解析 + */ + protected ISqlParser countSqlParser; + /** + * 溢出总页数,设置第一页 + */ + protected boolean overflow = false; + /** + * 单页限制 500 条,小于 0 如 -1 不受限制 + */ + protected long limit = 500L; + /** + * 数据类型 + */ + private DbType dbType; + /** + * 方言 + */ + private IDialect dialect; + /** + * 方言类型 + */ + @Deprecated + protected String dialectType; + /** + * 方言实现类 + */ + @Deprecated + protected String dialectClazz; + + public PaginationInterceptor() { + } + + /** + * 构建分页sql + * + * @param originalSql 原生sql + * @param page 分页参数 + * @return 构建后 sql + */ + public static String concatOrderBy(String originalSql, IPage page) { + if (CollectionUtils.isNotEmpty(page.orders())) { + try { + List orderList = page.orders(); + Select selectStatement = (Select) CCJSqlParserUtil.parse(originalSql); + List orderByElements; + List orderByElementsReturn; + if (selectStatement.getSelectBody() instanceof PlainSelect) { + PlainSelect plainSelect = (PlainSelect) selectStatement.getSelectBody(); + orderByElements = plainSelect.getOrderByElements(); + orderByElementsReturn = addOrderByElements(orderList, orderByElements); + plainSelect.setOrderByElements(orderByElementsReturn); + return plainSelect.toString(); + } + + if (selectStatement.getSelectBody() instanceof SetOperationList) { + SetOperationList setOperationList = (SetOperationList) selectStatement.getSelectBody(); + orderByElements = setOperationList.getOrderByElements(); + orderByElementsReturn = addOrderByElements(orderList, orderByElements); + setOperationList.setOrderByElements(orderByElementsReturn); + return setOperationList.toString(); + } + + if (selectStatement.getSelectBody() instanceof WithItem) { + return originalSql; + } + + return originalSql; + } catch (JSQLParserException var7) { + logger.error("failed to concat orderBy from IPage, exception=", var7); + } + } + + return originalSql; + } + + /** + * 添加分页排序规则 + * + * @param orderList 分页规则 + * @param orderByElements 分页排序元素 + * @return 分页规则 + */ + private static List addOrderByElements(List orderList, List orderByElements) { + orderByElements = CollectionUtils.isEmpty(orderByElements) ? new ArrayList(orderList.size()) : orderByElements; + List orderByElementList = (List) orderList.stream().filter((item) -> { + return StringUtils.isNotBlank(item.getColumn()); + }).map((item) -> { + OrderByElement element = new OrderByElement(); + element.setExpression(new Column(item.getColumn())); + element.setAsc(item.isAsc()); + element.setAscDescPresent(true); + return element; + }).collect(Collectors.toList()); + ((List) orderByElements).addAll(orderByElementList); + return (List) orderByElements; + } + + /** + * 执行sql查询逻辑 + * + * @param invocation mybatis 调用类 + * @return + * @throws Throwable + */ + @Override + public Object intercept(Invocation invocation) throws Throwable { + StatementHandler statementHandler = (StatementHandler) PluginUtils.realTarget(invocation.getTarget()); + MetaObject metaObject = SystemMetaObject.forObject(statementHandler); + this.sqlParser(metaObject); + MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement"); + if (SqlCommandType.SELECT == mappedStatement.getSqlCommandType() && StatementType.CALLABLE != mappedStatement.getStatementType()) { + BoundSql boundSql = (BoundSql) metaObject.getValue("delegate.boundSql"); + Object paramObj = boundSql.getParameterObject(); + IPage page = null; + if (paramObj instanceof IPage) { + page = (IPage) paramObj; + } else if (paramObj instanceof Map) { + Iterator var8 = ((Map) paramObj).values().iterator(); + + while (var8.hasNext()) { + Object arg = var8.next(); + if (arg instanceof IPage) { + page = (IPage) arg; + break; + } + } + } + + if (null != page && page.getSize() >= 0L) { + if (this.limit > 0L && this.limit <= page.getSize()) { + this.handlerLimit(page); + } + + String originalSql = boundSql.getSql(); + + //注解逻辑判断 添加注解了才拦截 + Class classType = Class.forName(mappedStatement.getId().substring(0, mappedStatement.getId().lastIndexOf("."))); + String mName = mappedStatement.getId().substring(mappedStatement.getId().lastIndexOf(".") + 1, mappedStatement.getId().length()); + UserDTO currentUserDto = JwtUtils.getCurrentUserDto(); + + String sqlCommandType = mappedStatement.getSqlCommandType().toString(); + //获取类注解 获取需要忽略拦截的方法名称 + DataPermission dataAnnotation = classType.getAnnotation(DataPermission.class); + if (!Objects.isNull(dataAnnotation)) { + + String[] ignores = dataAnnotation.ignoresMethod(); + //校验拦截忽略方法名 忽略新增方法 忽略回调/定时方法 + if (!((!Objects.isNull(ignores) && Arrays.asList(ignores).contains(mName)) + || OperationTypeEnum.INSERT.getType().equals(sqlCommandType.toLowerCase()) + || Objects.isNull(currentUserDto) + || (!Objects.isNull(DataContext.get()) && DataContext.get().getType())) + + + ) { + originalSql = SqlUtil.buildTargetSql(originalSql, SqlUtil.getResourceIds()); + } + } + + Connection connection = (Connection) invocation.getArgs()[0]; + if (page.isSearchCount() && !page.isHitCount()) { + SqlInfo sqlInfo = SqlParserUtils.getOptimizeCountSql(page.optimizeCountSql(), this.countSqlParser, originalSql); + this.queryTotal(sqlInfo.getSql(), mappedStatement, boundSql, page, connection); + if (page.getTotal() <= 0L) { + return null; + } + } + + DbType dbType = Optional.ofNullable(this.dbType).orElse(JdbcUtils.getDbType(connection.getMetaData().getURL())); + IDialect dialect = Optional.ofNullable(this.dialect).orElse(DialectFactory.getDialect(dbType)); + String buildSql = concatOrderBy(originalSql, page); + DialectModel model = dialect.buildPaginationSql(buildSql, page.offset(), page.getSize()); + Configuration configuration = mappedStatement.getConfiguration(); + List mappings = new ArrayList(boundSql.getParameterMappings()); + Map additionalParameters = (Map) metaObject.getValue("delegate.boundSql.additionalParameters"); + model.consumers(mappings, configuration, additionalParameters); + metaObject.setValue("delegate.boundSql.sql", model.getDialectSql()); + metaObject.setValue("delegate.boundSql.parameterMappings", mappings); + return invocation.proceed(); + } else { + return invocation.proceed(); + } + } else { + return invocation.proceed(); + } + } + + /** + * 处理分页数量 + * + * @param page 分页参数 + */ + protected void handlerLimit(IPage page) { + page.setSize(this.limit); + } + + /** + * 查询总数量 + * + * @param sql sql语句 + * @param mappedStatement 映射语句包装类 + * @param boundSql sql包装类 + * @param page 分页参数 + * @param connection JDBC连接包装类 + */ + protected void queryTotal(String sql, MappedStatement mappedStatement, BoundSql boundSql, IPage page, Connection connection) { + try { + PreparedStatement statement = connection.prepareStatement(sql); + Throwable var7 = null; + + try { + DefaultParameterHandler parameterHandler = new MybatisDefaultParameterHandler(mappedStatement, boundSql.getParameterObject(), boundSql); + parameterHandler.setParameters(statement); + long total = 0L; + ResultSet resultSet = statement.executeQuery(); + Throwable var12 = null; + + try { + if (resultSet.next()) { + total = resultSet.getLong(1); + } + } catch (Throwable var37) { + var12 = var37; + throw var37; + } finally { + if (resultSet != null) { + if (var12 != null) { + try { + resultSet.close(); + } catch (Throwable var36) { + var12.addSuppressed(var36); + } + } else { + resultSet.close(); + } + } + + } + + page.setTotal(total); + if (this.overflow && page.getCurrent() > page.getPages()) { + this.handlerOverflow(page); + } + } catch (Throwable var39) { + var7 = var39; + throw var39; + } finally { + if (statement != null) { + if (var7 != null) { + try { + statement.close(); + } catch (Throwable var35) { + var7.addSuppressed(var35); + } + } else { + statement.close(); + } + } + + } + + } catch (Exception var41) { + throw ExceptionUtils.mpe("Error: Method queryTotal execution error of sql : \n %s \n", var41, new Object[]{sql}); + } + } + + /** + * 设置默认当前页 + * + * @param page 分页参数 + */ + protected void handlerOverflow(IPage page) { + page.setCurrent(1L); + } + + /** + * MybatisPlus拦截器实现自定义插件 + * + * @param target 拦截目标对象 + * @return + */ + @Override + public Object plugin(Object target) { + return target instanceof StatementHandler ? Plugin.wrap(target, this) : target; + } + + /** + * MybatisPlus拦截器实现自定义属性设置 + * + * @param prop 属性参数 + */ + @Override + public void setProperties(Properties prop) { + String dialectType = prop.getProperty("dialectType"); + String dialectClazz = prop.getProperty("dialectClazz"); + if (StringUtils.isNotBlank(dialectType)) { + this.setDialectType(dialectType); + } + + if (StringUtils.isNotBlank(dialectClazz)) { + this.setDialectClazz(dialectClazz); + } + + } + + /** + * 设置数据源类型 + * + * @param dialectType 数据源类型 + */ + @Deprecated + public void setDialectType(String dialectType) { + this.setDbType(DbType.getDbType(dialectType)); + } + + + /** + * 设置方言实现类配置 + * + * @param dialectClazz 方言实现类 + */ + @Deprecated + public void setDialectClazz(String dialectClazz) { + this.setDialect(DialectFactory.getDialect(dialectClazz)); + } + + /** + * 设置获取总数的sql解析器 + * + * @param countSqlParser 总数的sql解析器 + * @return 自定义MybatisPlus拦截器 + */ + public PaginationInterceptor setCountSqlParser(final ISqlParser countSqlParser) { + this.countSqlParser = countSqlParser; + return this; + } + + /** + * 溢出总页数,设置第一页 + * + * @param overflow 溢出总页数 + * @return 自定义MybatisPlus拦截器 + */ + public PaginationInterceptor setOverflow(final boolean overflow) { + this.overflow = overflow; + return this; + } + + /** + * 设置分页规则 + * + * @param limit 分页数量 + * @return 自定义MybatisPlus拦截器 + */ + public PaginationInterceptor setLimit(final long limit) { + this.limit = limit; + return this; + } + + /** + * 设置数据类型 + * + * @param dbType 数据类型 + * @return 自定义MybatisPlus拦截器 + */ + public PaginationInterceptor setDbType(final DbType dbType) { + this.dbType = dbType; + return this; + } + + /** + * 设置方言 + * + * @param dialect 方言 + * @return 自定义MybatisPlus拦截器 + */ + public PaginationInterceptor setDialect(final IDialect dialect) { + this.dialect = dialect; + return this; + } + + +} + diff --git a/dubhe-server/common/src/main/java/org/dubhe/utils/DateUtil.java b/dubhe-server/common/src/main/java/org/dubhe/utils/DateUtil.java index 5db3ee9..150d09f 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/utils/DateUtil.java +++ b/dubhe-server/common/src/main/java/org/dubhe/utils/DateUtil.java @@ -18,16 +18,25 @@ package org.dubhe.utils; import java.sql.Timestamp; +import java.text.DateFormat; +import java.text.SimpleDateFormat; import java.time.Instant; import java.time.LocalDate; import java.time.LocalDateTime; import java.time.ZoneId; +import java.util.Date; /** * @description 日期工具类 - * @date 2020-6-10 + * @date 2020-06-10 */ public class DateUtil { + + private DateUtil(){ + + } + + /** * 获取当前时间戳 * @@ -77,4 +86,13 @@ public class DateUtil { return (milli-l1); } + /** + * @return 当前字符串时间yyyy-MM-dd HH:mm:ss SSS + */ + public static String getCurrentTimeStr(){ + Date date = new Date(); + DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss SSS"); + return dateFormat.format(date); + } + } diff --git a/dubhe-server/common/src/main/java/org/dubhe/utils/FileUtil.java b/dubhe-server/common/src/main/java/org/dubhe/utils/FileUtil.java index 430f2c6..7dde12f 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/utils/FileUtil.java +++ b/dubhe-server/common/src/main/java/org/dubhe/utils/FileUtil.java @@ -19,10 +19,12 @@ package org.dubhe.utils; import cn.hutool.core.codec.Base64; import cn.hutool.core.io.IoUtil; +import cn.hutool.core.util.CharsetUtil; import cn.hutool.core.util.IdUtil; import cn.hutool.poi.excel.BigExcelWriter; import cn.hutool.poi.excel.ExcelUtil; import org.apache.poi.util.IOUtils; +import org.dubhe.enums.LogEnum; import org.dubhe.exception.BusinessException; import org.springframework.web.multipart.MultipartFile; @@ -352,4 +354,56 @@ public class FileUtil extends cn.hutool.core.io.FileUtil { return getMd5(getByte(file)); } + + /** + * 生成文件 + * @param filePath 文件绝对路径 + * @param content 文件内容 + * @param append 文件是否是追加 + * @return + */ + public static boolean generateFile(String filePath,String content,boolean append){ + File file = new File(filePath); + FileOutputStream outputStream = null; + try { + if (!file.exists()){ + file.createNewFile(); + } + outputStream = new FileOutputStream(file,append); + outputStream.write(content.getBytes(CharsetUtil.defaultCharset())); + outputStream.flush(); + }catch (IOException e) { + LogUtil.error(LogEnum.FILE_UTIL,e); + return false; + }finally { + if (outputStream != null){ + try { + outputStream.close(); + } catch (IOException e) { + LogUtil.error(LogEnum.FILE_UTIL,e); + } + } + } + return true; + } + + + /** + * 压缩文件目录 + * + * @param zipDir 待压缩文件夹路径 + * @param zipFile 压缩完成zip文件绝对路径 + * @return + */ + public static boolean zipPath(String zipDir,String zipFile) { + if (zipDir == null) { + return false; + } + File zip = new File(zipFile); + cn.hutool.core.util.ZipUtil.zip(zip, CharsetUtil.defaultCharset(), true, + (f) -> !f.isDirectory(), + new File(zipDir).listFiles()); + return true; + } + } diff --git a/dubhe-server/common/src/main/java/org/dubhe/utils/HttpClientUtils.java b/dubhe-server/common/src/main/java/org/dubhe/utils/HttpClientUtils.java index a2250b5..e16145d 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/utils/HttpClientUtils.java +++ b/dubhe-server/common/src/main/java/org/dubhe/utils/HttpClientUtils.java @@ -20,6 +20,7 @@ import org.apache.commons.io.IOUtils; import org.dubhe.enums.LogEnum; + import javax.net.ssl.HttpsURLConnection; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSocketFactory; @@ -31,13 +32,13 @@ import java.io.InputStreamReader; import java.net.URL; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; - +import org.apache.commons.codec.binary.Base64; import static org.dubhe.constant.StringConstant.UTF8; import static org.dubhe.constant.SymbolConstant.BLANK; /** - * @description: httpClient工具类,不校验SSL证书 - * @date: 2020-5-21 + * @description httpClient工具类,不校验SSL证书 + * @date 2020-05-21 */ public class HttpClientUtils { @@ -45,7 +46,7 @@ public class HttpClientUtils { InputStream inputStream = null; BufferedReader bufferedReader = null; InputStreamReader inputStreamReader = null; - StringBuilder stringBuider = new StringBuilder(); + StringBuilder stringBuilder = new StringBuilder(); String result = BLANK; HttpsURLConnection con = null; try { @@ -59,10 +60,10 @@ public class HttpClientUtils { String str = null; while ((str = bufferedReader.readLine()) != null) { - stringBuider.append(str); + stringBuilder.append(str); } - result = stringBuider.toString(); + result = stringBuilder.toString(); LogUtil.info(LogEnum.BIZ_SYS,"Request path:{}, SUCCESS, result:{}", path, result); } catch (Exception e) { @@ -74,17 +75,54 @@ public class HttpClientUtils { return result; } + public static String sendHttpsDelete(String path,String username,String password) { + InputStream inputStream = null; + BufferedReader bufferedReader = null; + InputStreamReader inputStreamReader = null; + StringBuilder stringBuilder = new StringBuilder(); + String result = BLANK; + HttpsURLConnection con = null; + try { + con = getConnection(path); + String input =username+ ":" +password; + String encoding=Base64.encodeBase64String(input.getBytes()); + con.setRequestProperty(JwtUtils.AUTH_HEADER, "Basic " + encoding); + con.setRequestMethod("DELETE"); + con.connect(); + /**将返回的输入流转换成字符串**/ + inputStream = con.getInputStream(); + inputStreamReader = new InputStreamReader(inputStream, UTF8); + bufferedReader = new BufferedReader(inputStreamReader); - private static void closeResource(BufferedReader bufferedReader,InputStreamReader inputStreamReader,InputStream inputStream,HttpsURLConnection con) { + String str = null; + while ((str = bufferedReader.readLine()) != null) { + stringBuilder.append(str); + } - IOUtils.closeQuietly(bufferedReader); + result = stringBuilder.toString(); + LogUtil.info(LogEnum.BIZ_SYS,"Request path:{}, SUCCESS, result:{}", path, result); - if (inputStreamReader != null) { - IOUtils.closeQuietly(inputStreamReader); + } catch (Exception e) { + LogUtil.error(LogEnum.BIZ_SYS,"Request path:{}, ERROR, exception:{}", path, e); + return result; + } finally { + closeResource(bufferedReader,inputStreamReader,inputStream,con); } + + return result; + } + + private static void closeResource(BufferedReader bufferedReader,InputStreamReader inputStreamReader,InputStream inputStream,HttpsURLConnection con) { if (inputStream != null) { IOUtils.closeQuietly(inputStream); } + + if (inputStreamReader != null) { + IOUtils.closeQuietly(inputStreamReader); + } + if (bufferedReader != null) { + IOUtils.closeQuietly(bufferedReader); + } if (con != null) { con.disconnect(); } diff --git a/dubhe-server/common/src/main/java/org/dubhe/utils/HttpUtils.java b/dubhe-server/common/src/main/java/org/dubhe/utils/HttpUtils.java index 30e8309..70a72db 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/utils/HttpUtils.java +++ b/dubhe-server/common/src/main/java/org/dubhe/utils/HttpUtils.java @@ -20,8 +20,8 @@ package org.dubhe.utils; import lombok.extern.slf4j.Slf4j; /** - * @description: HttpUtil - * @date 2020.04.30 + * @description HttpUtil + * @date 2020-04-30 */ @Slf4j public class HttpUtils { diff --git a/dubhe-server/common/src/main/java/org/dubhe/utils/IOUtil.java b/dubhe-server/common/src/main/java/org/dubhe/utils/IOUtil.java new file mode 100644 index 0000000..dd6a94b --- /dev/null +++ b/dubhe-server/common/src/main/java/org/dubhe/utils/IOUtil.java @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ +package org.dubhe.utils; + +import org.dubhe.enums.LogEnum; + +import java.io.Closeable; +import java.io.IOException; + +/** + * @description IO流操作工具类 + * @date 2020-10-14 + */ +public class IOUtil { + + /** + * 循环的依次关闭流 + * + * @param closeableList 要被关闭的流集合 + */ + public static void close(Closeable... closeableList) { + for (Closeable closeable : closeableList) { + try { + if (closeable != null) { + closeable.close(); + } + } catch (IOException e) { + LogUtil.error(LogEnum.IO_UTIL, "关闭流异常,异常信息:{}", e); + } + } + } +} diff --git a/dubhe-server/common/src/main/java/org/dubhe/utils/JwtUtils.java b/dubhe-server/common/src/main/java/org/dubhe/utils/JwtUtils.java index 573880d..afaaed5 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/utils/JwtUtils.java +++ b/dubhe-server/common/src/main/java/org/dubhe/utils/JwtUtils.java @@ -139,7 +139,7 @@ public class JwtUtils { public static boolean isTokenExpired(String token) { Date now = Calendar.getInstance().getTime(); DecodedJWT jwt = JWT.decode(token); - return jwt.getExpiresAt().before(now); + return jwt.getExpiresAt() == null || jwt.getExpiresAt().before(now); } /** diff --git a/dubhe-server/common/src/main/java/org/dubhe/utils/K8sNameTool.java b/dubhe-server/common/src/main/java/org/dubhe/utils/K8sNameTool.java index af25452..1f8846d 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/utils/K8sNameTool.java +++ b/dubhe-server/common/src/main/java/org/dubhe/utils/K8sNameTool.java @@ -67,12 +67,12 @@ public class K8sNameTool { } /** - * 生成 Notebook的NameSpace + * 生成 Notebook的Namespace * * @param userId * @return namespace */ - public String generateNameSpace(long userId) { + public String generateNamespace(long userId) { return this.k8sNameConfig.getNamespace() + SEPARATOR + userId; } @@ -96,7 +96,7 @@ public class K8sNameTool { * @param namespace * @return Long */ - public Long getUserIdFromNameSpace(String namespace) { + public Long getUserIdFromNamespace(String namespace) { if (StringUtils.isEmpty(namespace) || !namespace.contains(this.k8sNameConfig.getNamespace() + SEPARATOR)) { return null; } diff --git a/dubhe-server/common/src/main/java/org/dubhe/utils/LocalFileUtil.java b/dubhe-server/common/src/main/java/org/dubhe/utils/LocalFileUtil.java new file mode 100644 index 0000000..e5bad3d --- /dev/null +++ b/dubhe-server/common/src/main/java/org/dubhe/utils/LocalFileUtil.java @@ -0,0 +1,307 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ +package org.dubhe.utils; + +import cn.hutool.core.util.StrUtil; +import lombok.Getter; +import org.apache.commons.compress.archivers.zip.ZipArchiveEntry; +import org.apache.commons.compress.archivers.zip.ZipFile; +import org.apache.commons.io.IOUtils; +import org.dubhe.base.MagicNumConstant; +import org.dubhe.config.NfsConfig; +import org.dubhe.enums.LogEnum; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Component; +import org.springframework.util.FileCopyUtils; + +import java.io.*; +import java.util.Enumeration; +import java.util.zip.ZipEntry; + +/** + * @description 本地文件操作工具类 + * @date 2020-08-19 + */ +@Component +@Getter +public class LocalFileUtil { + + @Autowired + private NfsConfig nfsConfig; + + private static final String FILE_SEPARATOR = File.separator; + + private static final String ZIP = ".zip"; + + private static final String CHARACTER_GBK = "GBK"; + + private static final String OS_NAME = "os.name"; + + private static final String WINDOWS = "Windows"; + + @Value("${k8s.nfs-root-path}") + private String nfsRootPath; + + @Value("${k8s.nfs-root-windows-path}") + private String nfsRootWindowsPath; + + /** + * windows 与 linux 的路径兼容 + * + * @param path linux下的路径 + * @return path 兼容windows后的路径 + */ + private String compatiblePath(String path) { + if (path == null) { + return null; + } + if (System.getProperties().getProperty(OS_NAME).contains(WINDOWS)) { + path = path.replace(nfsRootPath, StrUtil.SLASH); + path = path.replace(StrUtil.SLASH, FILE_SEPARATOR); + path = nfsRootWindowsPath + path; + } + return path; + } + + + /** + * 本地解压zip包并删除压缩文件 + * + * @param sourcePath zip源文件 例如:/abc/z.zip + * @param targetPath 解压后的目标文件夹 例如:/abc/ + * @return boolean + */ + public boolean unzipLocalPath(String sourcePath, String targetPath) { + if (StringUtils.isEmpty(sourcePath) || StringUtils.isEmpty(targetPath)) { + return false; + } + if (!sourcePath.toLowerCase().endsWith(ZIP)) { + return false; + } + //绝对路径 + String sourceAbsolutePath = nfsConfig.getRootDir() + sourcePath; + String targetPathAbsolutePath = nfsConfig.getRootDir() + targetPath; + ZipFile zipFile = null; + InputStream in = null; + OutputStream out = null; + File sourceFile = new File(compatiblePath(sourceAbsolutePath)); + File targetFileDir = new File(compatiblePath(targetPathAbsolutePath)); + if (!targetFileDir.exists()) { + boolean targetMkdir = targetFileDir.mkdirs(); + if (!targetMkdir) { + LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "{}failed to create target folder before decompression", sourceAbsolutePath); + } + } + try { + zipFile = new ZipFile(sourceFile); + //判断压缩文件编码方式,并重新获取文件对象 + try { + zipFile.close(); + zipFile = new ZipFile(sourceFile, CHARACTER_GBK); + } catch (Exception e) { + zipFile.close(); + zipFile = new ZipFile(sourceFile); + LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "{}the encoding mode of decompressed compressed file is changed to UTF-8:{}", sourceAbsolutePath, e); + } + ZipEntry entry; + Enumeration enumeration = zipFile.getEntries(); + while (enumeration.hasMoreElements()) { + entry = (ZipEntry) enumeration.nextElement(); + String entryName = entry.getName(); + File fileDir; + if (entry.isDirectory()) { + fileDir = new File(targetPathAbsolutePath + entry.getName()); + if (!fileDir.exists()) { + boolean fileMkdir = fileDir.mkdirs(); + if (!fileMkdir) { + LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "failed to create folder {} while decompressing {}", fileDir, sourceAbsolutePath); + } + } + } else { + //若文件夹未创建则创建文件夹 + if (entryName.contains(FILE_SEPARATOR)) { + String zipDirName = entryName.substring(MagicNumConstant.ZERO, entryName.lastIndexOf(FILE_SEPARATOR)); + fileDir = new File(targetPathAbsolutePath + zipDirName); + if (!fileDir.exists()) { + boolean fileMkdir = fileDir.mkdirs(); + if (!fileMkdir) { + LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "failed to create folder {} while decompressing {}", fileDir, sourceAbsolutePath); + } + } + } + in = zipFile.getInputStream((ZipArchiveEntry) entry); + out = new FileOutputStream(new File(targetPathAbsolutePath, entryName)); + IOUtils.copyLarge(in, out); + in.close(); + out.close(); + } + } + boolean deleteZipFile = sourceFile.delete(); + if (!deleteZipFile) { + LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "{}compressed file deletion failed after decompression", sourceAbsolutePath); + } + return true; + } catch (IOException e) { + LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "{}decompression failed: {}", sourceAbsolutePath, e); + return false; + } finally { + //关闭未关闭的io流 + closeIoFlow(sourceAbsolutePath, zipFile, in, out); + } + + } + + /** + * 关闭未关闭的io流 + * + * @param sourceAbsolutePath 源路径 + * @param zipFile 压缩文件对象 + * @param in 输入流 + * @param out 输出流 + */ + private void closeIoFlow(String sourceAbsolutePath, ZipFile zipFile, InputStream in, OutputStream out) { + if (in != null) { + try { + in.close(); + } catch (IOException e) { + LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "{}input stream shutdown failed: {}", sourceAbsolutePath, e); + } + } + if (out != null) { + try { + out.close(); + } catch (IOException e) { + LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "{}output stream shutdown failed: {}", sourceAbsolutePath, e); + } + } + if (zipFile != null) { + try { + zipFile.close(); + } catch (IOException e) { + LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "{}input stream shutdown failed: {}", sourceAbsolutePath, e); + } + } + } + + /** + * NFS 复制目录到指定目录下 多个文件 包含目录与文件并存情况 + * + * 通过本地文件复制方式 + * + * @param sourcePath 需要复制的文件目录 例如:/abc/def + * @param targetPath 需要放置的目标目录 例如:/abc/dd + * @return boolean + */ + public boolean copyPath(String sourcePath, String targetPath) { + if (StringUtils.isEmpty(sourcePath) || StringUtils.isEmpty(targetPath)) { + return false; + } + sourcePath = formatPath(sourcePath); + targetPath = formatPath(targetPath); + try { + return copyLocalPath(nfsConfig.getRootDir() + sourcePath, nfsConfig.getRootDir() + targetPath); + } catch (Exception e) { + LogUtil.error(LogEnum.LOCAL_FILE_UTIL, " failed to Copy file original path: {} ,target path: {} ,copyPath: {}", sourcePath, targetPath, e); + return false; + } + } + + /** + * 复制文件到指定目录下 单个文件 + * + * @param sourcePath 需要复制的文件 例如:/abc/def/cc.txt + * @param targetPath 需要放置的目标目录 例如:/abc/dd + * @return boolean + */ + private boolean copyLocalFile(String sourcePath, String targetPath) { + if (StringUtils.isEmpty(sourcePath) || StringUtils.isEmpty(targetPath)) { + return false; + } + sourcePath = formatPath(sourcePath); + targetPath = formatPath(targetPath); + try (InputStream input = new FileInputStream(sourcePath); + FileOutputStream output = new FileOutputStream(targetPath)) { + FileCopyUtils.copy(input, output); + return true; + } catch (IOException e) { + LogUtil.error(LogEnum.LOCAL_FILE_UTIL, " failed to copy file original path: {} ,target path: {} ,copyLocalFile:{} ", sourcePath, targetPath, e); + return false; + } + } + + + /** + * 复制文件 到指定目录下 多个文件 包含目录与文件并存情况 + * + * @param sourcePath 需要复制的文件目录 例如:/abc/def + * @param targetPath 需要放置的目标目录 例如:/abc/dd + * @return boolean + */ + private boolean copyLocalPath(String sourcePath, String targetPath) { + if (!StringUtils.isEmpty(sourcePath) && !StringUtils.isEmpty(targetPath)) { + sourcePath = formatPath(sourcePath); + if (sourcePath.endsWith(FILE_SEPARATOR)) { + sourcePath = sourcePath.substring(MagicNumConstant.ZERO, sourcePath.lastIndexOf(FILE_SEPARATOR)); + } + targetPath = formatPath(targetPath); + File sourceFile = new File(sourcePath); + if (sourceFile.exists()) { + File[] files = sourceFile.listFiles(); + if (files != null && files.length != 0) { + for (File file : files) { + try { + if (file.isDirectory()) { + File fileDir = new File(targetPath + FILE_SEPARATOR + file.getName()); + if (!fileDir.exists()) { + fileDir.mkdirs(); + } + copyLocalPath(sourcePath + FILE_SEPARATOR + file.getName(), targetPath + FILE_SEPARATOR + file.getName()); + } + if (file.isFile()) { + File fileTargetPath = new File(targetPath); + if (!fileTargetPath.exists()) { + fileTargetPath.mkdirs(); + } + copyLocalFile(file.getAbsolutePath(), targetPath + FILE_SEPARATOR + file.getName()); + } + } catch (Exception e) { + LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "failed to copy folder original path: {} , target path : {} ,copyLocalPath: {}", sourcePath, targetPath, e); + return false; + } + } + } + return true; + } + } + return false; + } + + /** + * 替换路径中多余的 "/" + * + * @param path + * @return String + */ + public String formatPath(String path) { + if (!StringUtils.isEmpty(path)) { + return path.replaceAll("///*", FILE_SEPARATOR); + } + return path; + } + +} \ No newline at end of file diff --git a/dubhe-server/common/src/main/java/org/dubhe/utils/LogUtil.java b/dubhe-server/common/src/main/java/org/dubhe/utils/LogUtil.java index 36b9dc0..25b273a 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/utils/LogUtil.java +++ b/dubhe-server/common/src/main/java/org/dubhe/utils/LogUtil.java @@ -23,10 +23,10 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.exception.ExceptionUtils; import org.dubhe.aspect.LogAspect; import org.dubhe.base.MagicNumConstant; -import org.dubhe.constant.SymbolConstant; import org.dubhe.domain.entity.LogInfo; import org.dubhe.enums.LogEnum; import org.slf4j.MDC; +import org.slf4j.MarkerFactory; import org.slf4j.helpers.MessageFormatter; import java.util.Arrays; @@ -38,218 +38,283 @@ import java.util.UUID; */ @Slf4j public class LogUtil { - /** - * info级别的日志 - * - * @param logType 日志类型 - * @param object 打印的日志参数 - */ - public static void info(LogEnum logType, Object... object) { - - logHandle(logType, Level.INFO, object); - } - - /** - * debug级别的日志 - * - * @param logType 日志类型 - * @param object 打印的日志参数 - */ - public static void debug(LogEnum logType, Object... object) { - logHandle(logType, Level.DEBUG, object); - } - - /** - * error级别的日志 - * - * @param logType 日志类型 - * @param object 打印的日志参数 - */ - public static void error(LogEnum logType, Object... object) { - errorObjectHandle(object); - logHandle(logType, Level.ERROR, object); - } - - /** - * warn级别的日志 - * - * @param logType 日志类型 - * @param object 打印的日志参数 - */ - public static void warn(LogEnum logType, Object... object) { - logHandle(logType, Level.WARN, object); - } - - /** - * trace级别的日志 - * - * @param logType 日志类型 - * @param object 打印的日志参数 - */ - public static void trace(LogEnum logType, Object... object) { - logHandle(logType, Level.TRACE, object); - } - - /** - * 日志处理 - * - * @param logType 日志类型 - * @param level 日志级别 - * @param object 打印的日志参数 - */ - private static void logHandle(LogEnum logType, Level level, Object[] object) { - - LogInfo logInfo = generateLogInfo(logType, level, object); - String logInfoJsonStr = logJsonStringLengthLimit(logInfo); - switch (Level.toLevel(logInfo.getLevel()).levelInt) { - case Level.TRACE_INT: - log.trace(logInfoJsonStr); - break; - case Level.DEBUG_INT: - log.debug(logInfoJsonStr); - break; - case Level.INFO_INT: - log.info(logInfoJsonStr); - break; - case Level.WARN_INT: - log.warn(logInfoJsonStr); - break; - case Level.ERROR_INT: - log.error(logInfoJsonStr); - break; - default: - } - - } - - /** - * 日志信息组装的内部方法 - * - * @param logType 日志类型 - * @param level 日志级别 - * @param object 打印的日志参数 - * @return LogInfo日志对象信息 - */ - private static LogInfo generateLogInfo(LogEnum logType, Level level, Object[] object) { - LogInfo logInfo = new LogInfo(); - // 日志类型检测 - if (!LogEnum.isLogType(logType)) { - level = Level.ERROR; - object = new Object[MagicNumConstant.ONE]; - object[MagicNumConstant.ZERO] = String.valueOf("logType【").concat(String.valueOf(logType)) - .concat("】is error !"); - logType = LogEnum.SYS_ERR; - } - - // 获取trace_id - if (StringUtils.isEmpty(MDC.get(LogAspect.TRACE_ID))) { - MDC.put(LogAspect.TRACE_ID, UUID.randomUUID().toString()); - } - // 设置logInfo的level,type,traceId属性 - logInfo.setLevel(level.levelStr).setType(logType.toString()).setTraceId(MDC.get(LogAspect.TRACE_ID)); - // 设置logInfo的堆栈信息 - setLogStackInfo(logInfo); - // 设置logInfo的info信息 - setLogInfo(logInfo, object); - // 截取loginfo的长度并转换成json字符串 - return logInfo; - - } - - /** - * 设置loginfo的堆栈信息 - * - * @param logInfo 日志对象 - */ - private static void setLogStackInfo(LogInfo logInfo) { - StackTraceElement[] elements = Thread.currentThread().getStackTrace(); - if (elements.length >= MagicNumConstant.SIX) { - logInfo.setCName(elements[MagicNumConstant.FIVE].getClassName()) - .setMName(elements[MagicNumConstant.FIVE].getMethodName()) - .setLine(String.valueOf(elements[MagicNumConstant.FIVE].getLineNumber())); - } - } - - /** - * 限制log日志的长度并转换成json - * - * @param logInfo 日志对象 - * @return String 日志对象Json字符串 - */ - private static String logJsonStringLengthLimit(LogInfo logInfo) { - try { - String jsonString = JSON.toJSONString(logInfo.getInfo()); - if (jsonString.length() > MagicNumConstant.TEN_THOUSAND) { - jsonString = jsonString.substring(MagicNumConstant.ZERO, MagicNumConstant.TEN_THOUSAND); - } - - logInfo.setInfo(jsonString); - jsonString = JSON.toJSONString(logInfo); - jsonString = jsonString.replace(SymbolConstant.BACKSLASH_MARK, SymbolConstant.MARK) - .replace(SymbolConstant.DOUBLE_MARK, SymbolConstant.MARK) - .replace(SymbolConstant.BRACKETS, SymbolConstant.BLANK); - - return jsonString; - - } catch (Exception e) { - logInfo.setLevel(Level.ERROR.levelStr).setType(LogEnum.SYS_ERR.toString()) - .setInfo("cannot serialize exception: " + ExceptionUtils.getStackTrace(e)); - return JSON.toJSONString(logInfo); - } - } - - /** - * 设置日志对象的info信息 - * - * @param logInfo 日志对象 - * @param object 打印的日志参数 - */ - private static void setLogInfo(LogInfo logInfo, Object[] object) { - for (Object obj : object) { - if (obj instanceof Exception) { - log.error((ExceptionUtils.getStackTrace((Throwable) obj))); - } - } - - if (object.length > MagicNumConstant.ONE) { - logInfo.setInfo(MessageFormatter.arrayFormat(object[MagicNumConstant.ZERO].toString(), - Arrays.copyOfRange(object, MagicNumConstant.ONE, object.length)).getMessage()); - - } else if (object.length == MagicNumConstant.ONE && object[MagicNumConstant.ZERO] instanceof Exception) { - logInfo.setInfo((ExceptionUtils.getStackTrace((Exception) object[MagicNumConstant.ZERO]))); - } else if (object.length == MagicNumConstant.ONE) { - logInfo.setInfo( - object[MagicNumConstant.ZERO] == null ? SymbolConstant.BLANK : object[MagicNumConstant.ZERO]); - } else { - logInfo.setInfo(SymbolConstant.BLANK); - } - - } - - /** - * 处理Exception的情况 - * - * @param object 打印的日志参数 - */ - private static void errorObjectHandle(Object[] object) { - if (object.length >= MagicNumConstant.TWO) { - object[MagicNumConstant.ZERO] = String.valueOf(object[MagicNumConstant.ZERO]) - .concat(SymbolConstant.BRACKETS); - } - - if (object.length == MagicNumConstant.TWO && object[MagicNumConstant.ONE] instanceof Exception) { - log.error((ExceptionUtils.getStackTrace((Throwable) object[MagicNumConstant.ONE]))); - object[MagicNumConstant.ONE] = ExceptionUtils.getStackTrace((Exception) object[MagicNumConstant.ONE]); - - } else if (object.length >= MagicNumConstant.THREE) { - for (int i = 0; i < object.length; i++) { - if (object[i] instanceof Exception) { - log.error((ExceptionUtils.getStackTrace((Throwable) object[i]))); - object[i] = ExceptionUtils.getStackTrace((Exception) object[i]); - } - - } - } - } + private static final String TRACE_TYPE = "TRACE_TYPE"; + + public static final String SCHEDULE_LEVEL = "SCHEDULE"; + + public static final String K8S_CALLBACK_LEVEL = "K8S_CALLBACK"; + + private static final String GLOBAL_REQUEST_LEVEL = "GLOBAL_REQUEST"; + + private static final String TRACE_LEVEL = "TRACE"; + + private static final String DEBUG_LEVEL = "DEBUG"; + + private static final String INFO_LEVEL = "INFO"; + + private static final String WARN_LEVEL = "WARN"; + + private static final String ERROR_LEVEL = "ERROR"; + + + public static void startScheduleTrace() { + MDC.put(TRACE_TYPE, SCHEDULE_LEVEL); + } + + public static void startK8sCallbackTrace() { + MDC.put(TRACE_TYPE, K8S_CALLBACK_LEVEL); + } + + public static void cleanTrace() { + MDC.clear(); + } + + /** + * info级别的日志 + * + * @param logType 日志类型 + * @param object 打印的日志参数 + * @return void + */ + + public static void info(LogEnum logType, Object... object) { + + logHandle(logType, Level.INFO, object); + } + + /** + * debug级别的日志 + * + * @param logType 日志类型 + * @param object 打印的日志参数 + * @return void + */ + public static void debug(LogEnum logType, Object... object) { + logHandle(logType, Level.DEBUG, object); + } + + /** + * error级别的日志 + * + * @param logType 日志类型 + * @param object 打印的日志参数 + * @return void + */ + public static void error(LogEnum logType, Object... object) { + errorObjectHandle(object); + logHandle(logType, Level.ERROR, object); + } + + /** + * warn级别的日志 + * + * @param logType 日志类型 + * @param object 打印的日志参数 + * @return void + */ + public static void warn(LogEnum logType, Object... object) { + logHandle(logType, Level.WARN, object); + } + + /** + * trace级别的日志 + * + * @param logType 日志类型 + * @param object 打印的日志参数 + * @return void + */ + public static void trace(LogEnum logType, Object... object) { + logHandle(logType, Level.TRACE, object); + } + + /** + * 日志处理 + * + * @param logType 日志类型 + * @param level 日志级别 + * @param object 打印的日志参数 + * @return void + */ + private static void logHandle(LogEnum logType, Level level, Object[] object) { + + LogInfo logInfo = generateLogInfo(logType, level, object); + + switch (logInfo.getLevel()) { + case TRACE_LEVEL: + log.trace(MarkerFactory.getMarker(TRACE_LEVEL), logJsonStringLengthLimit(logInfo)); + break; + case DEBUG_LEVEL: + log.debug(MarkerFactory.getMarker(DEBUG_LEVEL), logJsonStringLengthLimit(logInfo)); + break; + case GLOBAL_REQUEST_LEVEL: + logInfo.setLevel(null); + logInfo.setType(null); + logInfo.setLocation(null); + log.info(MarkerFactory.getMarker(GLOBAL_REQUEST_LEVEL), logJsonStringLengthLimit(logInfo)); + break; + case SCHEDULE_LEVEL: + log.info(MarkerFactory.getMarker(SCHEDULE_LEVEL), logJsonStringLengthLimit(logInfo)); + break; + case K8S_CALLBACK_LEVEL: + log.info(MarkerFactory.getMarker(K8S_CALLBACK_LEVEL), logJsonStringLengthLimit(logInfo)); + break; + case INFO_LEVEL: + log.info(MarkerFactory.getMarker(INFO_LEVEL), logJsonStringLengthLimit(logInfo)); + break; + case WARN_LEVEL: + log.warn(MarkerFactory.getMarker(WARN_LEVEL), logJsonStringLengthLimit(logInfo)); + break; + case ERROR_LEVEL: + log.error(MarkerFactory.getMarker(ERROR_LEVEL), logJsonStringLengthLimit(logInfo)); + break; + default: + } + + } + + + /** + * 日志信息组装的内部方法 + * + * @param logType 日志类型 + * @param level 日志级别 + * @param object 打印的日志参数 + * @return LogInfo + */ + private static LogInfo generateLogInfo(LogEnum logType, Level level, Object[] object) { + + + LogInfo logInfo = new LogInfo(); + // 日志类型检测 + if (!LogEnum.isLogType(logType)) { + level = Level.ERROR; + object = new Object[MagicNumConstant.ONE]; + object[MagicNumConstant.ZERO] = "日志类型【".concat(String.valueOf(logType)).concat("】不正确!"); + logType = LogEnum.SYS_ERR; + } + + // 获取trace_id + if (StringUtils.isEmpty(MDC.get(LogAspect.TRACE_ID))) { + MDC.put(LogAspect.TRACE_ID, UUID.randomUUID().toString()); + } + // 设置logInfo的level,type,traceId属性 + logInfo.setLevel(level.levelStr) + .setType(logType.toString()) + .setTraceId(MDC.get(LogAspect.TRACE_ID)); + + + //自定义日志级别 + //LogEnum、 MDC中的 TRACE_TYPE 做日志分流标识 + if (Level.INFO.toInt() == level.toInt()) { + if (LogEnum.GLOBAL_REQ.equals(logType)) { + //info全局请求 + logInfo.setLevel(GLOBAL_REQUEST_LEVEL); + } else if (LogEnum.BIZ_K8S.equals(logType)) { + logInfo.setLevel(K8S_CALLBACK_LEVEL); + } else { + //schedule定时等 链路记录 + String traceType = MDC.get(TRACE_TYPE); + if (StringUtils.isNotBlank(traceType)) { + logInfo.setLevel(traceType); + } + } + } + + // 设置logInfo的堆栈信息 + setLogStackInfo(logInfo); + // 设置logInfo的info信息 + setLogInfo(logInfo, object); + // 截取logInfo的长度并转换成json字符串 + return logInfo; + } + + /** + * 设置loginfo的堆栈信息 + * + * @param logInfo 日志对象 + * @return void + */ + private static void setLogStackInfo(LogInfo logInfo) { + StackTraceElement[] elements = Thread.currentThread().getStackTrace(); + if (elements.length >= MagicNumConstant.SIX) { + StackTraceElement element = elements[MagicNumConstant.FIVE]; + logInfo.setLocation(String.format("%s#%s:%s", element.getClassName(), element.getMethodName(), element.getLineNumber())); + } + } + + /** + * 限制log日志的长度并转换成json + * + * @param logInfo 日志对象 + * @return String + */ + private static String logJsonStringLengthLimit(LogInfo logInfo) { + try { + + String jsonString = JSON.toJSONString(logInfo); + if (StringUtils.isBlank(jsonString)) { + return ""; + } + if (jsonString.length() > MagicNumConstant.TEN_THOUSAND) { + String trunk = logInfo.getInfo().toString().substring(MagicNumConstant.ZERO, MagicNumConstant.NINE_THOUSAND); + logInfo.setInfo(trunk); + jsonString = JSON.toJSONString(logInfo); + } + return jsonString; + + } catch (Exception e) { + logInfo.setLevel(Level.ERROR.levelStr).setType(LogEnum.SYS_ERR.toString()) + .setInfo("cannot serialize exception: " + ExceptionUtils.getStackTrace(e)); + return JSON.toJSONString(logInfo); + } + } + + /** + * 设置日志对象的info信息 + * + * @param logInfo 日志对象 + * @param object 打印的日志参数 + * @return void + */ + private static void setLogInfo(LogInfo logInfo, Object[] object) { + + if (object.length > MagicNumConstant.ONE) { + logInfo.setInfo(MessageFormatter.arrayFormat(object[MagicNumConstant.ZERO].toString(), + Arrays.copyOfRange(object, MagicNumConstant.ONE, object.length)).getMessage()); + + } else if (object.length == MagicNumConstant.ONE && object[MagicNumConstant.ZERO] instanceof Exception) { + logInfo.setInfo((ExceptionUtils.getStackTrace((Exception) object[MagicNumConstant.ZERO]))); + log.error((ExceptionUtils.getStackTrace((Exception) object[MagicNumConstant.ZERO]))); + } else if (object.length == MagicNumConstant.ONE) { + logInfo.setInfo(object[MagicNumConstant.ZERO] == null ? "" : object[MagicNumConstant.ZERO]); + } else { + logInfo.setInfo(""); + } + + } + + /** + * 处理Exception的情况 + * + * @param object 打印的日志参数 + * @return void + */ + private static void errorObjectHandle(Object[] object) { + + if (object.length == MagicNumConstant.TWO && object[MagicNumConstant.ONE] instanceof Exception) { + log.error(String.valueOf(object[MagicNumConstant.ZERO]), (Exception) object[MagicNumConstant.ONE]); + object[MagicNumConstant.ONE] = ExceptionUtils.getStackTrace((Exception) object[MagicNumConstant.ONE]); + + } else if (object.length >= MagicNumConstant.THREE) { + log.error(String.valueOf(object[MagicNumConstant.ZERO]), + Arrays.copyOfRange(object, MagicNumConstant.ONE, object.length)); + for (int i = 0; i < object.length; i++) { + if (object[i] instanceof Exception) { + object[i] = ExceptionUtils.getStackTrace((Exception) object[i]); + } + + } + } + } } diff --git a/dubhe-server/common/src/main/java/org/dubhe/utils/MathUtils.java b/dubhe-server/common/src/main/java/org/dubhe/utils/MathUtils.java index 250cf84..d0f9bf3 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/utils/MathUtils.java +++ b/dubhe-server/common/src/main/java/org/dubhe/utils/MathUtils.java @@ -20,8 +20,8 @@ package org.dubhe.utils; import org.dubhe.base.MagicNumConstant; /** - * @description: 计算工具类 - * @create: 2020/6/4 14:53 + * @description 计算工具类 + * @date 2020-06-04 */ public class MathUtils { diff --git a/dubhe-server/common/src/main/java/org/dubhe/utils/MinioUtil.java b/dubhe-server/common/src/main/java/org/dubhe/utils/MinioUtil.java index e526ef2..42c934f 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/utils/MinioUtil.java +++ b/dubhe-server/common/src/main/java/org/dubhe/utils/MinioUtil.java @@ -68,9 +68,9 @@ public class MinioUtil { try { client = new MinioClient(url, accessKey, secretKey); } catch (InvalidEndpointException e) { - LogUtil.warn(LogEnum.BIZ_DATASET, "MinIO endpoint invalid. e:", e); + LogUtil.warn(LogEnum.BIZ_DATASET, "MinIO endpoint invalid. e, {}", e); } catch (InvalidPortException e) { - LogUtil.warn(LogEnum.BIZ_DATASET, "MinIO endpoint port invalid. e:", e); + LogUtil.warn(LogEnum.BIZ_DATASET, "MinIO endpoint port invalid. e, {}", e); } } @@ -94,7 +94,7 @@ public class MinioUtil { /** * 读取文件 * - * @param bucket + * @param bucket 桶 * @param fullFilePath 文件存储的全路径,包括文件名,非'/'开头. e.g. dataset/12/annotation/test.txt * @return String */ @@ -107,7 +107,7 @@ public class MinioUtil { /** * 文件删除 * - * @param bucket + * @param bucket 桶 * @param fullFilePath 文件存储的全路径,包括文件名,非'/'开头. e.g. dataset/12/annotation/test.txt */ public void del(String bucket, String fullFilePath) throws Exception { @@ -125,8 +125,8 @@ public class MinioUtil { /** * 批量删除文件 * - * @param bucket - * @param objectNames + * @param bucket 桶 + * @param objectNames 对象名称 */ public void delFiles(String bucket,List objectNames) throws Exception{ Iterable> results = client.removeObjects(bucket, objectNames); @@ -138,6 +138,10 @@ public class MinioUtil { /** * 获取对象名称 * + * @param bucketName 桶名称 + * @param prefix 前缀 + * @return + * @throws Exception */ public List getObjects(String bucketName, String prefix)throws Exception{ List fileNames = new ArrayList<>(); @@ -152,6 +156,10 @@ public class MinioUtil { /** * 获取文件流 * + * @param bucket 桶 + * @param objectName 对象名称 + * @return + * @throws Exception */ public InputStream getObjectInputStream(String bucket,String objectName)throws Exception{ return client.getObject(bucket, objectName); @@ -160,9 +168,9 @@ public class MinioUtil { /** * 文件夹复制 * - * @param bucket + * @param bucket 桶 * @param sourceFiles 源文件 - * @param targetDir 目标文件夹 + * @param targetDir 目标文件夹 */ public void copyDir(String bucket, List sourceFiles, String targetDir) { sourceFiles.forEach(sourceFile -> { @@ -211,10 +219,10 @@ public class MinioUtil { /** * 生成文件下载请求参数方法 * - * @param bucketName - * @param prefix - * @param objects - * @return MinioDownloadDto + * @param bucketName 桶名称 + * @param prefix 前缀 + * @param objects 对象名称 + * @return MinioDownloadDto 下载请求参数 */ public MinioDownloadDto getDownloadParam(String bucketName, String prefix, List objects, String zipName) { String paramTemplate = "{\"id\":%d,\"jsonrpc\":\"%s\",\"params\":{\"username\":\"%s\",\"password\":\"%s\"},\"method\":\"%s\"}"; diff --git a/dubhe-server/common/src/main/java/org/dubhe/utils/NfsUtil.java b/dubhe-server/common/src/main/java/org/dubhe/utils/NfsUtil.java index f7c7b41..0856532 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/utils/NfsUtil.java +++ b/dubhe-server/common/src/main/java/org/dubhe/utils/NfsUtil.java @@ -17,6 +17,7 @@ package org.dubhe.utils; +import cn.hutool.core.util.StrUtil; import com.emc.ecs.nfsclient.nfs.io.Nfs3File; import com.emc.ecs.nfsclient.nfs.io.NfsFileInputStream; import com.emc.ecs.nfsclient.nfs.io.NfsFileOutputStream; @@ -31,7 +32,6 @@ import org.dubhe.exception.NfsBizException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; -import org.springframework.util.FileCopyUtils; import java.io.*; import java.util.ArrayList; @@ -130,14 +130,14 @@ public class NfsUtil { } /** - * 校验文件或文件夹是否存在 + * 校验文件或文件夹是否不存在 * * @param path 文件路径 * @return boolean */ public boolean fileOrDirIsEmpty(String path) { if (!StringUtils.isEmpty(path)) { - path = formatPath(path); + path = formatPath(path.startsWith(nfsConfig.getRootDir()) ? path.replaceFirst(nfsConfig.getRootDir(), StrUtil.SLASH) : path); Nfs3File nfs3File = getNfs3File(path); try { if (nfs3File.exists()) { @@ -398,30 +398,6 @@ public class NfsUtil { } - /** - * NFS 复制目录到指定目录下 多个文件 包含目录与文件并存情况 - * - * 通过本地文件复制方式 - * - * @param sourcePath 需要复制的文件目录 例如:/abc/def - * @param targetPath 需要放置的目标目录 例如:/abc/dd - * @return boolean - */ - public boolean copyPath(String sourcePath, String targetPath) { - if (StringUtils.isEmpty(sourcePath) || StringUtils.isEmpty(targetPath)) { - return false; - } - sourcePath = formatPath(sourcePath); - targetPath = formatPath(targetPath); - try { - return copyLocalPath(nfsConfig.getRootDir() + sourcePath, nfsConfig.getRootDir() + targetPath); - } catch (Exception e) { - LogUtil.error(LogEnum.NFS_UTIL, " copyPath 复制失败: ", e); - return false; - } - } - - /** * NFS 复制目录到指定目录下 多个文件 包含目录与文件并存情况 * @@ -468,126 +444,15 @@ public class NfsUtil { } } - - /** - * 复制文件到指定目录下 单个文件 - * - * @param sourcePath 需要复制的文件 例如:/abc/def/cc.txt - * @param targetPath 需要放置的目标目录 例如:/abc/dd - * @return boolean - */ - public boolean copyLocalFile(String sourcePath, String targetPath) { - LogUtil.info(LogEnum.NFS_UTIL, "复制文件原路径: {} ,目标路径: {}", sourcePath, targetPath); - if (StringUtils.isEmpty(sourcePath) || StringUtils.isEmpty(targetPath)) { - return false; - } - sourcePath = formatPath(sourcePath); - targetPath = formatPath(targetPath); - LogUtil.info(LogEnum.NFS_UTIL, "过滤后文件原路径: {} ,目标路径:{}", sourcePath, targetPath); - try (InputStream input = new FileInputStream(sourcePath); - FileOutputStream output = new FileOutputStream(targetPath)) { - FileCopyUtils.copy(input, output); - LogUtil.info(LogEnum.NFS_UTIL, "复制文件成功"); - return true; - } catch (IOException e) { - LogUtil.error(LogEnum.NFS_UTIL, " copyLocalFile 复制失败: ", e); - return false; - } - } - - - /** - * 复制文件 到指定目录下 多个文件 包含目录与文件并存情况 - * - * @param sourcePath 需要复制的文件目录 例如:/abc/def - * @param targetPath 需要放置的目标目录 例如:/abc/dd - * @return boolean - */ - public boolean copyLocalPath(String sourcePath, String targetPath) { - if (!StringUtils.isEmpty(sourcePath) && !StringUtils.isEmpty(targetPath)) { - sourcePath = formatPath(sourcePath); - if (sourcePath.endsWith(FILE_SEPARATOR)) { - sourcePath = sourcePath.substring(MagicNumConstant.ZERO, sourcePath.lastIndexOf(FILE_SEPARATOR)); - } - targetPath = formatPath(targetPath); - LogUtil.info(LogEnum.NFS_UTIL, "复制文件夹 原路径: {} , 目标路径 : {} ", sourcePath, targetPath); - File[] files = new File(sourcePath).listFiles(); - LogUtil.info(LogEnum.NFS_UTIL, "需要复制的文件数量为: {}", files.length); - if (files.length != 0) { - for (File file : files) { - try { - if (file.isDirectory()) { - LogUtil.info(LogEnum.NFS_UTIL, "需要复制夹: {}", file.getAbsolutePath()); - LogUtil.info(LogEnum.NFS_UTIL, "目标文件夹: {}", targetPath + FILE_SEPARATOR + file.getName()); - File fileDir = new File(targetPath + FILE_SEPARATOR + file.getName()); - if(!fileDir.exists()){ - fileDir.mkdirs(); - } - copyLocalPath(sourcePath + FILE_SEPARATOR + file.getName(), targetPath + FILE_SEPARATOR + file.getName()); - } - if (file.isFile()) { - File fileTargetPath = new File(targetPath); - if(!fileTargetPath.exists()){ - fileTargetPath.mkdirs(); - } - LogUtil.info(LogEnum.NFS_UTIL, "需要复制文件: {}", file.getAbsolutePath()); - LogUtil.info(LogEnum.NFS_UTIL, "需要复制文件名称: {}", file.getName()); - copyLocalFile(file.getAbsolutePath() , targetPath + FILE_SEPARATOR + file.getName()); - } - }catch (Exception e){ - LogUtil.error(LogEnum.NFS_UTIL, "复制文件夹失败: {}", e); - return false; - } - } - } - return true; - } - return false; - } - - /** - * 解压前清理同路径下其他文件(目前只支持路径下无文件夹,文件均为zip文件) - * 上传路径垃圾文件清理 - * - * @param zipFilePath zip源文件 例如:/abc/z.zip - * @param path 文件夹 例如:/abc/ - * @return boolean - */ - public boolean cleanPath(String zipFilePath, String path) { - if (!StringUtils.isEmpty(zipFilePath) && !StringUtils.isEmpty(path) && zipFilePath.toLowerCase().endsWith(ZIP)) { - zipFilePath = formatPath(zipFilePath); - path = formatPath(path); - Nfs3File nfs3Files = getNfs3File(path); - try { - String zipName = zipFilePath.substring(zipFilePath.lastIndexOf(FILE_SEPARATOR) + MagicNumConstant.ONE); - if (!StringUtils.isEmpty(zipName)) { - List nfs3FilesList = nfs3Files.listFiles(); - if (!CollectionUtils.isEmpty(nfs3FilesList)) { - for (Nfs3File nfs3File : nfs3FilesList) { - if (!zipName.equals(nfs3File.getName())) { - nfs3File.delete(); - } - } - return true; - } - } - } catch (Exception e) { - LogUtil.error(LogEnum.NFS_UTIL, "路径{}清理失败,错误原因为:{} ", path, e); - return false; - } finally { - nfsPool.revertNfs(nfs3Files.getNfs()); - } - } - return false; - } - /** * zip解压并删除压缩文件 - * + * 当压缩包文件较多时,可能会因为RPC超时而解压失败 + * 该方法已废弃,请使用org.dubhe.utils.LocalFileUtil类的unzipLocalPath方法来替代 * @param sourcePath zip源文件 例如:/abc/z.zip * @param targetPath 解压后的目标文件夹 例如:/abc/ * @return boolean */ + @Deprecated public boolean unzip(String sourcePath, String targetPath) { if (StringUtils.isEmpty(sourcePath) || StringUtils.isEmpty(targetPath)) { return false; @@ -894,11 +759,15 @@ public class NfsUtil { * @param path * @return String */ - private String formatPath(String path) { + public String formatPath(String path) { if (!StringUtils.isEmpty(path)) { return path.replaceAll("///*", FILE_SEPARATOR); } return path; } + public String getAbsolutePath(String relativePath) { + return nfsConfig.getRootDir() + nfsConfig.getBucket() + relativePath; + } + } diff --git a/dubhe-server/common/src/main/java/org/dubhe/utils/RedisUtils.java b/dubhe-server/common/src/main/java/org/dubhe/utils/RedisUtils.java index c5f1d5a..170a030 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/utils/RedisUtils.java +++ b/dubhe-server/common/src/main/java/org/dubhe/utils/RedisUtils.java @@ -26,6 +26,9 @@ import org.springframework.data.redis.core.Cursor; import org.springframework.data.redis.core.RedisConnectionUtils; import org.springframework.data.redis.core.RedisTemplate; import org.springframework.data.redis.core.ScanOptions; +import org.springframework.data.redis.core.script.DefaultRedisScript; +import org.springframework.data.redis.core.script.RedisScript; +import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; @@ -33,7 +36,7 @@ import java.util.*; import java.util.concurrent.TimeUnit; /** - * @description redis工具类 + * @description redis工具类 * @date 2020-03-13 */ @Component @@ -62,7 +65,7 @@ public class RedisUtils { redisTemplate.expire(key, time, TimeUnit.SECONDS); } } catch (Exception e) { - LogUtil.error(LogEnum.SYS_ERR,"RedisUtils expire key {} time {} error:{}",key,time,e.getMessage(),e); + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils expire key {} time {} error:{}", key, time, e.getMessage(), e); return false; } return true; @@ -82,7 +85,7 @@ public class RedisUtils { * 查找匹配key * * @param pattern key - * @return / + * @return List 匹配的key集合 */ public List scan(String pattern) { ScanOptions options = ScanOptions.scanOptions().match(pattern).build(); @@ -94,9 +97,9 @@ public class RedisUtils { result.add(new String(cursor.next())); } try { - RedisConnectionUtils.releaseConnection(rc, factory,true); + RedisConnectionUtils.releaseConnection(rc, factory, true); } catch (Exception e) { - LogUtil.error(LogEnum.SYS_ERR,"RedisUtils scan pattern {} error:{}",pattern,e.getMessage(),e); + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils scan pattern {} error:{}", pattern, e.getMessage(), e); } return result; } @@ -107,7 +110,7 @@ public class RedisUtils { * @param patternKey key * @param page 页码 * @param size 每页数目 - * @return / + * @return 匹配到的key集合 */ public List findKeysForPage(String patternKey, int page, int size) { ScanOptions options = ScanOptions.scanOptions().match(patternKey).build(); @@ -132,9 +135,9 @@ public class RedisUtils { cursor.next(); } try { - RedisConnectionUtils.releaseConnection(rc, factory,true); + RedisConnectionUtils.releaseConnection(rc, factory, true); } catch (Exception e) { - LogUtil.error(LogEnum.SYS_ERR,"RedisUtils findKeysForPage patternKey {} page {} size {} error:{}",patternKey,page,size,e.getMessage(),e); + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils findKeysForPage patternKey {} page {} size {} error:{}", patternKey, page, size, e.getMessage(), e); } return result; } @@ -149,7 +152,7 @@ public class RedisUtils { try { return redisTemplate.hasKey(key); } catch (Exception e) { - LogUtil.error(LogEnum.SYS_ERR,"RedisUtils hasKey key {} error:{}",key,e.getMessage(),e); + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils hasKey key {} error:{}", key, e.getMessage(), e); return false; } } @@ -175,7 +178,7 @@ public class RedisUtils { * 普通缓存获取 * * @param key 键 - * @return 值 + * @return key对应的value值 */ public Object get(String key) { @@ -185,8 +188,8 @@ public class RedisUtils { /** * 批量获取 * - * @param keys - * @return + * @param keys key集合 + * @return key集合对应的value集合 */ public List multiGet(List keys) { Object obj = redisTemplate.opsForValue().multiGet(Collections.singleton(keys)); @@ -205,7 +208,7 @@ public class RedisUtils { redisTemplate.opsForValue().set(key, value); return true; } catch (Exception e) { - LogUtil.error(LogEnum.SYS_ERR,"RedisUtils set key {} value {} error:{}",key,value,e.getMessage(),e); + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils set key {} value {} error:{}", key, value, e.getMessage(), e); return false; } } @@ -227,7 +230,7 @@ public class RedisUtils { } return true; } catch (Exception e) { - LogUtil.error(LogEnum.SYS_ERR,"RedisUtils set key {} value {} time {} error:{}",key,value,time,e.getMessage(),e); + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils set key {} value {} time {} error:{}", key, value, time, e.getMessage(), e); return false; } } @@ -250,11 +253,56 @@ public class RedisUtils { } return true; } catch (Exception e) { - LogUtil.error(LogEnum.SYS_ERR,"RedisUtils set key {} value {} time {} timeUnit {} error:{}",key,value,time,timeUnit,e.getMessage(),e); + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils set key {} value {} time {} timeUnit {} error:{}", key, value, time, timeUnit, e.getMessage(), e); return false; } } + //===============================Lock================================= + + /** + * 加锁 + * @param key 键 + * @param requestId 请求id用以释放锁 + * @param expireTime 超时时间(秒) + * @return + */ + public boolean getDistributedLock(String key, String requestId, long expireTime) { + String script = "if redis.call('setNx',KEYS[1],ARGV[1]) == 1 then if redis.call('get',KEYS[1]) == ARGV[1] then return redis.call('expire',KEYS[1],ARGV[2]) else return 0 end else return 0 end"; + Object result = executeRedisScript(script, key, requestId, expireTime); + return result != null && result.equals(MagicNumConstant.ONE_LONG); + } + + /** + * 释放锁 + * @param key 键 + * @param requestId 请求id用以释放锁 + * @return + */ + public boolean releaseDistributedLock(String key, String requestId) { + String script = "if redis.call('get', KEYS[1]) == ARGV[1] then return redis.call('del', KEYS[1]) else return 0 end"; + Object result = executeRedisScript(script, key, requestId); + return result != null && result.equals(MagicNumConstant.ONE_LONG); + } + + /** + * + * @param script 脚本字符串 + * @param key 键 + * @param args 脚本其他参数 + * @return + */ + public Object executeRedisScript(String script, String key, Object... args) { + try { + RedisScript redisScript = new DefaultRedisScript<>(script, Long.class); + redisTemplate.setValueSerializer(new Jackson2JsonRedisSerializer<>(Object.class)); + return redisTemplate.execute(redisScript, Collections.singletonList(key), args); + } catch (Exception e) { + LogUtil.error(LogEnum.SYS_ERR, "executeRedisScript script {} key {} expireTime {} args {} error:{}", script, key, args, e); + return MagicNumConstant.ZERO_LONG; + } + } + // ================================Map================================= /** @@ -291,7 +339,7 @@ public class RedisUtils { redisTemplate.opsForHash().putAll(key, map); return true; } catch (Exception e) { - LogUtil.error(LogEnum.SYS_ERR,"RedisUtils hmset key {} map {} error:{}",key,map,e.getMessage(),e); + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils hmset key {} map {} error:{}", key, map, e.getMessage(), e); return false; } } @@ -312,7 +360,7 @@ public class RedisUtils { } return true; } catch (Exception e) { - LogUtil.error(LogEnum.SYS_ERR,"RedisUtils hmset key {} map {} time {} error:{}",key,map,time,e.getMessage(),e); + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils hmset key {} map {} time {} error:{}", key, map, time, e.getMessage(), e); return false; } } @@ -330,7 +378,7 @@ public class RedisUtils { redisTemplate.opsForHash().put(key, item, value); return true; } catch (Exception e) { - LogUtil.error(LogEnum.SYS_ERR,"RedisUtils hset key {} item {} value {} error:{}",key,item,value,e.getMessage(),e); + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils hset key {} item {} value {} error:{}", key, item, value, e.getMessage(), e); return false; } } @@ -352,7 +400,7 @@ public class RedisUtils { } return true; } catch (Exception e) { - LogUtil.error(LogEnum.SYS_ERR,"RedisUtils hset key {} item {} value {} time {} error:{}",key,item,value,time,e.getMessage(),e); + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils hset key {} item {} value {} time {} error:{}", key, item, value, time, e.getMessage(), e); return false; } } @@ -414,7 +462,7 @@ public class RedisUtils { try { return redisTemplate.opsForSet().members(key); } catch (Exception e) { - LogUtil.error(LogEnum.SYS_ERR,"RedisUtils sGet key {} error:{}",key,e.getMessage(),e); + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils sGet key {} error:{}", key, e.getMessage(), e); return null; } } @@ -430,7 +478,7 @@ public class RedisUtils { try { return redisTemplate.opsForSet().isMember(key, value); } catch (Exception e) { - LogUtil.error(LogEnum.SYS_ERR,"RedisUtils sHasKey key {} value {} error:{}",key,value,e.getMessage(),e); + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils sHasKey key {} value {} error:{}", key, value, e.getMessage(), e); return false; } } @@ -446,7 +494,7 @@ public class RedisUtils { try { return redisTemplate.opsForSet().add(key, values); } catch (Exception e) { - LogUtil.error(LogEnum.SYS_ERR,"RedisUtils sSet key {} values {} error:{}",key,values,e.getMessage(),e); + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils sSet key {} values {} error:{}", key, values, e.getMessage(), e); return 0; } } @@ -467,7 +515,7 @@ public class RedisUtils { } return count; } catch (Exception e) { - LogUtil.error(LogEnum.SYS_ERR,"RedisUtils sSetAndTime key {} time {} values {} error:{}",key,time,values,e.getMessage(),e); + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils sSetAndTime key {} time {} values {} error:{}", key, time, values, e.getMessage(), e); return 0; } } @@ -482,7 +530,7 @@ public class RedisUtils { try { return redisTemplate.opsForSet().size(key); } catch (Exception e) { - LogUtil.error(LogEnum.SYS_ERR,"RedisUtils sGetSetSize key {} error:{}",key,e.getMessage(),e); + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils sGetSetSize key {} error:{}", key, e.getMessage(), e); return 0; } } @@ -499,7 +547,7 @@ public class RedisUtils { Long count = redisTemplate.opsForSet().remove(key, values); return count; } catch (Exception e) { - LogUtil.error(LogEnum.SYS_ERR,"RedisUtils setRemove key {} values {} error:{}",key,values,e.getMessage(),e); + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils setRemove key {} values {} error:{}", key, values, e.getMessage(), e); return 0; } } @@ -507,22 +555,22 @@ public class RedisUtils { // ===============================sorted set================================= /** - *将zSet数据放入缓存 + * 将zSet数据放入缓存 * * @param key * @param time * @param values * @return Boolean */ - public Boolean zSet(String key, long time, Object value){ + public Boolean zSet(String key, long time, Object value) { try { - Boolean success = redisTemplate.opsForZSet().add(key, value,System.currentTimeMillis()); + Boolean success = redisTemplate.opsForZSet().add(key, value, System.currentTimeMillis()); if (success) { expire(key, time); } return success; } catch (Exception e) { - LogUtil.error(LogEnum.SYS_ERR,"RedisUtils zSet key {} time {} value {} error:{}",key,time,value,e.getMessage(),e); + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils zSet key {} time {} value {} error:{}", key, time, value, e.getMessage(), e); return false; } } @@ -533,17 +581,16 @@ public class RedisUtils { * @param key * @return Set */ - public Set zGet(String key){ + public Set zGet(String key) { try { - return redisTemplate.opsForZSet().reverseRange(key,Long.MIN_VALUE, Long.MAX_VALUE); + return redisTemplate.opsForZSet().reverseRange(key, Long.MIN_VALUE, Long.MAX_VALUE); } catch (Exception e) { - LogUtil.error(LogEnum.SYS_ERR,"RedisUtils zGet key {} error:{}",key,e.getMessage(),e); + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils zGet key {} error:{}", key, e.getMessage(), e); return null; } } - // ===============================list================================= /** @@ -558,7 +605,7 @@ public class RedisUtils { try { return redisTemplate.opsForList().range(key, start, end); } catch (Exception e) { - LogUtil.error(LogEnum.SYS_ERR,"RedisUtils lGetIndex key {} start {} end {} error:{}",key,start,end,e.getMessage(),e); + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lGetIndex key {} start {} end {} error:{}", key, start, end, e.getMessage(), e); return null; } } @@ -573,7 +620,7 @@ public class RedisUtils { try { return redisTemplate.opsForList().size(key); } catch (Exception e) { - LogUtil.error(LogEnum.SYS_ERR,"RedisUtils lGetListSize key {} error:{}",key,e.getMessage(),e); + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lGetListSize key {} error:{}", key, e.getMessage(), e); return 0; } } @@ -589,7 +636,7 @@ public class RedisUtils { try { return redisTemplate.opsForList().index(key, index); } catch (Exception e) { - LogUtil.error(LogEnum.SYS_ERR,"RedisUtils lGetIndex key {} index {} error:{}",key,index,e.getMessage(),e); + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lGetIndex key {} index {} error:{}", key, index, e.getMessage(), e); return null; } } @@ -606,7 +653,7 @@ public class RedisUtils { redisTemplate.opsForList().rightPush(key, value); return true; } catch (Exception e) { - LogUtil.error(LogEnum.SYS_ERR,"RedisUtils lSet key {} value {} error:{}",key,value,e.getMessage(),e); + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lSet key {} value {} error:{}", key, value, e.getMessage(), e); return false; } } @@ -617,7 +664,7 @@ public class RedisUtils { * @param key 键 * @param value 值 * @param time 时间(秒) - * @return + * @return 是否存储成功 */ public boolean lSet(String key, Object value, long time) { try { @@ -627,7 +674,7 @@ public class RedisUtils { } return true; } catch (Exception e) { - LogUtil.error(LogEnum.SYS_ERR,"RedisUtils lSet key {} value {} time {} error:{}",key,value,time,e.getMessage(),e); + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lSet key {} value {} time {} error:{}", key, value, time, e.getMessage(), e); return false; } } @@ -637,14 +684,14 @@ public class RedisUtils { * * @param key 键 * @param value 值 - * @return + * @return 是否存储成功 */ public boolean lSet(String key, List value) { try { redisTemplate.opsForList().rightPushAll(key, value); return true; } catch (Exception e) { - LogUtil.error(LogEnum.SYS_ERR,"RedisUtils lSet key {} value {} error:{}",key,value,e.getMessage(),e); + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lSet key {} value {} error:{}", key, value, e.getMessage(), e); return false; } } @@ -655,7 +702,7 @@ public class RedisUtils { * @param key 键 * @param value 值 * @param time 时间(秒) - * @return + * @return 是否存储成功 */ public boolean lSet(String key, List value, long time) { try { @@ -665,7 +712,7 @@ public class RedisUtils { } return true; } catch (Exception e) { - LogUtil.error(LogEnum.SYS_ERR,"RedisUtils lSet key {} value {} time {} error:{}",key,value,time,e.getMessage(),e); + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lSet key {} value {} time {} error:{}", key, value, time, e.getMessage(), e); return false; } } @@ -676,14 +723,14 @@ public class RedisUtils { * @param key 键 * @param index 索引 * @param value 值 - * @return / + * @return 更新数据标识 */ public boolean lUpdateIndex(String key, long index, Object value) { try { redisTemplate.opsForList().set(key, index, value); return true; } catch (Exception e) { - LogUtil.error(LogEnum.SYS_ERR,"RedisUtils lUpdateIndex key {} index {} value {} error:{}",key,index,value,e.getMessage(),e); + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lUpdateIndex key {} index {} value {} error:{}", key, index, value, e.getMessage(), e); return false; } } @@ -700,8 +747,23 @@ public class RedisUtils { try { return redisTemplate.opsForList().remove(key, count, value); } catch (Exception e) { - LogUtil.error(LogEnum.SYS_ERR,"RedisUtils lRemove key {} count {} value {} error:{}",key,count,value,e.getMessage(),e); + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lRemove key {} count {} value {} error:{}", key, count, value, e.getMessage(), e); return 0; } } + + /** + * 队列从左弹出数据 + * + * @param key + * @return key对应的value值 + */ + public Object lpop(String key) { + try { + return redisTemplate.opsForList().leftPop(key); + } catch (Exception e) { + LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lRemove key {} error:{}", key, e.getMessage(), e); + return null; + } + } } diff --git a/dubhe-server/common/src/main/java/org/dubhe/utils/ReflectionUtils.java b/dubhe-server/common/src/main/java/org/dubhe/utils/ReflectionUtils.java index c5ea13b..0859db5 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/utils/ReflectionUtils.java +++ b/dubhe-server/common/src/main/java/org/dubhe/utils/ReflectionUtils.java @@ -18,7 +18,8 @@ package org.dubhe.utils; import java.lang.reflect.Field; -import java.util.*; +import java.util.ArrayList; +import java.util.List; /** * @description 反射工具类 diff --git a/dubhe-server/common/src/main/java/org/dubhe/utils/RegexUtil.java b/dubhe-server/common/src/main/java/org/dubhe/utils/RegexUtil.java index 0d97bd2..d320c47 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/utils/RegexUtil.java +++ b/dubhe-server/common/src/main/java/org/dubhe/utils/RegexUtil.java @@ -23,8 +23,8 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; /** - * @description: 正则匹配工具类 - * @create: 2020/4/23 13:51 + * @description 正则匹配工具类 + * @date 2020-04-23 */ @Slf4j public class RegexUtil { diff --git a/dubhe-server/common/src/main/java/org/dubhe/utils/SqlUtil.java b/dubhe-server/common/src/main/java/org/dubhe/utils/SqlUtil.java index 6c0267e..d21756d 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/utils/SqlUtil.java +++ b/dubhe-server/common/src/main/java/org/dubhe/utils/SqlUtil.java @@ -17,11 +17,18 @@ package org.dubhe.utils; +import org.dubhe.base.BaseService; +import org.dubhe.base.DataContext; + +import java.util.HashSet; +import java.util.Objects; +import java.util.Set; + /** * @description sql语句转换的工具类 * @date 2020-07-06 */ - + public class SqlUtil { /** @@ -46,4 +53,46 @@ public class SqlUtil { return ""; } + + /** + * 获取资源拥有着ID + * + * @return 资源拥有者id集合 + */ + public static Set getResourceIds() { + if (!Objects.isNull(DataContext.get())) { + return DataContext.get().getResourceUserIds(); + } + Set ids = new HashSet<>(); + Long id = JwtUtils.getCurrentUserDto().getId(); + ids.add(id); + return ids; + + } + + + /** + * 构建目标sql语句 + * + * @param originSql 原生sql + * @param resourceUserIds 所属资源用户ids + * @return 目标sql + */ + public static String buildTargetSql(String originSql, Set resourceUserIds) { + if (BaseService.isAdmin()) { + return originSql; + } + String sqlWhereBefore = org.dubhe.utils.StringUtils.substringBefore(originSql.toLowerCase(), "where"); + String sqlWhereAfter = org.dubhe.utils.StringUtils.substringAfter(originSql.toLowerCase(), "where"); + StringBuffer buffer = new StringBuffer(); + //操作的sql拼接 + String targetSql = buffer.append(sqlWhereBefore).append(" where ").append(" origin_user_id in (") + .append(org.dubhe.utils.StringUtils.join(resourceUserIds, ",")).append(") and ").append(sqlWhereAfter).toString(); + + return targetSql; + } + + + + } diff --git a/dubhe-server/common/src/main/java/org/dubhe/utils/StringUtils.java b/dubhe-server/common/src/main/java/org/dubhe/utils/StringUtils.java index 5bfa187..1253e4f 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/utils/StringUtils.java +++ b/dubhe-server/common/src/main/java/org/dubhe/utils/StringUtils.java @@ -273,4 +273,41 @@ public class StringUtils extends org.apache.commons.lang3.StringUtils { matcher.appendTail(sb); return sb.toString(); } + + + /** + * 字符串截取前 + * @param str + * @return + */ + public static String substringBefore(String str, String separator){ + + if (!isEmpty(str) && separator != null) { + if (separator.isEmpty()) { + return ""; + } else { + int pos = str.indexOf(separator); + return pos == -1 ? str : str.substring(0, pos); + } + } else { + return str; + } + } + + /** + * 字符串截取后 + * @param str + * @return + */ + public static String substringAfter(String str, String separator){ + + if (isEmpty(str)) { + return str; + } else if (separator == null) { + return ""; + } else { + int pos = str.indexOf(separator); + return pos == -1 ? "" : str.substring(pos + separator.length()); + } + } } diff --git a/dubhe-server/common/src/main/java/org/dubhe/utils/TimeTransferUtil.java b/dubhe-server/common/src/main/java/org/dubhe/utils/TimeTransferUtil.java index 0423246..481f977 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/utils/TimeTransferUtil.java +++ b/dubhe-server/common/src/main/java/org/dubhe/utils/TimeTransferUtil.java @@ -17,42 +17,33 @@ package org.dubhe.utils; - -import lombok.extern.slf4j.Slf4j; - -import java.text.ParseException; import java.text.SimpleDateFormat; import java.util.Calendar; import java.util.Date; +import static org.dubhe.base.MagicNumConstant.EIGHT; + /** - * @description: UTC时间转换CST时间工具类 - * @create: 2020/5/20 12:10 + * @description 时间格式转换工具类 + * @date 2020-05-20 */ -@Slf4j public class TimeTransferUtil { + + private static final String UTC_FORMAT = "yyyy-MM-dd'T'HH:mm:ss.sss'Z'"; + /** - * @param utcTime - * @return cstTime + * Date转换为UTC时间 + * + * @param date + * @return utcTime */ - public static String cstTransfer(String utcTime){ - Date utcDate = null; - /**2020-05-20T03:13:22Z 对应的时间格式 yyyy-MM-dd'T'HH:mm:ss'Z'**/ - SimpleDateFormat utcSimpleDateFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss'Z'"); - - try { - utcDate = utcSimpleDateFormat.parse(utcTime); - } catch (ParseException e) { - log.info(e.getMessage()); - return null; - } - /**System.out.println("UTC时间:"+date);**/ + public static String dateTransferToUtc(Date date){ Calendar calendar = Calendar.getInstance(); - calendar.setTime(utcDate); - calendar.set(Calendar.HOUR,calendar.get(Calendar.HOUR)+8); - SimpleDateFormat cstSimpleDateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); - Date cstDate = calendar.getTime(); - String cstTime = cstSimpleDateFormat.format(calendar.getTime()); - return cstTime; + calendar.setTime(date); + /**UTC时间与CST时间相差8小时**/ + calendar.set(Calendar.HOUR,calendar.get(Calendar.HOUR) - EIGHT); + SimpleDateFormat utcSimpleDateFormat = new SimpleDateFormat(UTC_FORMAT); + Date utcDate = calendar.getTime(); + return utcSimpleDateFormat.format(utcDate); } } diff --git a/dubhe-server/common/src/main/java/org/dubhe/utils/UniqueKeyGenerator.java b/dubhe-server/common/src/main/java/org/dubhe/utils/UniqueKeyGenerator.java index 1a8eec7..8ddfd6c 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/utils/UniqueKeyGenerator.java +++ b/dubhe-server/common/src/main/java/org/dubhe/utils/UniqueKeyGenerator.java @@ -25,9 +25,8 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; /** - * @description: 唯一码生成器 (依赖时间轴) - * - * @date 2020.05.08 + * @description 唯一码生成器 (依赖时间轴) + * @date 2020-05-08 */ public class UniqueKeyGenerator { diff --git a/dubhe-server/common/src/main/java/org/dubhe/utils/WrapperHelp.java b/dubhe-server/common/src/main/java/org/dubhe/utils/WrapperHelp.java index a303e19..c8b46e7 100644 --- a/dubhe-server/common/src/main/java/org/dubhe/utils/WrapperHelp.java +++ b/dubhe-server/common/src/main/java/org/dubhe/utils/WrapperHelp.java @@ -20,7 +20,6 @@ package org.dubhe.utils; import cn.hutool.core.util.ObjectUtil; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import lombok.extern.slf4j.Slf4j; -import lombok.val; import org.dubhe.annotation.Query; import java.lang.reflect.Field; @@ -31,7 +30,7 @@ import java.util.List; /** * @description 构建Wrapper - * @date 2020-03-15 13:52:30 + * @date 2020-03-15 */ @Slf4j public class WrapperHelp { diff --git a/dubhe-server/common/src/test/java/org/dubhe/HttpUtilsTest.java b/dubhe-server/common/src/test/java/org/dubhe/HttpUtilsTest.java index 5403c02..55823a2 100644 --- a/dubhe-server/common/src/test/java/org/dubhe/HttpUtilsTest.java +++ b/dubhe-server/common/src/test/java/org/dubhe/HttpUtilsTest.java @@ -23,8 +23,8 @@ import org.junit.Test; import static org.dubhe.utils.HttpUtils.isSuccess; /** - * @description: HttpUtil - * @date 2020.04.30 + * @description HttpUtil + * @date 2020-04-30 */ public class HttpUtilsTest { diff --git a/dubhe-server/deploy.sh b/dubhe-server/deploy.sh new file mode 100644 index 0000000..4622654 --- /dev/null +++ b/dubhe-server/deploy.sh @@ -0,0 +1,71 @@ +#!/bin/bash + +PROG_NAME=$0 +ACTION=$1 +ENV=$2 +APP_HOME=$3 + +APP_NAME=dubhe-${ENV} + +APP_HOME=$APP_HOME/${APP_NAME} # 从package.tgz中解压出来的jar包放到这个目录下 +JAR_NAME=${APP_HOME}/dubhe-admin/target/dubhe-admin-1.0-exec.jar # jar包的名字 +JAVA_OUT=/dev/null + +# 创建出相关目录 +mkdir -p ${APP_HOME} +mkdir -p ${APP_HOME}/logs + +usage() { + echo "Usage: $PROG_NAME {start|stop|restart} {dev|test|prod}" + exit 2 +} + +start_application() { + echo "starting java process" + echo "nohup java -jar ${JAR_NAME} > ${JAVA_OUT} --spring.profiles.active=${ENV} 2>&1 &" + nohup java -jar ${JAR_NAME} > ${JAVA_OUT} --spring.profiles.active=${ENV} 2>&1 & + echo "started java process" +} + +stop_application() { + checkjavapid=`ps -ef | grep java | grep ${APP_NAME} | grep -v grep |grep -v 'deploy.sh'| awk '{print$2}'` + + if [ -z $checkjavapid ];then + echo -e "\rno java process "$checkjavapid + return + fi + + echo "stop java process" + times=60 + for e in $(seq 60) + do + sleep 1 + COSTTIME=$(($times - $e )) + checkjavapid=`ps -ef | grep java | grep ${APP_NAME} | grep -v grep |grep -v 'deploy.sh'| awk '{print$2}'` + if [ "$checkjavapid" != "" ];then + echo "kill "$checkjavapid + kill -9 $checkjavapid + echo -e "\r -- stopping java lasts `expr $COSTTIME` seconds." + else + echo -e "\rjava process has exited" + break; + fi + done + echo "" +} + +case "$ACTION" in + start) + start_application + ;; + stop) + stop_application + ;; + restart) + stop_application + start_application + ;; + *) + usage + ;; +esac diff --git a/dubhe-server/dubhe-admin/pom.xml b/dubhe-server/dubhe-admin/pom.xml index 8db6e5f..d563dfe 100644 --- a/dubhe-server/dubhe-admin/pom.xml +++ b/dubhe-server/dubhe-admin/pom.xml @@ -75,6 +75,7 @@ false true + exec diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/task/HarborImagePushAsync.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/async/HarborImagePushAsync.java similarity index 71% rename from dubhe-server/dubhe-admin/src/main/java/org/dubhe/task/HarborImagePushAsync.java rename to dubhe-server/dubhe-admin/src/main/java/org/dubhe/async/HarborImagePushAsync.java index 9d39c6e..0de73ce 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/task/HarborImagePushAsync.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/async/HarborImagePushAsync.java @@ -14,7 +14,7 @@ * limitations under the License. * ============================================================= */ -package org.dubhe.task; +package org.dubhe.async; import cn.hutool.core.util.StrUtil; import org.dubhe.base.ResponseCode; @@ -25,6 +25,7 @@ import org.dubhe.enums.ImageStateEnum; import org.dubhe.enums.LogEnum; import org.dubhe.exception.BusinessException; import org.dubhe.harbor.api.HarborApi; +import org.dubhe.utils.IOUtil; import org.dubhe.utils.LogUtil; import org.dubhe.utils.StringUtils; import org.springframework.beans.factory.annotation.Autowired; @@ -62,46 +63,21 @@ public class HarborImagePushAsync { String imageResource = trainHarborConfig.getAddress() + StrUtil.SLASH + trainHarborConfig.getModelName() + StrUtil.SLASH + imageNameandTag; String cmdStr = "docker login --username=" + trainHarborConfig.getUsername() + " " + trainHarborConfig.getAddress() + " --password=" + trainHarborConfig.getPassword() + " ; docker " + - "load < " + imagePath + " |awk '{print $3}' |xargs -I str docker tag str " + imageResource + " ; docker push " + imageResource; + "load < " + imagePath + " |awk '{print $3}' |xargs -I str docker tag str " + imageResource + " ; docker push " + imageResource + "; docker rmi " + imageResource; String[] cmd = {"/bin/bash", "-c", cmdStr}; LogUtil.info(LogEnum.BIZ_TRAIN, "镜像上传执行脚本参数:{}", cmd); Process process = Runtime.getRuntime().exec(cmd); - //读取标准输出流 - BufferedReader brOut = new BufferedReader(new InputStreamReader(process.getInputStream())); - //读取标准错误流 - BufferedReader brErr = new BufferedReader(new InputStreamReader(process.getErrorStream())); - String line; - String outMessage = ""; - String errMessage = ""; - while ((line = brOut.readLine()) != null) { - outMessage += line; - } - if (StringUtils.isNotEmpty(outMessage)) { - LogUtil.info(LogEnum.BIZ_TRAIN, "shell上传镜像输出信息:" + outMessage); - } - while ((line = brErr.readLine()) != null) { - errMessage += line; - } - if (StringUtils.isNotEmpty(errMessage)) { - LogUtil.error(LogEnum.BIZ_TRAIN, "shell上传镜像异常信息:" + errMessage); - } - Integer status = process.waitFor(); - LogUtil.info(LogEnum.BIZ_TRAIN, "上传镜像状态:{}", status); - if (status == null) { - if (harborApi.isExistImage(ptImage.getImageUrl())) { - updateImageStatus(ptImage, ImageStateEnum.SUCCESS.getCode()); - } else { - updateImageStatus(ptImage, ImageStateEnum.FAIL.getCode()); - } - } else if (status == 0) { + if (checkImagePushIsOk(ptImage, process)) { updateImageStatus(ptImage, ImageStateEnum.SUCCESS.getCode()); } else { updateImageStatus(ptImage, ImageStateEnum.FAIL.getCode()); } } catch (Exception e) { LogUtil.error(LogEnum.BIZ_TRAIN, "上传镜像异常:{}", e); + updateImageStatus(ptImage, ImageStateEnum.FAIL.getCode()); throw new BusinessException("上传镜像异常!"); + } } @@ -117,4 +93,52 @@ public class HarborImagePushAsync { ptImageMapper.updateById(ptImage); return ResponseCode.SUCCESS; } + + + /** + * 校验镜像是否上传成功 + * + * @param ptImage 镜像信息 + * @param process process对象 + * @return 是否上传成功 + */ + public boolean checkImagePushIsOk(PtImage ptImage, Process process) { + //读取标准输出流 + BufferedReader brOut = new BufferedReader(new InputStreamReader(process.getInputStream())); + //读取标准错误流 + BufferedReader brErr = new BufferedReader(new InputStreamReader(process.getErrorStream())); + String line; + StringBuilder outMessage = new StringBuilder(); + StringBuilder errMessage = new StringBuilder(); + boolean isPushOk = true; + try { + while ((line = brOut.readLine()) != null) { + outMessage.append(line); + } + if (StringUtils.isNotEmpty(outMessage)) { + LogUtil.info(LogEnum.BIZ_TRAIN, "shell上传镜像输出信息:{}", outMessage.toString()); + } + while ((line = brErr.readLine()) != null) { + errMessage.append(line); + } + if (StringUtils.isNotEmpty(errMessage)) { + LogUtil.error(LogEnum.BIZ_TRAIN, "shell上传镜像异常信息:{}", errMessage.toString()); + } + Integer status = process.waitFor(); + LogUtil.info(LogEnum.BIZ_TRAIN, "上传镜像状态:{}", status); + if (status == null) { + if (!harborApi.isExistImage(ptImage.getImageUrl())) { + isPushOk = false; + } + } else if (status != 0) { + isPushOk = false; + } + } catch (Exception e) { + LogUtil.error(LogEnum.BIZ_TRAIN, "上传镜像异常:{}", e); + return false; + } finally { + IOUtil.close(brErr, brOut); + } + return isPushOk; + } } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/async/StopTrainJobAsync.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/async/StopTrainJobAsync.java new file mode 100644 index 0000000..2c908ed --- /dev/null +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/async/StopTrainJobAsync.java @@ -0,0 +1,144 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ +package org.dubhe.async; + +import org.dubhe.config.TrainJobConfig; +import org.dubhe.dao.PtTrainJobMapper; +import org.dubhe.domain.dto.UserDTO; +import org.dubhe.domain.entity.PtTrainJob; +import org.dubhe.enums.LogEnum; +import org.dubhe.enums.TrainJobStatusEnum; +import org.dubhe.enums.TrainTypeEnum; +import org.dubhe.k8s.api.DistributeTrainApi; +import org.dubhe.k8s.api.PodApi; +import org.dubhe.k8s.api.TrainJobApi; +import org.dubhe.k8s.domain.resource.BizPod; +import org.dubhe.utils.*; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.scheduling.annotation.Async; +import org.springframework.stereotype.Component; + +import java.time.LocalDateTime; +import java.time.ZoneOffset; +import java.time.format.DateTimeFormatter; +import java.util.List; +import java.util.function.Consumer; + +/** + * @description 停止训练任务异步处理 + * @date 2020-08-13 + */ +@Component +public class StopTrainJobAsync { + + @Autowired + private K8sNameTool k8sNameTool; + + @Autowired + private PodApi podApi; + + @Autowired + private TrainJobApi trainJobApi; + + @Autowired + private TrainJobConfig trainJobConfig; + + @Autowired + private DistributeTrainApi distributeTrainApi; + + @Autowired + private PtTrainJobMapper ptTrainJobMapper; + + /** + * 停止任务 + * + * @param currentUser 用户 + * @param jobList 任务集合 + */ + @Async("trainExecutor") + public void stopJobs(UserDTO currentUser, List jobList) { + String namespace = k8sNameTool.generateNamespace(currentUser.getId()); + jobList.forEach(job -> { + BizPod bizPod = podApi.getWithResourceName(namespace, job.getJobName()); + if (!bizPod.isSuccess()) { + LogUtil.error(LogEnum.BIZ_TRAIN, "User {} stops training Job return code:{},message:{}", currentUser.getUsername(), Integer.valueOf(bizPod.getCode()), bizPod.getMessage()); + } + boolean bool = TrainTypeEnum.isDistributeTrain(job.getTrainType()) ? + distributeTrainApi.deleteByResourceName(namespace, job.getJobName()).isSuccess() : + trainJobApi.delete(namespace, job.getJobName()); + if (!bool) { + LogUtil.error(LogEnum.BIZ_TRAIN, "User {} stops training Job and K8S fails in the stop process, namespace为{}, resourceName为{}", + currentUser.getUsername(), namespace, job.getJobName()); + } + //更新训练状态 + job.setRuntime(calculateRuntime(bizPod)) + .setTrainStatus(TrainJobStatusEnum.STOP.getStatus()); + ptTrainJobMapper.updateById(job); + + }); + } + + + /** + * 计算job训练时长 + * + * @param bizPod pod信息 + * @return String 训练时长 + */ + private String calculateRuntime(BizPod bizPod) { + return calculateRuntime(bizPod, (x) -> { + }); + } + + + /** + * 计算job训练时长 + * + * @param bizPod + * @param consumer pod已经完成状态的回调函数 + * @return res 返回训练时长 + */ + private String calculateRuntime(BizPod bizPod, Consumer consumer) { + Long completedTime; + if (StringUtils.isBlank(bizPod.getStartTime())) { + return TrainUtil.INIT_RUNTIME; + } + Long startTime = transformTime(bizPod.getStartTime()); + boolean hasCompleted = StringUtils.isNotBlank(bizPod.getCompletedTime()); + completedTime = hasCompleted ? transformTime(bizPod.getCompletedTime()) : LocalDateTime.now().toEpochSecond(ZoneOffset.of(trainJobConfig.getPlusEight())); + Long time = completedTime - startTime; + String res = DubheDateUtil.convert2Str(time); + if (hasCompleted) { + consumer.accept(res); + } + return res; + } + + + /** + * 时间转换 + * + * @param time 时间 + * @return Long 时间戳 + */ + private Long transformTime(String time) { + LocalDateTime localDateTime = LocalDateTime.parse(time, DateTimeFormatter.ISO_OFFSET_DATE_TIME); + //没有根据时区做处理, 默认当前为东八区 + localDateTime = localDateTime.plusHours(Long.valueOf(trainJobConfig.getEight())); + return localDateTime.toEpochSecond(ZoneOffset.of(trainJobConfig.getPlusEight())); + } +} diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/async/TrainAlgorithmUploadAsync.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/async/TrainAlgorithmUploadAsync.java new file mode 100644 index 0000000..e52f688 --- /dev/null +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/async/TrainAlgorithmUploadAsync.java @@ -0,0 +1,115 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ +package org.dubhe.async; + +import org.dubhe.config.NfsConfig; +import org.dubhe.dao.PtTrainAlgorithmMapper; +import org.dubhe.domain.dto.PtTrainAlgorithmCreateDTO; +import org.dubhe.domain.dto.UserDTO; +import org.dubhe.domain.entity.PtTrainAlgorithm; +import org.dubhe.enums.AlgorithmStatusEnum; +import org.dubhe.enums.BizNfsEnum; +import org.dubhe.enums.LogEnum; +import org.dubhe.exception.BusinessException; +import org.dubhe.service.NoteBookService; +import org.dubhe.utils.K8sNameTool; +import org.dubhe.utils.LocalFileUtil; +import org.dubhe.utils.LogUtil; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.scheduling.annotation.Async; +import org.springframework.stereotype.Component; + +/** + * @description 异步上传算法 + * @date 2020-08-10 + */ +@Component +public class TrainAlgorithmUploadAsync { + + @Autowired + private LocalFileUtil localFileUtil; + + @Autowired + private NfsConfig nfsConfig; + + @Autowired + private K8sNameTool k8sNameTool; + + @Autowired + private NoteBookService noteBookService; + + @Autowired + private PtTrainAlgorithmMapper trainAlgorithmMapper; + + /** + * 异步任务创建算法 + * + * @param user 当前登录用户信息 + * @param ptTrainAlgorithm 算法信息 + * @param trainAlgorithmCreateDTO 创建算法条件 + */ + @Async("trainExecutor") + public void createTrainAlgorithm(UserDTO user, PtTrainAlgorithm ptTrainAlgorithm, PtTrainAlgorithmCreateDTO trainAlgorithmCreateDTO) { + String path = nfsConfig.getBucket() + trainAlgorithmCreateDTO.getCodeDir(); + //校验创建算法来源(true:由fork创建算法,false:其它创建算法方式),若为true则拷贝预置算法文件至新路径 + if (trainAlgorithmCreateDTO.getFork()) { + //生成算法相对路径 + String algorithmPath = k8sNameTool.getNfsPath(BizNfsEnum.ALGORITHM, user.getId()); + //拷贝预置算法文件夹 + boolean copyResult = localFileUtil.copyPath(path, nfsConfig.getBucket() + algorithmPath); + if (!copyResult) { + LogUtil.error(LogEnum.BIZ_TRAIN, "The user {} copied the preset algorithm path {} successfully", user.getUsername(), path); + updateTrainAlgorithm(ptTrainAlgorithm, trainAlgorithmCreateDTO, false); + throw new BusinessException("内部错误"); + } + + ptTrainAlgorithm.setCodeDir(algorithmPath); + + //修改算法上传状态 + updateTrainAlgorithm(ptTrainAlgorithm, trainAlgorithmCreateDTO, true); + + } else { + updateTrainAlgorithm(ptTrainAlgorithm, trainAlgorithmCreateDTO, true); + } + } + + + /** + * 更新上传算法状态 + * + * @param ptTrainAlgorithm 算法信息 + * @param trainAlgorithmCreateDTO 创建算法的条件 + * @param flag 创建算法是否成功(true:成功,false:失败) + */ + public void updateTrainAlgorithm(PtTrainAlgorithm ptTrainAlgorithm, PtTrainAlgorithmCreateDTO trainAlgorithmCreateDTO, boolean flag) { + + LogUtil.info(LogEnum.BIZ_TRAIN, "async update algorithmPath by algorithmId:{} and update noteBook by noteBookId:{}", ptTrainAlgorithm.getId(), trainAlgorithmCreateDTO.getNoteBookId()); + if (flag) { + ptTrainAlgorithm.setAlgorithmStatus(AlgorithmStatusEnum.SUCCESS.getCode()); + //更新fork算法新路径 + trainAlgorithmMapper.updateById(ptTrainAlgorithm); + //保存算法根据notbookId更新算法id + if (trainAlgorithmCreateDTO.getNoteBookId() != null) { + LogUtil.info(LogEnum.BIZ_TRAIN, "Save algorithm Update algorithm ID :{} according to notBookId:{}", trainAlgorithmCreateDTO.getNoteBookId(), ptTrainAlgorithm.getId()); + noteBookService.updateTrainIdByNoteBookId(trainAlgorithmCreateDTO.getNoteBookId(), ptTrainAlgorithm.getId()); + } + } else { + ptTrainAlgorithm.setAlgorithmStatus(AlgorithmStatusEnum.FAIL.getCode()); + trainAlgorithmMapper.updateById(ptTrainAlgorithm); + } + } +} diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/async/TrainJobAsync.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/async/TrainJobAsync.java new file mode 100644 index 0000000..efca62e --- /dev/null +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/async/TrainJobAsync.java @@ -0,0 +1,417 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ +package org.dubhe.async; + +import cn.hutool.core.util.StrUtil; +import com.alibaba.fastjson.JSONObject; +import org.dubhe.base.MagicNumConstant; +import org.dubhe.constant.SymbolConstant; +import org.dubhe.config.TrainJobConfig; +import org.dubhe.dao.PtTrainJobMapper; +import org.dubhe.domain.dto.BaseTrainJobDTO; +import org.dubhe.domain.dto.UserDTO; +import org.dubhe.domain.entity.PtTrainJob; +import org.dubhe.domain.vo.PtImageAndAlgorithmVO; +import org.dubhe.enums.BizEnum; +import org.dubhe.enums.LogEnum; +import org.dubhe.enums.ResourcesPoolTypeEnum; +import org.dubhe.enums.TrainJobStatusEnum; +import org.dubhe.exception.BusinessException; +import org.dubhe.k8s.api.DistributeTrainApi; +import org.dubhe.k8s.api.NamespaceApi; +import org.dubhe.k8s.api.TrainJobApi; +import org.dubhe.k8s.domain.bo.DistributeTrainBO; +import org.dubhe.k8s.domain.bo.PtJupyterJobBO; +import org.dubhe.k8s.domain.resource.BizDistributeTrain; +import org.dubhe.k8s.domain.resource.BizNamespace; +import org.dubhe.k8s.domain.vo.PtJupyterJobVO; +import org.dubhe.utils.*; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; +import java.util.ArrayList; +import java.util.List; + +/** + * @description 提交训练任务 + * @date 2020-07-17 + */ +@Component +public class TrainJobAsync { + + @Autowired + private K8sNameTool k8sNameTool; + + @Autowired + private NamespaceApi namespaceApi; + + @Autowired + private TrainJobConfig trainJobConfig; + + @Autowired + private NfsUtil nfsUtil; + + @Autowired + private LocalFileUtil localFileUtil; + + @Autowired + private PtTrainJobMapper ptTrainJobMapper; + + @Autowired + private TrainJobApi trainJobApi; + + @Autowired + private DistributeTrainApi distributeTrainApi; + + + /** + * 提交分布式训练 + * + * @param baseTrainJobDTO 训练任务信息 + * @param currentUser 用户 + * @param ptImageAndAlgorithmVO 镜像和算法信息 + * @param ptTrainJob 训练任务实体信息 + */ + public void doDistributedJob(BaseTrainJobDTO baseTrainJobDTO, UserDTO currentUser, PtImageAndAlgorithmVO ptImageAndAlgorithmVO, PtTrainJob ptTrainJob) { + try { + //判断是否存在相应的namespace,如果没有则创建 + String namespace = getNamespace(currentUser); + // 构建DistributeTrainBO + DistributeTrainBO bo = buildDistributeTrainBO(baseTrainJobDTO, currentUser, ptImageAndAlgorithmVO, ptTrainJob, namespace); + if (null == bo) { + LogUtil.error(LogEnum.BIZ_TRAIN, "user{}create TrainJob,Encapsulating ptjupyterjobbo object is empty,the received parameters namespace:{}", currentUser.getId(), namespace); + updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, "", false); + return; + } + // 调度K8s + BizDistributeTrain bizDistributeTrain = distributeTrainApi.create(bo); + if (bizDistributeTrain.isSuccess()) { + // 调度成功 + updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, bizDistributeTrain.getName(), true); + } else { + // 调度失败 + LogUtil.error(LogEnum.BIZ_TRAIN, "distributeTrainApi.create FAILED! {}", bizDistributeTrain); + updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, bizDistributeTrain.getName(), false); + } + } catch (Exception e) { + LogUtil.error(LogEnum.BIZ_TRAIN, "doDistributedJob ERROR!{} ", e); + updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, "", false); + } + } + + /** + * 构造分布式训练DistributeTrainBO + * + * @param baseTrainJobDTO 训练任务信息 + * @param currentUser 用户 + * @param ptImageAndAlgorithmVO 镜像和算法信息 + * @param ptTrainJob 训练任务实体信息 + * @param namespace 命名空间 + * @return DistributeTrainBO + */ + private DistributeTrainBO buildDistributeTrainBO(BaseTrainJobDTO baseTrainJobDTO, UserDTO currentUser, PtImageAndAlgorithmVO ptImageAndAlgorithmVO, PtTrainJob ptTrainJob, String namespace) { + //绝对路径 + String basePath = nfsUtil.getNfsConfig().getBucket() + trainJobConfig.getManage() + StrUtil.SLASH + + currentUser.getId() + StrUtil.SLASH + baseTrainJobDTO.getJobName(); + //相对路径 + String relativePath = StrUtil.SLASH + trainJobConfig.getManage() + StrUtil.SLASH + + currentUser.getId() + StrUtil.SLASH + baseTrainJobDTO.getJobName(); + String[] codeDirArray = ptImageAndAlgorithmVO.getCodeDir().split(StrUtil.SLASH); + String workspaceDir = codeDirArray[codeDirArray.length - 1]; + // 算法路径待拷贝的地址 + String sourcePath = nfsUtil.getNfsConfig().getBucket() + ptImageAndAlgorithmVO.getCodeDir().substring(1); + String trainDir = basePath.substring(1) + StrUtil.SLASH + workspaceDir; + + if (!localFileUtil.copyPath(sourcePath, trainDir)) { + LogUtil.error(LogEnum.BIZ_TRAIN, "buildDistributeTrainBO copyPath failed ! sourcePath:{},basePath:{},trainDir:{}", sourcePath, basePath, trainDir); + return null; + } + // 参数前缀 + String paramPrefix = trainJobConfig.getPythonFormat(); + // 初始化固定头命令,获取分布式节点IP + StringBuilder sb = new StringBuilder("export NODE_IPS=`cat /home/hostfile.json |jq -r \".[]|.ip\"|paste -d \",\" -s` "); + // 切换到算法路径下 + sb.append(" && cd ").append(trainJobConfig.getDockerTrainPath()).append(StrUtil.SLASH).append(workspaceDir).append(" && "); + // 拼接用户自定义python启动命令 + sb.append(ptImageAndAlgorithmVO.getRunCommand()); + // 拼接python固定参数 节点IP + sb.append(paramPrefix).append(trainJobConfig.getNodeIps()).append("=\"$NODE_IPS\" "); + // 拼接python固定参数 节点数量 + sb.append(paramPrefix).append(trainJobConfig.getNodeNum()).append(SymbolConstant.FLAG_EQUAL).append(ptTrainJob.getResourcesPoolNode()).append(StrUtil.SPACE); + if (ptImageAndAlgorithmVO.getIsTrainOut()) { + // 拼接 out + nfsUtil.createDir(basePath + StrUtil.SLASH + trainJobConfig.getOutPath()); + baseTrainJobDTO.setOutPath(relativePath + StrUtil.SLASH + trainJobConfig.getOutPath()); + sb.append(paramPrefix).append(trainJobConfig.getDockerOutPath()); + } + if (ptImageAndAlgorithmVO.getIsTrainLog()) { + // 拼接 输出日志 + nfsUtil.createDir(basePath + StrUtil.SLASH + trainJobConfig.getLogPath()); + baseTrainJobDTO.setLogPath(relativePath + StrUtil.SLASH + trainJobConfig.getLogPath()); + sb.append(paramPrefix).append(trainJobConfig.getDockerLogPath()); + } + if (ptImageAndAlgorithmVO.getIsVisualizedLog()) { + // 拼接 输出可视化日志 + nfsUtil.createDir(basePath + StrUtil.SLASH + trainJobConfig.getVisualizedLogPath()); + baseTrainJobDTO.setVisualizedLogPath(relativePath + StrUtil.SLASH + trainJobConfig.getVisualizedLogPath()); + sb.append(paramPrefix).append(trainJobConfig.getDockerVisualizedLogPath()); + } + // 拼接python固定参数 数据集 + sb.append(paramPrefix).append(trainJobConfig.getDockerDataset()); + JSONObject runParams = baseTrainJobDTO.getRunParams(); + if (null != runParams && !runParams.isEmpty()) { + // 拼接用户自定义参数 + runParams.entrySet().forEach(entry -> + sb.append(paramPrefix).append(entry.getKey()).append(SymbolConstant.FLAG_EQUAL).append(entry.getValue()).append(StrUtil.SPACE) + ); + } + // 在用户自定以参数拼接晚后拼接固定参数,防止被用户自定义参数覆盖 + if (ResourcesPoolTypeEnum.isGpuCode(baseTrainJobDTO.getPtTrainJobSpecs().getResourcesPoolType())) { + // 需要GPU + sb.append(paramPrefix).append(trainJobConfig.getGpuNumPerNode()).append(SymbolConstant.FLAG_EQUAL).append(baseTrainJobDTO.getGpuNumPerNode()).append(StrUtil.SPACE); + } + String mainCommand = sb.toString(); + // 拼接辅助日志打印 + String wholeCommand = " echo 'Distribute training mission begins... " + + mainCommand + + " ' && " + + mainCommand + + " && echo 'Distribute training mission is over' "; + DistributeTrainBO distributeTrainBO = new DistributeTrainBO() + .setNamespace(namespace) + .setName(baseTrainJobDTO.getJobName()) + .setSize(ptTrainJob.getResourcesPoolNode()) + .setImage(ptImageAndAlgorithmVO.getImageName()) + .setMasterCmd(wholeCommand) + .setMemNum(baseTrainJobDTO.getMenNum()) + .setCpuNum(baseTrainJobDTO.getCpuNum()) + .setDatasetStoragePath(k8sNameTool.getAbsoluteNfsPath(baseTrainJobDTO.getDataSourcePath())) + .setWorkspaceStoragePath(localFileUtil.formatPath(nfsUtil.getNfsConfig().getRootDir() + basePath)) + .setModelStoragePath(k8sNameTool.getAbsoluteNfsPath(relativePath + StrUtil.SLASH + trainJobConfig.getOutPath())) + .setBusinessLabel(k8sNameTool.getPodLabel(BizEnum.ALGORITHM)); + //延时启动,单位为分钟 + if (baseTrainJobDTO.getDelayCreateTime() != null && baseTrainJobDTO.getDelayCreateTime() > 0) { + distributeTrainBO.setDelayCreateTime(baseTrainJobDTO.getDelayCreateTime() * MagicNumConstant.SIXTY); + } + //定时停止,单位为分钟 + if (baseTrainJobDTO.getDelayDeleteTime() != null && baseTrainJobDTO.getDelayDeleteTime() > 0) { + distributeTrainBO.setDelayDeleteTime(baseTrainJobDTO.getDelayDeleteTime() * MagicNumConstant.SIXTY); + } + if (ResourcesPoolTypeEnum.isGpuCode(baseTrainJobDTO.getPtTrainJobSpecs().getResourcesPoolType())) { + // 需要GPU + distributeTrainBO.setGpuNum(baseTrainJobDTO.getGpuNumPerNode()); + } + // 主从一致 + distributeTrainBO.setSlaveCmd(distributeTrainBO.getMasterCmd()); + return distributeTrainBO; + } + + + /** + * 提交job + * + * @param baseTrainJobDTO 训练任务信息 + * @param currentUser 用户 + * @param ptImageAndAlgorithmVO 镜像和算法信息 + */ + public void doJob(BaseTrainJobDTO baseTrainJobDTO, UserDTO currentUser, PtImageAndAlgorithmVO ptImageAndAlgorithmVO, PtTrainJob ptTrainJob) { + PtJupyterJobBO jobBo = null; + String k8sJobName = ""; + try { + //判断是否存在相应的namespace,如果没有则创建 + String namespace = getNamespace(currentUser); + + //封装PtJupyterJobBO对象,调用创建训练任务接口 + jobBo = pkgPtJupyterJobBo(baseTrainJobDTO, currentUser, ptImageAndAlgorithmVO, namespace); + if (null == jobBo) { + LogUtil.error(LogEnum.BIZ_TRAIN, "user {} create TrainJob,Encapsulating ptjupyterjobbo object is empty,the received parameters namespace:{}", currentUser.getId(), namespace); + updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, k8sJobName, false); + } + PtJupyterJobVO ptJupyterJobResult = trainJobApi.create(jobBo); + if (!ptJupyterJobResult.isSuccess()) { + String message = null == ptJupyterJobResult.getMessage() ? "未知的错误" : ptJupyterJobResult.getMessage(); + LogUtil.error(LogEnum.BIZ_TRAIN, "user {} create TrainJob, K8s creation failed, the received parameters are {}, the wrong information is{}", currentUser.getUsername(), jobBo, message); + ptTrainJob.setTrainMsg(message); + updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, k8sJobName, false); + } + k8sJobName = ptJupyterJobResult.getName(); + //更新训练任务状态 + updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, k8sJobName, true); + } catch (Exception e) { + LogUtil.error(LogEnum.BIZ_TRAIN, "user {} create TrainJob, K8s creation failed, the received parameters are {}, the wrong information is{}", currentUser.getUsername(), + jobBo, e); + ptTrainJob.setTrainMsg("内部错误"); + updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, k8sJobName, false); + } + } + + + /** + * 获取namespace + * + * @param currentUser 用户 + * @return String 命名空间 + */ + private String getNamespace(UserDTO currentUser) { + String namespaceStr = k8sNameTool.generateNamespace(currentUser.getId()); + BizNamespace bizNamespace = namespaceApi.get(namespaceStr); + if (null == bizNamespace) { + BizNamespace namespace = namespaceApi.create(namespaceStr, null); + if (null == namespace || !namespace.isSuccess()) { + LogUtil.error(LogEnum.BIZ_TRAIN, "User {} failed to create namespace during training job..."); + throw new BusinessException("内部错误"); + } + } + return namespaceStr; + } + + /** + * 封装出创建job所需的BO + * + * @param baseTrainJobDTO 训练任务信息 + * @param ptImageAndAlgorithmVO 镜像和算法信息 + * @param namespace 命名空间 + * @return PtJupyterJobBO jupyter任务BO + */ + private PtJupyterJobBO pkgPtJupyterJobBo(BaseTrainJobDTO baseTrainJobDTO, UserDTO currentUser, + PtImageAndAlgorithmVO ptImageAndAlgorithmVO, String namespace) { + + //绝对路径 + String commonPath = nfsUtil.getNfsConfig().getBucket() + trainJobConfig.getManage() + StrUtil.SLASH + + currentUser.getId() + StrUtil.SLASH + baseTrainJobDTO.getJobName(); + //相对路径 + String relativeCommonPath = StrUtil.SLASH + trainJobConfig.getManage() + StrUtil.SLASH + + currentUser.getId() + StrUtil.SLASH + baseTrainJobDTO.getJobName(); + String[] codeDirArray = ptImageAndAlgorithmVO.getCodeDir().split(StrUtil.SLASH); + String workspaceDir = codeDirArray[codeDirArray.length - 1]; + // 算法路径待拷贝的地址 + String sourcePath = nfsUtil.getNfsConfig().getBucket() + ptImageAndAlgorithmVO.getCodeDir().substring(1); + String trainDir = commonPath.substring(1) + StrUtil.SLASH + workspaceDir; + LogUtil.info(LogEnum.BIZ_TRAIN, "Algorithm path copy::sourcePath:{},commonPath:{},trainDir:{}", sourcePath, commonPath, trainDir); + boolean bool = localFileUtil.copyPath(sourcePath.substring(1), trainDir); + if (!bool) { + LogUtil.error(LogEnum.BIZ_TRAIN, "During the process of user {} creating training Job and encapsulating k8s creating job interface parameters, it failed to copy algorithm directory {} to the specified directory {}", currentUser.getUsername(), sourcePath.substring(1), + trainDir); + return null; + } + + List list = new ArrayList<>(); + JSONObject runParams = baseTrainJobDTO.getRunParams(); + + StringBuilder sb = new StringBuilder(); + sb.append(ptImageAndAlgorithmVO.getRunCommand()); + // 拼接out,log和dataset + String pattern = trainJobConfig.getPythonFormat(); + if (ptImageAndAlgorithmVO.getIsTrainOut()) { + nfsUtil.createDir(commonPath + StrUtil.SLASH + trainJobConfig.getOutPath()); + baseTrainJobDTO.setOutPath(relativeCommonPath + StrUtil.SLASH + trainJobConfig.getOutPath()); + sb.append(pattern).append(trainJobConfig.getDockerOutPath()); + } + if (ptImageAndAlgorithmVO.getIsTrainLog()) { + nfsUtil.createDir(commonPath + StrUtil.SLASH + trainJobConfig.getLogPath()); + baseTrainJobDTO.setLogPath(relativeCommonPath + StrUtil.SLASH + trainJobConfig.getLogPath()); + sb.append(pattern).append(trainJobConfig.getDockerLogPath()); + } + if (ptImageAndAlgorithmVO.getIsVisualizedLog()) { + nfsUtil.createDir(commonPath + StrUtil.SLASH + trainJobConfig.getVisualizedLogPath()); + baseTrainJobDTO.setVisualizedLogPath(relativeCommonPath + StrUtil.SLASH + trainJobConfig.getVisualizedLogPath()); + sb.append(pattern).append(trainJobConfig.getDockerVisualizedLogPath()); + } + sb.append(pattern).append(trainJobConfig.getDockerDataset()); + + String valDataSourcePath = baseTrainJobDTO.getValDataSourcePath(); + if (StringUtils.isNotBlank(valDataSourcePath)) { + sb.append(pattern).append(trainJobConfig.getLoadValDatasetKey()).append(SymbolConstant.FLAG_EQUAL).append(trainJobConfig.getDockerValDatasetPath()); + } + //将模型加载路径拼接到 + String modelLoadPathDir = baseTrainJobDTO.getModelLoadPathDir(); + if (StringUtils.isNotBlank(modelLoadPathDir)) { + //将模型路径model_load_dir路径 + sb.append(pattern).append(trainJobConfig.getLoadKey()).append(SymbolConstant.FLAG_EQUAL).append(trainJobConfig.getDockerModelPath()); + } + + if (null != runParams && !runParams.isEmpty()) { + runParams.forEach((k, v) -> + sb.append(pattern).append(k).append(SymbolConstant.FLAG_EQUAL).append(v).append(StrUtil.SPACE) + ); + } + // 在用户自定以参数拼接晚后拼接固定参数,防止被用户自定义参数覆盖 + if (ResourcesPoolTypeEnum.isGpuCode(baseTrainJobDTO.getPtTrainJobSpecs().getResourcesPoolType())) { + // 需要GPU + sb.append(pattern).append(trainJobConfig.getGpuNumPerNode()).append(SymbolConstant.FLAG_EQUAL).append(baseTrainJobDTO.getGpuNumPerNode()).append(StrUtil.SPACE); + } + String executeCmd = sb.toString(); + list.add("-c"); + + String workPath = trainJobConfig.getDockerTrainPath() + StrUtil.SLASH + workspaceDir; + String command = "echo 'training mission begins... " + executeCmd + "\r\n '" + + " && cd " + workPath + + " && " + executeCmd + + " && echo 'the training mission is over' "; + list.add(command); + + PtJupyterJobBO jobBo = new PtJupyterJobBO(); + jobBo.setNamespace(namespace) + .setName(baseTrainJobDTO.getJobName()) + .setImage(ptImageAndAlgorithmVO.getImageName()) + .putNfsMounts(trainJobConfig.getDockerDatasetPath(), nfsUtil.getNfsConfig().getRootDir() + nfsUtil.getNfsConfig().getBucket().substring(1) + baseTrainJobDTO.getDataSourcePath()) + .setCmdLines(list) + .putNfsMounts(trainJobConfig.getDockerTrainPath(), nfsUtil.getNfsConfig().getRootDir() + commonPath.substring(1)) + .putNfsMounts(trainJobConfig.getDockerModelPath(), nfsUtil.formatPath(nfsUtil.getAbsolutePath(modelLoadPathDir))) + .putNfsMounts(trainJobConfig.getDockerValDatasetPath(), nfsUtil.formatPath(nfsUtil.getAbsolutePath(valDataSourcePath))) + .setBusinessLabel(k8sNameTool.getPodLabel(BizEnum.ALGORITHM)); + //延时启动,单位为分钟 + if (baseTrainJobDTO.getDelayCreateTime() != null && baseTrainJobDTO.getDelayCreateTime() > 0) { + jobBo.setDelayCreateTime(baseTrainJobDTO.getDelayCreateTime() * MagicNumConstant.SIXTY); + } + //自动停止,单位为分钟 + if (baseTrainJobDTO.getDelayDeleteTime() != null && baseTrainJobDTO.getDelayDeleteTime() > 0) { + jobBo.setDelayDeleteTime(baseTrainJobDTO.getDelayDeleteTime() * MagicNumConstant.SIXTY); + } + jobBo.setCpuNum(baseTrainJobDTO.getCpuNum()).setMemNum(baseTrainJobDTO.getMenNum()); + if (ResourcesPoolTypeEnum.isGpuCode(baseTrainJobDTO.getPtTrainJobSpecs().getResourcesPoolType())) { + jobBo.setUseGpu(true).setGpuNum(baseTrainJobDTO.getGpuNumPerNode()); + } else { + jobBo.setUseGpu(false); + } + return jobBo; + } + + /** + * 训练任务异步处理更新训练状态 + * + * @param user 用户 + * @param ptTrainJob 训练任务 + * @param baseTrainJobDTO 训练任务信息 + * @param k8sJobName k8s创建的job名称,或者分布式训练名称 + * @param flag 创建训练任务是否异常(true:正常,false:失败) + **/ + private void updateTrainStatus(UserDTO user, PtTrainJob ptTrainJob, BaseTrainJobDTO baseTrainJobDTO, String k8sJobName, boolean flag) { + + ptTrainJob.setK8sJobName(k8sJobName) + .setOutPath(baseTrainJobDTO.getOutPath()) + .setLogPath(baseTrainJobDTO.getLogPath()) + .setVisualizedLogPath(baseTrainJobDTO.getVisualizedLogPath()); + LogUtil.info(LogEnum.BIZ_TRAIN, "user {} training tasks are processed asynchronously to update training status,receiving parameters:{}", user.getId(), ptTrainJob); + if (flag) { + ptTrainJobMapper.updateById(ptTrainJob); + } else { + ptTrainJob.setTrainStatus(TrainJobStatusEnum.CREATE_FAILED.getStatus()); + //训练任务创建失败 + ptTrainJobMapper.updateById(ptTrainJob); + } + } +} diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/task/TransactionAsyncManager.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/async/TransactionAsyncManager.java similarity index 62% rename from dubhe-server/dubhe-admin/src/main/java/org/dubhe/task/TransactionAsyncManager.java rename to dubhe-server/dubhe-admin/src/main/java/org/dubhe/async/TransactionAsyncManager.java index 2a65e55..053df6b 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/task/TransactionAsyncManager.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/async/TransactionAsyncManager.java @@ -14,13 +14,16 @@ * limitations under the License. * ============================================================= */ -package org.dubhe.task; +package org.dubhe.async; import org.dubhe.aspect.LogAspect; +import org.dubhe.base.DataContext; import org.dubhe.domain.dto.BaseTrainJobDTO; +import org.dubhe.domain.dto.CommonPermissionDataDTO; import org.dubhe.domain.dto.UserDTO; import org.dubhe.domain.entity.PtTrainJob; import org.dubhe.domain.vo.PtImageAndAlgorithmVO; +import org.dubhe.enums.TrainTypeEnum; import org.slf4j.MDC; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; @@ -38,26 +41,32 @@ import java.util.concurrent.Executor; public class TransactionAsyncManager { @Autowired - private TrainJobAsyncTask trainJobAsyncTask; - - @Resource(name = "trainJobAsyncExecutor") - private Executor trainJobAsyncExecutor; + private TrainJobAsync trainJobAsync; + @Resource(name = "trainExecutor") + private Executor trainExecutor; public void execute(BaseTrainJobDTO baseTrainJobDTO, UserDTO currentUser, PtImageAndAlgorithmVO ptImageAndAlgorithmVO, PtTrainJob ptTrainJob) { - String traceId = MDC.get(LogAspect.TRACE_ID); + CommonPermissionDataDTO commonPermissionDataDTO = DataContext.get(); TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronizationAdapter() { - @Override public void afterCommit() { - trainJobAsyncExecutor.execute( - () -> { - MDC.put(LogAspect.TRACE_ID, traceId); - trainJobAsyncTask.doJob(baseTrainJobDTO, currentUser, ptImageAndAlgorithmVO, ptTrainJob); - MDC.remove(LogAspect.TRACE_ID); - } - ); + + trainExecutor.execute(() -> { + MDC.put(LogAspect.TRACE_ID, traceId); + DataContext.set(commonPermissionDataDTO); + + if (TrainTypeEnum.isDistributeTrain(ptTrainJob.getTrainType())) { + // 分布式训练 + trainJobAsync.doDistributedJob(baseTrainJobDTO, currentUser, ptImageAndAlgorithmVO, ptTrainJob); + } else { + // 普通训练 + trainJobAsync.doJob(baseTrainJobDTO, currentUser, ptImageAndAlgorithmVO, ptTrainJob); + } + MDC.remove(LogAspect.TRACE_ID); + DataContext.remove(); + }); } }); } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/config/GlobalFilter.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/config/GlobalFilter.java new file mode 100644 index 0000000..21342f2 --- /dev/null +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/config/GlobalFilter.java @@ -0,0 +1,148 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.dubhe.config; + +import com.alibaba.fastjson.JSON; +import org.dubhe.constant.StringConstant; +import org.dubhe.constatnts.UserConstant; +import org.dubhe.dto.GlobalRequestRecordDTO; +import org.dubhe.enums.LogEnum; +import org.dubhe.utils.JwtUtils; +import org.dubhe.utils.LogUtil; +import org.dubhe.utils.StringUtils; +import org.springframework.core.annotation.Order; +import org.springframework.stereotype.Component; + +import javax.servlet.*; +import javax.servlet.annotation.WebFilter; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.IOException; + +import static org.dubhe.constant.StringConstant.K8S_CALLBACK_URI; + +/** + * @description 全局请求拦截器 用于日志收集 + * @date 2020-08-13 + */ +@Order(1) +@Component +@WebFilter(filterName = "GlobalFilter", urlPatterns = "/**") +public class GlobalFilter implements Filter { + @Override + public void init(FilterConfig filterConfig) throws ServletException { + + } + + @Override + public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain chain) throws IOException, ServletException { + + long start = System.currentTimeMillis(); + HttpServletRequest request = ((HttpServletRequest) servletRequest); + HttpServletResponse response = ((HttpServletResponse) servletResponse); + + GlobalRequestRecordDTO dto = new GlobalRequestRecordDTO(); + + try { + if (StringUtils.isNotBlank(request.getContentType()) && request.getContentType().contains(StringConstant.MULTIPART)) { + chain.doFilter(request, response); + } else { + checkScheduleRequest(request); + checkK8sCallback(request); + + RequestBodyWrapper requestBodyWrapper = new RequestBodyWrapper(request); + ResponseBodyWrapper responseBodyWrapper = new ResponseBodyWrapper(response); + dto.setRequestBody(requestBodyWrapper.getBodyString()); + chain.doFilter(requestBodyWrapper, responseBodyWrapper); + + if (StringConstant.JSON_REQUEST.equals(responseBodyWrapper.getContentType())) { + final String responseBody = responseBodyWrapper.getResponseBody(); + dto.setResponseBody(responseBody); + } else { + responseBodyWrapper.flush(); + } + } + } catch (Exception e) { + LogUtil.error(LogEnum.GLOBAL_REQ, "Global request record error : {}", e); + throw e; + } finally { + buildGlobalRequestDTO(dto, request, response); + dto.setTimeCost(System.currentTimeMillis() - start); + LogUtil.info(LogEnum.GLOBAL_REQ, "Global request record: {}", dto); + LogUtil.cleanTrace(); + } + } + + /** + * 构建全局请求对象 + * + * @param dto + * @param request + * @param response + */ + private void buildGlobalRequestDTO(GlobalRequestRecordDTO dto, HttpServletRequest request, HttpServletResponse response) { + dto.setClientHost(request.getRemoteHost()); + dto.setParams(JSON.toJSONString(request.getParameterMap())); + dto.setMethod(request.getMethod()); + dto.setUri(request.getRequestURI()); + //身份认证信息 + String token = request.getHeader(UserConstant.USER_TOKEN_KEY); + dto.setAuthorization(token); + if (token != null) { + String userName = JwtUtils.getUserName(token); + dto.setUsername(userName); + } + dto.setContentType(response.getContentType()); + dto.setStatus(response.getStatus()); + } + + /** + * 检查是否是前端的定时请求 + * + * @param request 请求信息 + * @return 是否是前端的定时请求 + */ + private boolean checkScheduleRequest(HttpServletRequest request) { + + if (StringConstant.REQUEST_METHOD_GET.equals(request.getMethod()) + && StringUtils.isNotBlank(request.getParameter(LogUtil.SCHEDULE_LEVEL))) { + LogUtil.startScheduleTrace(); + return true; + } + + return false; + } + + /** + * 校验请求是否为k8s回调 + * @param request 请求信息 + * @return 是否为k8s回调 + */ + private boolean checkK8sCallback(HttpServletRequest request) { + if (request.getRequestURI() != null && request.getRequestURI().contains(K8S_CALLBACK_URI)) { + LogUtil.startK8sCallbackTrace(); + return true; + } + return false; + } + + @Override + public void destroy() { + + } +} diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/config/RequestBodyWrapper.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/config/RequestBodyWrapper.java new file mode 100644 index 0000000..f041be7 --- /dev/null +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/config/RequestBodyWrapper.java @@ -0,0 +1,101 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.dubhe.config; + +import org.dubhe.enums.LogEnum; +import org.dubhe.utils.LogUtil; + +import javax.servlet.ReadListener; +import javax.servlet.ServletInputStream; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletRequestWrapper; +import java.io.BufferedReader; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStreamReader; + +/** + * @description 用于获取请求body参数的包装类 + * @date 2020-08-20 + */ +public class RequestBodyWrapper extends HttpServletRequestWrapper { + + private ByteArrayInputStream byteArrayInputStream; + + private String bodyString; + + public RequestBodyWrapper(HttpServletRequest request) + throws IOException { + super(request); + try (BufferedReader reader = request.getReader()) { + StringBuilder sb = new StringBuilder(); + String line = null; + while ((line = reader.readLine()) != null) { + sb.append(line); + } + if (sb.length() > 0) { + bodyString = sb.toString(); + } else { + bodyString = ""; + } + byteArrayInputStream = new ByteArrayInputStream(bodyString.getBytes()); + } catch (Exception e) { + LogUtil.error(LogEnum.GLOBAL_REQ, "request get reader error : {}", e); + throw e; + } + + } + + /** + * 获取请求体的json数据 + * + * @return + */ + public String getBodyString() { + return bodyString; + } + + @Override + public BufferedReader getReader() throws IOException { + return new BufferedReader(new InputStreamReader(getInputStream())); + } + + @Override + public ServletInputStream getInputStream() throws IOException { + return new ServletInputStream() { + @Override + public boolean isFinished() { + return false; + } + + @Override + public boolean isReady() { + return false; + } + + @Override + public void setReadListener(ReadListener readListener) { + } + + @Override + public int read() throws IOException { + return byteArrayInputStream.read(); + } + }; + } +} \ No newline at end of file diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/config/ResponseBodyWrapper.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/config/ResponseBodyWrapper.java new file mode 100644 index 0000000..dc0bc3e --- /dev/null +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/config/ResponseBodyWrapper.java @@ -0,0 +1,126 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.dubhe.config; + +import org.dubhe.enums.LogEnum; +import org.dubhe.utils.LogUtil; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.io.UnsupportedEncodingException; + +import javax.servlet.ServletOutputStream; +import javax.servlet.WriteListener; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpServletResponseWrapper; + +/** + * @description 用于获取response 的 json 返回值 + * @date 2020-08-19 + */ +public class ResponseBodyWrapper extends HttpServletResponseWrapper { + + private ByteArrayOutputStream byteArrayOutputStream = null; + + private ServletOutputStream servletOutputStream = null; + + private PrintWriter printWriter = null; + + private HttpServletResponse response; + + public ResponseBodyWrapper(HttpServletResponse response) throws IOException { + super(response); + this.response = response; + byteArrayOutputStream = new ByteArrayOutputStream(); + printWriter = new PrintWriter(new OutputStreamWriter(byteArrayOutputStream, "UTF-8")); + servletOutputStream = new ServletOutputStream() { + + @Override + public void write(int b) throws IOException { + byteArrayOutputStream.write(b); + } + + @Override + public boolean isReady() { + return false; + } + + @Override + public void setWriteListener(WriteListener writeListener) { + } + }; + } + + @Override + public ServletOutputStream getOutputStream() throws IOException { + return servletOutputStream; + } + + @Override + public PrintWriter getWriter() throws IOException { + return printWriter; + } + + @Override + public void flushBuffer() throws IOException { + if (servletOutputStream != null) { + servletOutputStream.flush(); + } + if (printWriter != null) { + printWriter.flush(); + } + } + + @Override + public void reset() { + byteArrayOutputStream.reset(); + } + + /** + * 获取json返回值 + * @return + * @throws IOException + */ + public String getResponseBody() throws IOException { + //清空response的流,之后再添加进去 + flushBuffer(); + byte[] bytes = byteArrayOutputStream.toByteArray(); + try { + return new String(bytes, "UTF-8"); + } catch (UnsupportedEncodingException e) { + LogUtil.error(LogEnum.GLOBAL_REQ, e); + } finally { + response.getOutputStream().write(bytes); + } + return ""; + } + + /** + * 清掉缓冲 + * @throws IOException + */ + public void flush() throws IOException { + //清空response的流,之后再添加进去 + flushBuffer(); + byte[] bytes = byteArrayOutputStream.toByteArray(); + response.getOutputStream().write(bytes); + } + +} \ No newline at end of file diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/config/TimestampConverter.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/config/TimestampConverter.java index 8d851a2..baaf923 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/config/TimestampConverter.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/config/TimestampConverter.java @@ -24,7 +24,7 @@ import org.springframework.stereotype.Component; import java.sql.Timestamp; /** - * @description: 转换时间戳类型 + * @description 转换时间戳类型 * @date 2020-05-22 */ @Component diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/ModelQueryMapper.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/ModelQueryMapper.java new file mode 100644 index 0000000..0e1003f --- /dev/null +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/ModelQueryMapper.java @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ +package org.dubhe.dao; +import com.baomidou.mybatisplus.core.mapper.BaseMapper; +import org.apache.ibatis.annotations.Param; +import org.apache.ibatis.annotations.Select; +import org.dubhe.domain.dto.ModelQueryDTO; +import org.dubhe.domain.entity.ModelQuery; +import org.dubhe.domain.entity.ModelQueryBrance; + +/** + * @description model mapper + * @date 2020-10-09 + */ +public interface ModelQueryMapper extends BaseMapper { + /** + * 根据modelId查询模型信息 + * + * @param modelId 模型id + * @return modelQuery返回查询的模型对象 + */ + @Select("select name,url from pt_model_info where id=#{modelId}") + ModelQuery findModelNameById(@Param("modelId") Integer modelId); + + /** + * 根据模型路径查询模型版本信息 + * + * @param modelLoadPathDir 模型路径 + * @return ModelQueryBrance 模型版本信息 + */ + @Select("select version from pt_model_branch where url=#{modelLoadPathDir}") + ModelQueryBrance findModelVersionByUrl(@Param("modelLoadPathDir") String modelLoadPathDir); +} diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/NoteBookMapper.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/NoteBookMapper.java index 4f977e7..1a95d68 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/NoteBookMapper.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/NoteBookMapper.java @@ -21,6 +21,7 @@ import com.baomidou.mybatisplus.core.mapper.BaseMapper; import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import org.apache.ibatis.annotations.Param; import org.apache.ibatis.annotations.Select; +import org.dubhe.annotation.DataPermission; import org.dubhe.domain.entity.NoteBook; import java.util.List; @@ -29,28 +30,27 @@ import java.util.List; * @description notebook mapper * @date 2020-04-28 */ +@DataPermission(ignoresMethod = {"insert","findByNamespaceAndResourceName","selectRunNotUrlList"}) public interface NoteBookMapper extends BaseMapper { /** * 根据名称查询 * * @param name - * @param userId * @param status * @return NoteBook */ - @Select("select * from notebook where notebook_name = #{name} and user_id = #{userId} and status != #{status} and deleted = 0 limit 1") - NoteBook findByNameAndUserId(@Param("name") String name, @Param("userId") long userId, @Param("status") Integer status); + @Select("select * from notebook where notebook_name = #{name} and status != #{status} and deleted = 0 limit 1") + NoteBook findByNameAndStatus(@Param("name") String name, @Param("status") Integer status); /** * 查询正在运行的notebook数量 * - * @param userId * @param status * @return int */ - @Select("select count(1) from notebook where user_id = #{userId} and status = #{status} and deleted = 0") - int selectRunNoteBookNum(@Param("userId") long userId, @Param("status") Integer status); + @Select("select count(1) from notebook where status = #{status} and deleted = 0") + int selectRunNoteBookNum( @Param("status") Integer status); /** * 根据namespace + resourceName查询 diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtImageMapper.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtImageMapper.java index d9255e3..974a03f 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtImageMapper.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtImageMapper.java @@ -19,12 +19,14 @@ package org.dubhe.dao; import com.baomidou.mybatisplus.core.mapper.BaseMapper; +import org.dubhe.annotation.DataPermission; import org.dubhe.domain.entity.PtImage; /** * @description 镜像 Mapper 接口 * @date 2020-04-27 */ +@DataPermission(ignoresMethod = {"insert"}) public interface PtImageMapper extends BaseMapper { } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainAlgorithmMapper.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainAlgorithmMapper.java index 1828c2c..8a4d571 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainAlgorithmMapper.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainAlgorithmMapper.java @@ -20,6 +20,7 @@ package org.dubhe.dao; import com.baomidou.mybatisplus.core.mapper.BaseMapper; import org.apache.ibatis.annotations.Param; import org.apache.ibatis.annotations.Select; +import org.dubhe.annotation.DataPermission; import org.dubhe.domain.entity.PtTrainAlgorithm; import java.util.List; @@ -28,6 +29,7 @@ import java.util.List; * @description 训练算法Mapper * @date 2020-04-27 */ +@DataPermission(ignoresMethod = {"insert"}) public interface PtTrainAlgorithmMapper extends BaseMapper { /** diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainAlgorithmUsageMapper.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainAlgorithmUsageMapper.java index 95d68bb..de786d3 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainAlgorithmUsageMapper.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainAlgorithmUsageMapper.java @@ -17,17 +17,17 @@ package org.dubhe.dao; +import org.dubhe.annotation.DataPermission; import org.dubhe.domain.entity.PtTrainAlgorithmUsage; import com.baomidou.mybatisplus.core.mapper.BaseMapper; /** * - * 用户辅助信息Mapper 接口 - *

- * - * @since 2020-06-23 + * @description 用户辅助信息Mapper 接口 + * @date 2020-06-23 */ +@DataPermission(ignoresMethod = "insert") public interface PtTrainAlgorithmUsageMapper extends BaseMapper { } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainJobMapper.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainJobMapper.java index 5d9edd7..08c992c 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainJobMapper.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainJobMapper.java @@ -21,6 +21,7 @@ import com.baomidou.mybatisplus.core.mapper.BaseMapper; import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import org.apache.ibatis.annotations.Param; import org.apache.ibatis.annotations.Select; +import org.dubhe.annotation.DataPermission; import org.dubhe.domain.entity.PtTrainJob; import org.dubhe.domain.vo.PtTrainVO; @@ -28,6 +29,7 @@ import org.dubhe.domain.vo.PtTrainVO; * @description 训练作业job Mapper 接口 * @date 2020-04-27 */ +@DataPermission(ignoresMethod = {"insert","selectCountByStatus","getPageTrain"}) public interface PtTrainJobMapper extends BaseMapper { /** diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainMapper.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainMapper.java index c70db46..3e8f025 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainMapper.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainMapper.java @@ -18,12 +18,14 @@ package org.dubhe.dao; import com.baomidou.mybatisplus.core.mapper.BaseMapper; +import org.dubhe.annotation.DataPermission; import org.dubhe.domain.entity.PtTrain; /** * @description 训练作业主 Mapper 接口 * @date 2020-04-27 */ +@DataPermission(ignoresMethod = {"insert"}) public interface PtTrainMapper extends BaseMapper { } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainParamMapper.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainParamMapper.java index 9ceff8e..ff08d99 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainParamMapper.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainParamMapper.java @@ -17,6 +17,7 @@ package org.dubhe.dao; +import org.dubhe.annotation.DataPermission; import org.dubhe.domain.entity.PtTrainParam; import com.baomidou.mybatisplus.core.mapper.BaseMapper; @@ -24,6 +25,7 @@ import com.baomidou.mybatisplus.core.mapper.BaseMapper; * @description 任务参数 Mapper 接口 * @date 2020-04-27 */ +@DataPermission(ignoresMethod = "insert") public interface PtTrainParamMapper extends BaseMapper { } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/BaseTrainJobDTO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/BaseTrainJobDTO.java index fa63396..21ee36b 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/BaseTrainJobDTO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/BaseTrainJobDTO.java @@ -40,4 +40,40 @@ public class BaseTrainJobDTO implements Serializable { private String outPath; private String logPath; private String visualizedLogPath; + private Integer delayCreateTime; + private Integer delayDeleteTime; + + /** + * @return 每个节点的GPU数量 + */ + public Integer getGpuNumPerNode(){ + return getPtTrainJobSpecs().getSpecsInfo().getInteger("gpuNum"); + } + + /** + * @return cpu数量 + */ + public Integer getCpuNum(){ + return getPtTrainJobSpecs().getSpecsInfo().getInteger("cpuNum"); + } + + /** + * @return memNum + */ + public Integer getMenNum(){ + return getPtTrainJobSpecs().getSpecsInfo().getInteger("memNum"); + } + /** + * "验证数据来源名称" + */ + private String valDataSourceName; + + /** + * 验证数据来源路径 + */ + private String valDataSourcePath; + /** + * 模型路径 + */ + private String modelLoadPathDir; } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/ModelQueryDTO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/ModelQueryDTO.java new file mode 100644 index 0000000..a10256c --- /dev/null +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/ModelQueryDTO.java @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ +package org.dubhe.domain.dto; +import io.swagger.annotations.ApiModel; +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; + +/** + * @description model 查询dto + * @date 2020-10-09 + */ +@Data +@ApiModel("模型查询") +public class ModelQueryDTO { + @ApiModelProperty(value = "模型类型") + private Integer modelResource; + @ApiModelProperty(value = "模型名称") + private String name; + @ApiModelProperty(value = "模型版本") + private String version; + @ApiModelProperty(value = "模型id") + private Integer id; + @ApiModelProperty(value = "模型路径") + private String modelPath; +} diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/NoteBookListQueryDTO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/NoteBookListQueryDTO.java index 485d3d3..f388fd8 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/NoteBookListQueryDTO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/NoteBookListQueryDTO.java @@ -30,7 +30,6 @@ import java.io.Serializable; @Data public class NoteBookListQueryDTO implements Serializable { - @Query(propName = "status", type = Query.Type.EQ) @ApiModelProperty("0运行中,1停止, 2删除, 3启动中,4停止中,5删除中,6运行异常(暂未启用)") private Integer status; @@ -38,7 +37,7 @@ public class NoteBookListQueryDTO implements Serializable { @ApiModelProperty("notebook名称") private String noteBookName; - @Query(propName = "user_id", type = Query.Type.EQ) + @Query(propName = "origin_user_id", type = Query.Type.EQ) @ApiModelProperty(value = "所属用户ID", hidden = true) private Long userId; diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/NoteBookQueryDTO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/NoteBookQueryDTO.java index 93712f8..faa1a38 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/NoteBookQueryDTO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/NoteBookQueryDTO.java @@ -42,7 +42,7 @@ public class NoteBookQueryDTO implements Serializable { @Query(propName = "k8s_pvc_path", type = Query.Type.EQ) private String k8sPvcPath; - @Query(propName = "user_id", type = Query.Type.EQ) + @Query(propName = "origin_user_id", type = Query.Type.EQ) private Long userId; @Query(propName = "last_operation_timeout", type = Query.Type.LT) diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtImageDTO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtImageDTO.java index 85e0d41..94adff4 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtImageDTO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtImageDTO.java @@ -66,4 +66,6 @@ public class PtImageDTO implements Serializable { @ApiModelProperty("删除(0正常,1已删除)") private Boolean deleted; + @ApiModelProperty("资源拥有者ID") + private Long originUserId; } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtImageDeleteDTO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtImageDeleteDTO.java new file mode 100644 index 0000000..d332df4 --- /dev/null +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtImageDeleteDTO.java @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ +package org.dubhe.domain.dto; + +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; +import lombok.experimental.Accessors; + +import javax.validation.constraints.NotNull; +import java.io.Serializable; +import java.util.List; + +/** + * @description 训练镜像删除DTO + * @date 2020-08-13 + */ +@Data +@Accessors(chain = true) +public class PtImageDeleteDTO implements Serializable { + private static final long serialVersionUID = 1L; + + @ApiModelProperty(value = "id", required = true) + @NotNull(message = "镜像id不能为空") + private List ids; +} diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtImageQueryDTO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtImageQueryDTO.java index e1bc556..17cc7ce 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtImageQueryDTO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtImageQueryDTO.java @@ -40,4 +40,7 @@ public class PtImageQueryDTO extends PageQueryBase implements Serializable { @ApiModelProperty(value = "镜像状态,0为制作中,1位制作成功,2位制作失败") private Integer imageStatus; + @ApiModelProperty(value = "镜像名称或id") + private String imageNameOrId; + } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtImageUpdateDTO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtImageUpdateDTO.java new file mode 100644 index 0000000..6f08648 --- /dev/null +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtImageUpdateDTO.java @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ +package org.dubhe.domain.dto; + +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; +import lombok.experimental.Accessors; +import org.dubhe.utils.TrainUtil; +import org.hibernate.validator.constraints.Length; + +import javax.validation.constraints.NotNull; +import java.io.Serializable; +import java.util.List; + +/** + * @description 训练镜像信息修改DTO + * @date 2020-08-13 + */ +@Data +@Accessors(chain = true) +public class PtImageUpdateDTO implements Serializable { + + private static final long serialVersionUID = 1L; + + @ApiModelProperty(value = "id", required = true) + @NotNull(message = "镜像id不能为空") + private List ids; + + @ApiModelProperty("镜像描述") + @Length(max = TrainUtil.NUMBER_ONE_THOUSAND_AND_TWENTY_FOUR, message = "镜像描述-输入长度不能超过1024个字符") + private String remark; + +} diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtImageUploadDTO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtImageUploadDTO.java index a3ae025..970186b 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtImageUploadDTO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtImageUploadDTO.java @@ -55,7 +55,7 @@ public class PtImageUploadDTO implements Serializable { @Pattern(regexp = TrainUtil.REGEXP_TAG, message = "镜像版本号支持字母、数字、英文横杠、英文.号和下划线") private String imageTag; - @ApiModelProperty("备注") - @Length(max = TrainUtil.NUMBER_ONE_THOUSAND_AND_TWENTY_FOUR, message = "备注-输入长度不能超过1024个字符") + @ApiModelProperty("镜像描述") + @Length(max = TrainUtil.NUMBER_ONE_THOUSAND_AND_TWENTY_FOUR, message = "镜像描述-输入长度不能超过1024个字符") private String remark; } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainAlgorithmCreateDTO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainAlgorithmCreateDTO.java index 4109495..2370ef2 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainAlgorithmCreateDTO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainAlgorithmCreateDTO.java @@ -85,4 +85,7 @@ public class PtTrainAlgorithmCreateDTO implements Serializable { @ApiModelProperty("noteBookId") private Long noteBookId; + @ApiModelProperty("资源拥有者ID") + private Long originUserId; + } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainAlgorithmUsageCreateDTO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainAlgorithmUsageCreateDTO.java index 4dba0f8..42e6a84 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainAlgorithmUsageCreateDTO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainAlgorithmUsageCreateDTO.java @@ -17,7 +17,6 @@ package org.dubhe.domain.dto; -import io.swagger.annotations.ApiModel; import io.swagger.annotations.ApiModelProperty; import lombok.Data; import lombok.experimental.Accessors; diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainJobCreateDTO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainJobCreateDTO.java index d380647..7ccc0b6 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainJobCreateDTO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainJobCreateDTO.java @@ -76,10 +76,13 @@ public class PtTrainJobCreateDTO extends BaseImageDTO { @NotNull(message = "类型(0为CPU,1为GPU)不能为空") private Integer resourcesPoolType; - @ApiModelProperty(value = "规格类型Id", required = true) - @NotNull(message = "规格类型id不能为空") - @Min(value = TrainUtil.NUMBER_ONE, message = "规格类型id必须不小于1") - private Integer trainJobSpecsId; + @ApiModelProperty(value = "规格名称", required = true) + @NotNull(message = "规格名称不能为空") + private String trainJobSpecsName; + + @ApiModelProperty(value = "规格信息", required = true) + @NotNull(message = "规格信息不能为空") + private String trainJobSpecsInfo; @ApiModelProperty("true代表保存作业参数") private Boolean saveParams; @@ -93,4 +96,55 @@ public class PtTrainJobCreateDTO extends BaseImageDTO { @Length(max = TrainUtil.NUMBER_TWO_HUNDRED_AND_FIFTY_FIVE, message = "作业参数描述-输入长度不能超过255个字符") private String trainParamDesc; + @ApiModelProperty(value = "训练类型 0:普通训练,1:分布式训练", required = true) + @Min(value = TrainUtil.NUMBER_ZERO, message = "训练类型错误") + @Max(value = TrainUtil.NUMBER_ONE, message = "训练类型错误") + @NotNull(message = "训练类型(0为普通训练,1为分布式训练)不能为空") + private Integer trainType; + + @ApiModelProperty(value = "节点个数", required = true) + @Min(value = TrainUtil.NUMBER_ONE, message = "节点个数在1~8之间") + @Max(value = TrainUtil.NUMBER_EIGHT, message = "节点个数在1~8之间") + @NotNull(message = "节点个数") + private Integer resourcesPoolNode; + + @ApiModelProperty("验证数据来源名称") + private String valDataSourceName; + + @ApiModelProperty("验证数据来源路径") + private String valDataSourcePath; + + @ApiModelProperty("是否验证数据集") + private Integer valType; + + @ApiModelProperty(value = "训练延时启动时长,单位为小时") + @Min(value = TrainUtil.NUMBER_ZERO, message = "训练延时启动时长不能小于0小时") + @Max(value = TrainUtil.NUMBER_ONE_HUNDRED_AND_SIXTY_EIGHT, message = "训练延时启动时长不能大于168小时即时长不能超过一周(7*24小时)") + private Integer delayCreateTime; + + @ApiModelProperty(value = "训练自动停止时长,单位为小时") + @Min(value = TrainUtil.NUMBER_ZERO, message = "训练自动停止时长不能小于0小时") + @Max(value = TrainUtil.NUMBER_ONE_HUNDRED_AND_SIXTY_EIGHT, message = "训练自动停止时长不能大于168小时即时长不能超过一周(7*24小时)") + private Integer delayDeleteTime; + + @ApiModelProperty("资源拥有者ID") + private Long originUserId; + + @ApiModelProperty(value = "训练信息(失败信息)") + private String trainMsg; + + @ApiModelProperty(value = "是否打开模型选择") + private Integer modelType; + + @ApiModelProperty(value = "模型类型(0我的模型1预置模型)") + private Integer modelResource; + + @ApiModelProperty(value = "模型id") + private Integer modelId; + + @ApiModelProperty(value = "模型路径") + private String modelLoadPathDir; + + @ApiModelProperty(value = "模型名称") + private String modelName; } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainJobDTO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainJobDTO.java index 306dca3..064ef8a 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainJobDTO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainJobDTO.java @@ -50,4 +50,5 @@ public class PtTrainJobDTO implements Serializable { private Timestamp createTime; private Timestamp updateTime; private Boolean deleted; + private Long originUserId; } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainJobResumeDTO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainJobResumeDTO.java index 2637c7c..ab6e87c 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainJobResumeDTO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainJobResumeDTO.java @@ -27,8 +27,8 @@ import javax.validation.constraints.NotNull; import java.io.Serializable; /** - * @description: 恢复训练 - * @date: 2020-04-27 + * @description 恢复训练 + * @date 2020-04-27 */ @Data @Accessors(chain = true) @@ -41,4 +41,8 @@ public class PtTrainJobResumeDTO implements Serializable { @Min(value = TrainUtil.NUMBER_ONE, message = "id数值不合法") private Long id; + @ApiModelProperty(value = "path", required = true) + @NotNull(message = "path不能为空") + private String path; + } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainJobUpdateDTO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainJobUpdateDTO.java index 401f454..6b5e606 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainJobUpdateDTO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainJobUpdateDTO.java @@ -78,9 +78,58 @@ public class PtTrainJobUpdateDTO extends BaseImageDTO { @NotNull(message = "类型(0为CPU,1为GPU)不能为空") private Integer resourcesPoolType; - @ApiModelProperty(value = "规格类型Id", required = true) - @NotNull(message = "规格类型id不能为空") - @Min(value = TrainUtil.NUMBER_ONE, message = "规格类型id必须不小于1") - private Integer trainJobSpecsId; + @ApiModelProperty(value = "规格名称", required = true) + @NotNull(message = "规格名称不能为空") + private String trainJobSpecsName; + + @ApiModelProperty(value = "规格信息", required = true) + @NotNull(message = "规格信息不能为空") + private String trainJobSpecsInfo; + + @ApiModelProperty(value = "训练类型 0:普通训练,1:分布式训练", required = true) + @Min(value = TrainUtil.NUMBER_ZERO, message = "训练类型错误") + @Max(value = TrainUtil.NUMBER_ONE, message = "训练类型错误") + @NotNull(message = "训练类型(0为普通训练,1为分布式训练)不能为空") + private Integer trainType; + + @ApiModelProperty(value = "节点个数", required = true) + @Min(value = TrainUtil.NUMBER_ONE, message = "节点个数在1~8之间") + @Max(value = TrainUtil.NUMBER_EIGHT, message = "节点个数在1~8之间") + @NotNull(message = "节点个数") + private Integer resourcesPoolNode; + + @ApiModelProperty("验证数据来源名称") + private String valDataSourceName; + + @ApiModelProperty("验证数据来源路径") + private String valDataSourcePath; + + @ApiModelProperty("是否验证数据集") + private Integer valType; + + @ApiModelProperty(value = "训练延时启动时长,单位为小时") + @Min(value = TrainUtil.NUMBER_ZERO, message = "训练延时启动时长不能小于0小时") + @Max(value = TrainUtil.NUMBER_ONE_HUNDRED_AND_SIXTY_EIGHT, message = "训练延时启动时长不能大于168小时即时长不能超过一周(7*24小时)") + private Integer delayCreateTime; + + @ApiModelProperty(value = "训练自动停止时长,单位为小时") + @Min(value = TrainUtil.NUMBER_ZERO, message = "训练自动停止时长必须不能小于0小时") + @Max(value = TrainUtil.NUMBER_ONE_HUNDRED_AND_SIXTY_EIGHT, message = "训练自动停止时长不能大于168小时即时长不能超过一周(7*24小时)") + private Integer delayDeleteTime; + + @ApiModelProperty(value = "是否打开模型选择") + private Integer modelType; + + @ApiModelProperty(value = "模型类型(0我的模型1预置模型)") + private Integer modelResource; + + @ApiModelProperty(value = "模型名称") + private String modelName; + + @ApiModelProperty(value = "模型加载路径") + private String modelLoadPathDir; + + @ApiModelProperty(value = "模型id") + private Integer modelId; } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainParamCreateDTO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainParamCreateDTO.java index 1b9e636..1b399f7 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainParamCreateDTO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainParamCreateDTO.java @@ -1,12 +1,12 @@ /** * Copyright 2020 Zhejiang Lab. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -76,9 +76,44 @@ public class PtTrainParamCreateDTO extends BaseImageDTO { @NotNull(message = "类型(0为CPU,1为GPU)不能为空") private Integer resourcesPoolType; - @ApiModelProperty(value = "规格类型Id", required = true) - @NotNull(message = "规格类型id不能为空") - @Min(value = TrainUtil.NUMBER_ONE, message = "规格类型id必须不小于1") - private Integer trainJobSpecsId; + @ApiModelProperty(value = "规格名称", required = true) + @NotNull(message = "规格名称不能为空") + private String trainJobSpecsName; + + @ApiModelProperty("验证数据来源名称") + private String valDataSourceName; + + @ApiModelProperty("验证数据来源路径") + private String valDataSourcePath; + + @ApiModelProperty("是否验证数据集") + private Integer valType; + + @ApiModelProperty(value = "是否打开模型选择") + private Integer modelType; + + @ApiModelProperty(value = "模型类型(0我的模型1预置模型)") + private Integer modelResource; + + @ApiModelProperty(value = "模型名称") + private String modelName; + + @ApiModelProperty(value = "模型加载路径") + private String modelLoadPathDir; + + @ApiModelProperty(value = "模型id") + private Integer modelId; + + @ApiModelProperty(value = "训练类型 0:普通训练,1:分布式训练", required = true) + @Min(value = TrainUtil.NUMBER_ZERO, message = "训练类型错误") + @Max(value = TrainUtil.NUMBER_ONE, message = "训练类型错误") + @NotNull(message = "训练类型(0为普通训练,1为分布式训练)不能为空") + private Integer trainType; + + @ApiModelProperty(value = "节点个数", required = true) + @Min(value = TrainUtil.NUMBER_ONE, message = "节点个数在1~8之间") + @Max(value = TrainUtil.NUMBER_EIGHT, message = "节点个数在1~8之间") + @NotNull(message = "节点个数") + private Integer resourcesPoolNode; } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainParamUpdateDTO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainParamUpdateDTO.java index 59675e0..96506fe 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainParamUpdateDTO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainParamUpdateDTO.java @@ -81,9 +81,42 @@ public class PtTrainParamUpdateDTO extends BaseImageDTO { @NotNull(message = "类型(0为CPU,1为GPU)不能为空") private Integer resourcesPoolType; - @ApiModelProperty(value = "规格类型Id", required = true) - @NotNull(message = "规格类型id不能为空") - @Min(value = TrainUtil.NUMBER_ONE, message = "规格类型id必须不小于1") - private Integer trainJobSpecsId; + @ApiModelProperty(value = "规格名称", required = true) + @NotNull(message = "规格名称不能为空") + private String trainJobSpecsName; + + @ApiModelProperty("验证数据来源名称") + private String valDataSourceName; + + @ApiModelProperty("验证数据来源路径") + private String valDataSourcePath; + + @ApiModelProperty("是否验证数据集") + private Integer valType; + + @ApiModelProperty(value = "是否打开模型选择") + private Integer modelType; + + @ApiModelProperty(value = "模型类型(0我的模型1预置模型)") + private Integer modelResource; + + @ApiModelProperty(value = "模型名称") + private String modelName; + + @ApiModelProperty(value = "模型加载路径") + private String modelLoadPathDir; + + @ApiModelProperty(value = "模型id") + private Integer modelId; + + @ApiModelProperty(value = "训练类型 0:普通训练,1:分布式训练") + @Min(value = TrainUtil.NUMBER_ZERO, message = "训练类型错误") + @Max(value = TrainUtil.NUMBER_ONE, message = "训练类型错误") + private Integer trainType; + + @ApiModelProperty(value = "节点个数") + @Min(value = TrainUtil.NUMBER_ONE, message = "节点个数在1~8之间") + @Max(value = TrainUtil.NUMBER_EIGHT, message = "节点个数在1~8之间") + private Integer resourcesPoolNode; } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/ModelQuery.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/ModelQuery.java new file mode 100644 index 0000000..b6366ba --- /dev/null +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/ModelQuery.java @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ +package org.dubhe.domain.entity; +import com.baomidou.mybatisplus.annotation.TableField; +import com.baomidou.mybatisplus.annotation.TableName; +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; + +/** + * @description model详情 + * @date 2020-10-09 + */ +@Data +@TableName("pt_model_info") +public class ModelQuery { + @TableField(value = "model_resource") + @ApiModelProperty(value = "模型类型") + private Integer modelResource; + @TableField(value = "name") + @ApiModelProperty(value = "模型名称") + private String name; + @TableField(value = "model_version") + @ApiModelProperty(value = "模型版本") + private String modelVersion; + @TableField(value = "id") + @ApiModelProperty(value = "模型id") + private Integer id; + @TableField(value = "url") + @ApiModelProperty(value = "模型路径") + private String url; + +} diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/ModelQueryBrance.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/ModelQueryBrance.java new file mode 100644 index 0000000..c4c62cb --- /dev/null +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/ModelQueryBrance.java @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ +package org.dubhe.domain.entity; +import com.baomidou.mybatisplus.annotation.TableField; +import com.baomidou.mybatisplus.annotation.TableName; +import io.swagger.annotations.ApiModelProperty; +import lombok.Data; +/** + * @description model 模型版本 + * @date 2020-10-09 + */ +@Data +@TableName("pt_model_branch") +public class ModelQueryBrance { + @TableField(value = "model_resource") + @ApiModelProperty(value = "模型类型") + private Integer modelResource; + @TableField(value = "name") + @ApiModelProperty(value = "模型名称") + private String name; + @TableField(value = "version") + @ApiModelProperty(value = "模型版本") + private String version; + @TableField(value = "id") + @ApiModelProperty(value = "模型id") + private Integer id; + @TableField(value = "url") + @ApiModelProperty(value = "模型路径") + private String url; + +} diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/NoteBook.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/NoteBook.java index 2d822a9..21196a9 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/NoteBook.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/NoteBook.java @@ -17,16 +17,14 @@ package org.dubhe.domain.entity; -import com.baomidou.mybatisplus.annotation.IdType; -import com.baomidou.mybatisplus.annotation.TableField; -import com.baomidou.mybatisplus.annotation.TableId; -import com.baomidou.mybatisplus.annotation.TableName; +import com.baomidou.mybatisplus.annotation.*; import com.fasterxml.jackson.annotation.JsonIgnore; import io.swagger.annotations.ApiModelProperty; import lombok.Data; +import org.dubhe.base.BaseEntity; -import javax.validation.constraints.*; -import java.io.Serializable; +import javax.validation.constraints.NotNull; +import javax.validation.constraints.Size; import java.util.Date; /** @@ -35,13 +33,13 @@ import java.util.Date; */ @Data @TableName("notebook") -public class NoteBook implements Serializable { +public class NoteBook extends BaseEntity { @TableId(value = "id", type = IdType.AUTO) @ApiModelProperty(hidden = true) private Long id; - @TableField(value = "user_id") + @TableField(value = "origin_user_id",fill = FieldFill.INSERT) @ApiModelProperty(hidden = true) private Long userId; @@ -55,7 +53,8 @@ public class NoteBook implements Serializable { @TableField(value = "description") private String description; - @TableField(value = "url") + public final static String COLUMN_URL = "url"; + @TableField(value = COLUMN_URL) @ApiModelProperty(hidden = true) private String url; @@ -79,10 +78,12 @@ public class NoteBook implements Serializable { @ApiModelProperty(value = "硬盘内存大小") private Integer diskMemNum; + + public final static String COLUMN_STATUS = "status"; /** * 0运行中,1停止, 2删除, 3启动中,4停止中,5删除中,6运行异常(暂未启用) */ - @TableField(value = "status") + @TableField(value = COLUMN_STATUS) @ApiModelProperty(hidden = true) private Integer status; @@ -130,26 +131,6 @@ public class NoteBook implements Serializable { @ApiModelProperty(hidden = true) private String k8sPvcPath; - @TableField(value = "create_time") - @ApiModelProperty(hidden = true) - private Date createTime; - - @TableField(value = "create_user_id") - @ApiModelProperty(hidden = true) - private Long createUserId; - - @TableField(value = "update_time") - @ApiModelProperty(hidden = true) - private Date updateTime; - - @TableField(value = "update_user_id") - @ApiModelProperty(hidden = true) - private Long updateUserId; - - @TableField(value = "deleted") - @ApiModelProperty(hidden = true) - private Integer deleted; - @TableField(value = "data_source_name") @ApiModelProperty(hidden = true) @Size(max = 255, message = "数据集名称超长") @@ -189,11 +170,6 @@ public class NoteBook implements Serializable { ", k8sImageName='" + k8sImageName + '\'' + ", k8sMountPath='" + k8sMountPath + '\'' + ", k8sPvcPath='" + k8sPvcPath + '\'' + - ", createTime=" + createTime + - ", createUserId=" + createUserId + - ", updateTime=" + updateTime + - ", updateUserId=" + updateUserId + - ", deleted=" + deleted + ", dataSourceName='" + dataSourceName + '\'' + ", dataSourcePath='" + dataSourcePath + '\'' + ", algorithmId=" + algorithmId + diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/PtImage.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/PtImage.java index cc64ce1..c99fde3 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/PtImage.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/PtImage.java @@ -17,10 +17,7 @@ package org.dubhe.domain.entity; -import com.baomidou.mybatisplus.annotation.IdType; -import com.baomidou.mybatisplus.annotation.TableField; -import com.baomidou.mybatisplus.annotation.TableId; -import com.baomidou.mybatisplus.annotation.TableName; +import com.baomidou.mybatisplus.annotation.*; import lombok.Data; import lombok.EqualsAndHashCode; import lombok.experimental.Accessors; @@ -84,4 +81,10 @@ public class PtImage extends BaseEntity { */ @TableField(value = "image_status") private Integer imageStatus; + + /** + * 资源拥有者ID + */ + @TableField(value = "origin_user_id",fill = FieldFill.INSERT) + private Long originUserId; } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/PtJobParam.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/PtJobParam.java index 4cba152..615c8ad 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/PtJobParam.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/PtJobParam.java @@ -28,6 +28,8 @@ import lombok.EqualsAndHashCode; import lombok.experimental.Accessors; import org.dubhe.base.BaseEntity; +import java.sql.Timestamp; + /** * @description job运行参数及结果 * @date 2020-04-27 @@ -98,4 +100,16 @@ public class PtJobParam extends BaseEntity { @TableField(value = "param_accuracy") private String paramAccuracy; + /** + *训练延时启动时间 + */ + @TableField(value = "delay_create_time") + private Timestamp delayCreateTime; + + /** + *训练自动停止时间 + */ + @TableField(value = "delay_delete_time") + private Timestamp delayDeleteTime; + } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/PtTrain.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/PtTrain.java index 9f6251c..b227523 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/PtTrain.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/PtTrain.java @@ -17,10 +17,7 @@ package org.dubhe.domain.entity; -import com.baomidou.mybatisplus.annotation.IdType; -import com.baomidou.mybatisplus.annotation.TableField; -import com.baomidou.mybatisplus.annotation.TableId; -import com.baomidou.mybatisplus.annotation.TableName; +import com.baomidou.mybatisplus.annotation.*; import lombok.Data; import lombok.EqualsAndHashCode; import lombok.experimental.Accessors; @@ -65,4 +62,10 @@ public class PtTrain extends BaseEntity { */ @TableField(value = "total_num") private Integer totalNum; + + /** + * 资源拥有者ID + */ + @TableField(value = "origin_user_id",fill = FieldFill.INSERT) + private Long originUserId; } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/PtTrainAlgorithm.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/PtTrainAlgorithm.java index 8f1f6b4..ecea1dd 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/PtTrainAlgorithm.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/PtTrainAlgorithm.java @@ -18,10 +18,7 @@ package org.dubhe.domain.entity; import com.alibaba.fastjson.JSONObject; -import com.baomidou.mybatisplus.annotation.IdType; -import com.baomidou.mybatisplus.annotation.TableField; -import com.baomidou.mybatisplus.annotation.TableId; -import com.baomidou.mybatisplus.annotation.TableName; +import com.baomidou.mybatisplus.annotation.*; import com.baomidou.mybatisplus.extension.handlers.FastjsonTypeHandler; import lombok.Data; import lombok.EqualsAndHashCode; @@ -126,4 +123,15 @@ public class PtTrainAlgorithm extends BaseEntity { @TableField(value = "is_visualized_log") private Boolean isVisualizedLog; + /** + * 算法状态 + */ + @TableField(value = "algorithm_status") + private Integer algorithmStatus; + + /** + * 资源拥有者ID + */ + @TableField(value = "origin_user_id",fill = FieldFill.INSERT) + private Long originUserId; } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/PtTrainAlgorithmUsage.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/PtTrainAlgorithmUsage.java index 9739e7a..722098e 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/PtTrainAlgorithmUsage.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/PtTrainAlgorithmUsage.java @@ -17,10 +17,7 @@ package org.dubhe.domain.entity; -import com.baomidou.mybatisplus.annotation.IdType; -import com.baomidou.mybatisplus.annotation.TableField; -import com.baomidou.mybatisplus.annotation.TableId; -import com.baomidou.mybatisplus.annotation.TableName; +import com.baomidou.mybatisplus.annotation.*; import lombok.Data; import lombok.EqualsAndHashCode; import lombok.experimental.Accessors; @@ -49,8 +46,8 @@ public class PtTrainAlgorithmUsage extends BaseEntity { /** * 用户id */ - @TableField(value = "user_id") - private Long userId; + @TableField(value = "origin_user_id",fill = FieldFill.INSERT) + private Long originUserId; /** * 类型 @@ -64,11 +61,4 @@ public class PtTrainAlgorithmUsage extends BaseEntity { @TableField(value = "aux_info") private String auxInfo; - /** - * 是否为默认值(0否,1是默认值) - */ - @TableField(value = "is_default") - private Boolean isDefault; - - } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/PtTrainJob.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/PtTrainJob.java index 2e7f752..f78c1af 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/PtTrainJob.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/PtTrainJob.java @@ -17,10 +17,7 @@ package org.dubhe.domain.entity; -import com.baomidou.mybatisplus.annotation.IdType; -import com.baomidou.mybatisplus.annotation.TableField; -import com.baomidou.mybatisplus.annotation.TableId; -import com.baomidou.mybatisplus.annotation.TableName; +import com.baomidou.mybatisplus.annotation.*; import lombok.Data; import lombok.EqualsAndHashCode; import lombok.experimental.Accessors; @@ -104,10 +101,10 @@ public class PtTrainJob extends BaseEntity { private String logPath; /** - * 规格Id + * 规格名称 */ - @TableField(value = "train_job_specs_id") - private Integer trainJobSpecsId; + @TableField(value = "train_job_specs_name") + private String trainJobSpecsName; /** * 类型(0为CPU,1为GPU) @@ -145,4 +142,66 @@ public class PtTrainJob extends BaseEntity { @TableField(value = "k8s_job_name") private String k8sJobName; + /** + * 训练类型,0:普通训练,1:分布式训练 + */ + @TableField(value = "train_type") + private Integer trainType; + + /** + * 验证数据集来源名称 + */ + @TableField(value = "val_data_source_name") + private String valDataSourceName; + + /** + * 验证数据集来源路径 + */ + @TableField(value = "val_data_source_path") + private String valDataSourcePath; + + /** + * 是否验证数据集 + */ + @TableField(value = "val_type") + private Integer valType; + + /** + * 资源拥有者ID + */ + @TableField(value = "origin_user_id",fill = FieldFill.INSERT) + private Long originUserId; + + /** + * 是否打开模型原则 + */ + @TableField(value = "model_type") + private Integer modelType; + /** + * 模型来源 + */ + @TableField(value = "model_resource") + private Integer modelResource; + /** + * 模型加载路径 + */ + @TableField(value = "model_load_dir") + private String modelLoadPathDir; + /** + * 模型名称 + */ + @TableField(value = "model_name") + private String modelName; + /** + * 模型id + */ + @TableField(value = "model_id") + private Integer modelId; + + /** + * 训练信息(失败信息) + */ + @TableField(value = "train_msg") + private String trainMsg; + } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/PtTrainParam.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/PtTrainParam.java index 34543e6..f1ac63e 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/PtTrainParam.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/entity/PtTrainParam.java @@ -18,10 +18,7 @@ package org.dubhe.domain.entity; import com.alibaba.fastjson.JSONObject; -import com.baomidou.mybatisplus.annotation.IdType; -import com.baomidou.mybatisplus.annotation.TableField; -import com.baomidou.mybatisplus.annotation.TableId; -import com.baomidou.mybatisplus.annotation.TableName; +import com.baomidou.mybatisplus.annotation.*; import com.baomidou.mybatisplus.extension.handlers.FastjsonTypeHandler; import lombok.Data; import lombok.EqualsAndHashCode; @@ -111,10 +108,10 @@ public class PtTrainParam extends BaseEntity { private String logPath; /** - * 规格Id + * 规格名称 */ - @TableField(value = "train_job_specs_id") - private Integer trainJobSpecsId; + @TableField(value = "train_job_specs_name") + private String trainJobSpecsName; /** * 类型(0为CPU,1为GPU) @@ -127,4 +124,61 @@ public class PtTrainParam extends BaseEntity { */ @TableField(value = "resources_pool_node") private Integer resourcesPoolNode; + + /** + * 验证数据集来源名称 + */ + @TableField(value = "val_data_source_name") + private String valDataSourceName; + + /** + * 验证数据集来源路径 + */ + @TableField(value = "val_data_source_path") + private String valDataSourcePath; + + /** + * 是否验证数据集 + */ + @TableField(value = "val_type") + private Integer valType; + /** + * 模型名称 + */ + @TableField(value = "model_name") + private String modelName; + /** + * 模型id + */ + @TableField(value = "model_id") + private Integer modelId; + /** + * 模型来源 + */ + @TableField(value = "model_resource") + private Integer modelResource; + /** + * 模型类型 + */ + @TableField(value = "model_type") + private Integer modelType; + /** + * 模型路径 + */ + @TableField(value = "model_load_dir") + private String modelLoadPathDir; + + /** + * 训练类型,0:普通训练,1:分布式训练 + */ + @TableField(value = "train_type") + private Integer trainType; + + /** + * 资源拥有者ID + */ + @TableField(value = "origin_user_id",fill = FieldFill.INSERT) + private Long originUserId; + } + diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/NoteBookVO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/NoteBookVO.java index 65c1470..d4328b6 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/NoteBookVO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/NoteBookVO.java @@ -18,14 +18,12 @@ package org.dubhe.domain.vo; import cn.hutool.core.date.DatePattern; -import com.baomidou.mybatisplus.annotation.TableField; import com.fasterxml.jackson.annotation.JsonFormat; import com.fasterxml.jackson.annotation.JsonIgnore; import io.swagger.annotations.ApiModel; import io.swagger.annotations.ApiModelProperty; import lombok.Data; -import javax.validation.constraints.Size; import java.io.Serializable; import java.util.Date; @@ -112,4 +110,6 @@ public class NoteBookVO implements Serializable { @ApiModelProperty("算法ID") private Long algorithmId; + @ApiModelProperty("资源拥有者ID") + private Long originUserId; } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtDoJobResultVO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtDoJobResultVO.java index 9f7669f..97e6775 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtDoJobResultVO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtDoJobResultVO.java @@ -21,8 +21,8 @@ import lombok.Data; import lombok.experimental.Accessors; /** - * @description: doJob返回封装 - * @Date:2020-07-03 + * @description doJob返回封装 + * @date 2020-07-03 */ @Data @Accessors(chain = true) diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtImageAndAlgorithmVO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtImageAndAlgorithmVO.java index 124efdd..260d761 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtImageAndAlgorithmVO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtImageAndAlgorithmVO.java @@ -24,7 +24,7 @@ import lombok.experimental.Accessors; import java.io.Serializable; /** - * @description: 镜像 + * @description 镜像 * @date 2020-04-27 */ @Data diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtImageQueryVO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtImageQueryVO.java index 033a79b..9251bd2 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtImageQueryVO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtImageQueryVO.java @@ -24,8 +24,8 @@ import java.io.Serializable; import java.sql.Timestamp; /** - * @description: 返回镜像查询结果 - * @date: 2020-04-27 + * @description 返回镜像查询结果 + * @date 2020-04-27 */ @Data public class PtImageQueryVO implements Serializable { @@ -52,4 +52,6 @@ public class PtImageQueryVO implements Serializable { @ApiModelProperty("创建时间") private Timestamp createTime; + @ApiModelProperty("资源拥有者ID") + private Long originUserId; } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtJobMetricsGrafanaVO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtJobMetricsGrafanaVO.java index a367563..3532c00 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtJobMetricsGrafanaVO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtJobMetricsGrafanaVO.java @@ -33,4 +33,7 @@ public class PtJobMetricsGrafanaVO implements Serializable { @ApiModelProperty("job监控地址") private String jobMetricsGrafanaUrl; + + @ApiModelProperty("job对应k8s中的podName") + private String jobPodName; } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainAlgorithmQueryVO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainAlgorithmQueryVO.java index cdb2278..28221ea 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainAlgorithmQueryVO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainAlgorithmQueryVO.java @@ -25,7 +25,7 @@ import java.io.Serializable; import java.sql.Timestamp; /** - * @description: 训练算法返回列表 + * @description 训练算法返回列表 * @date 2020-04-27 */ @Data @@ -48,8 +48,8 @@ public class PtTrainAlgorithmQueryVO implements Serializable { @ApiModelProperty(value = "镜像名称") private String imageName; - @ApiModelProperty(value = "镜像Project") - private String imageNameProject; + @ApiModelProperty(value = "算法文件大小") + private String algorithmFileSize; @ApiModelProperty(value = "镜像版本") private String imageTag; @@ -84,4 +84,6 @@ public class PtTrainAlgorithmQueryVO implements Serializable { @ApiModelProperty(value = "可视化日志(1是,0否)") private Boolean isVisualizedLog; + @ApiModelProperty("资源拥有者ID") + private Long originUserId; } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainAlgorithmUsageQueryVO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainAlgorithmUsageQueryVO.java index b5e7db9..10b54a9 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainAlgorithmUsageQueryVO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainAlgorithmUsageQueryVO.java @@ -24,7 +24,7 @@ import java.io.Serializable; import java.sql.Timestamp; /** - * @description: 算法用途返回列表 + * @description 算法用途返回列表 * @date 2020-06-23 */ @Data @@ -35,18 +35,12 @@ public class PtTrainAlgorithmUsageQueryVO implements Serializable { @ApiModelProperty(value = "ID") private Long id; - @ApiModelProperty(value = "用户ID") - private Long userId; - @ApiModelProperty(value = "类型") private String type; @ApiModelProperty(value = "辅助信息") private String auxInfo; - @ApiModelProperty(value = "辅助信息") - private Boolean deleted; - @ApiModelProperty(value = "创建时间") private Timestamp createTime; diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainDataSourceStatusQueryVO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainDataSourceStatusQueryVO.java index 712b9eb..3d318d1 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainDataSourceStatusQueryVO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainDataSourceStatusQueryVO.java @@ -24,8 +24,8 @@ import lombok.Setter; import java.io.Serializable; /** - * @description: 查询数据集状态查询结果 - * @date: 2020-05-21 + * @description 查询数据集状态查询结果 + * @date 2020-05-21 */ @Data @Setter diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainJobDeleteVO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainJobDeleteVO.java index 1a03ee0..af0d1a8 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainJobDeleteVO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainJobDeleteVO.java @@ -23,8 +23,8 @@ import lombok.Data; import java.io.Serializable; /** - * @description: 返回删除训练任务结果 - * @date: 2020-04-28 + * @description 返回删除训练任务结果 + * @date 2020-04-28 */ @Data public class PtTrainJobDeleteVO implements Serializable { diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainJobDetailQueryVO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainJobDetailQueryVO.java index 6228c3e..efd7962 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainJobDetailQueryVO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainJobDetailQueryVO.java @@ -1,12 +1,12 @@ /** * Copyright 2020 Zhejiang Lab. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -26,8 +26,8 @@ import java.io.Serializable; import java.sql.Timestamp; /** - * @description: 根据jobId查询训练任务详情返回结果 - * @date: 2020-06-12 + * @description 根据jobId查询训练任务详情返回结果 + * @date 2020-06-12 */ @Data @Accessors(chain = true) @@ -35,6 +35,9 @@ public class PtTrainJobDetailQueryVO implements Serializable { private static final long serialVersionUID = 1L; + @ApiModelProperty("训练作业名") + private String trainName; + @ApiModelProperty("jobID") private Long id; @@ -71,8 +74,8 @@ public class PtTrainJobDetailQueryVO implements Serializable { @ApiModelProperty("可视化日志路径") private String visualizedLogPath; - @ApiModelProperty("规格ID") - private Integer trainJobSpecsId; + @ApiModelProperty("规格名称") + private String trainJobSpecsName; @ApiModelProperty("类型(0为CPU,1为GPU)") private Integer resourcesPoolType; @@ -131,5 +134,33 @@ public class PtTrainJobDetailQueryVO implements Serializable { @ApiModelProperty("P4推理速度(ms)") private Integer p4InferenceSpeed; + @ApiModelProperty(value = "算法文件路径") + private String algorithmCodeDir; + + @ApiModelProperty("训练类型 0:普通训练,1:分布式训练") + private Integer trainType; + + @ApiModelProperty("验证数据来源名称") + private String valDataSourceName; + + @ApiModelProperty("验证数据来源路径") + private String valDataSourcePath; + + @ApiModelProperty("是否验证数据集") + private Integer valType; + + @ApiModelProperty("训练延时启动倒计时,单位:分钟") + private Integer delayCreateCountDown; + + @ApiModelProperty("训练自动停止倒计时,单位:分钟") + private Integer delayDeleteCountDown; + + @ApiModelProperty("资源拥有者ID") + private Long originUserId; + + @ApiModelProperty("训练信息(失败信息)") + private String trainMsg; + @ApiModelProperty("模型名称") + private String modelName; } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainJobDetailVO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainJobDetailVO.java index 29987c6..a0ca8ee 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainJobDetailVO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainJobDetailVO.java @@ -26,7 +26,7 @@ import java.io.Serializable; import java.sql.Timestamp; /** - * @description: 返回训练版本查询详情 + * @description 返回训练版本查询详情 * @date 2020-04-27 */ @Data @@ -74,8 +74,8 @@ public class PtTrainJobDetailVO implements Serializable { @ApiModelProperty("可视化日志路径") private String visualizedLogPath; - @ApiModelProperty("规格ID") - private Integer trainJobSpecsId; + @ApiModelProperty("规格名称") + private String trainJobSpecsName; @ApiModelProperty("类型(0为CPU,1为GPU)") private Integer resourcesPoolType; @@ -134,4 +134,42 @@ public class PtTrainJobDetailVO implements Serializable { @ApiModelProperty("P4推理速度(ms)") private Integer p4InferenceSpeed; + @ApiModelProperty(value = "算法文件路径") + private String algorithmCodeDir; + + @ApiModelProperty("训练类型") + private Integer trainType; + + @ApiModelProperty("验证数据来源名称") + private String valDataSourceName; + + @ApiModelProperty("验证数据来源路径") + private String valDataSourcePath; + + @ApiModelProperty("是否验证数据集") + private Integer valType; + + @ApiModelProperty("训练延时启动倒计时,单位:分钟") + private Integer delayCreateCountDown; + + @ApiModelProperty("训练自动停止倒计时,单位:分钟") + private Integer delayDeleteCountDown; + + @ApiModelProperty("模型路径") + private String modelLoadPathDir; + + @ApiModelProperty(value = "是否打开模型选择") + private Integer modelType; + + @ApiModelProperty(value = "模型类型(0我的模型1预置模型)") + private Integer modelResource; + + @ApiModelProperty(value = "模型名称") + private String modelName; + + @ApiModelProperty(value = "模型id") + private Integer modelId; + + @ApiModelProperty(value = "训练信息(失败信息)") + private String trainMsg; } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainJobSpecsQueryVO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainJobSpecsQueryVO.java index 23b0bc1..cf4fcd1 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainJobSpecsQueryVO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainJobSpecsQueryVO.java @@ -25,7 +25,7 @@ import lombok.experimental.Accessors; import java.io.Serializable; /** - * @description: 训练作业规格 + * @description 训练作业规格 * @date 2020-05-06 */ @Data diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainJobStatisticsMineVO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainJobStatisticsMineVO.java index 686e39e..98f2d1e 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainJobStatisticsMineVO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainJobStatisticsMineVO.java @@ -24,8 +24,8 @@ import lombok.Data; import java.io.Serializable; /** - * @description: 我的训练任务统计 - * @date: 2020-07-15 + * @description 我的训练任务统计 + * @date 2020-07-15 */ @Data @ApiModel(description = "我的训练任务统计结果") diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainJobStopVO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainJobStopVO.java index f05900e..b2573b7 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainJobStopVO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainJobStopVO.java @@ -23,8 +23,8 @@ import lombok.Data; import java.io.Serializable; /** - * @description: 返回停止训练任务结果 - * @date: 2020-04-28 + * @description 返回停止训练任务结果 + * @date 2020-04-28 */ @Data public class PtTrainJobStopVO implements Serializable { diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainLogQueryVO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainLogQueryVO.java index 2409a48..a3294ee 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainLogQueryVO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainLogQueryVO.java @@ -24,7 +24,7 @@ import lombok.experimental.Accessors; import java.util.List; /** - * @description: 训练日志 查询VO + * @description 训练日志 查询VO * @date 2020-05-08 */ @Data @@ -44,7 +44,7 @@ public class PtTrainLogQueryVO { @ApiModelProperty(value = "结束行") private Integer endLine; - @ApiModelProperty(value = "lines") + @ApiModelProperty(value = "查询行数") private Integer lines; } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainParamQueryVO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainParamQueryVO.java index 0db9ef8..6ee35f3 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainParamQueryVO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainParamQueryVO.java @@ -26,8 +26,8 @@ import org.dubhe.base.BaseVO; import java.io.Serializable; /** - * @description: 任务参数查询返回查询结果 - * @date: 2020-04-27 + * @description 任务参数查询返回查询结果 + * @date 2020-04-27 */ @EqualsAndHashCode(callSuper = true) @Data @@ -74,13 +74,40 @@ public class PtTrainParamQueryVO extends BaseVO implements Serializable { @ApiModelProperty("运行参数(算法来源为我的算法时为调优参数,算法来源为预置算法时为运行参数)") private JSONObject runParams; - @ApiModelProperty("规格ID") - private Integer trainJobSpecsId; + @ApiModelProperty("规格名称") + private String trainJobSpecsName; @ApiModelProperty("类型(0为CPU,1为GPU)") private Integer resourcesPoolType; + @ApiModelProperty("训练类型") + private Integer trainType; + @ApiModelProperty("节点个数") private Integer resourcesPoolNode; + @ApiModelProperty("验证数据来源名称") + private String valDataSourceName; + + @ApiModelProperty("验证数据来源路径") + private String valDataSourcePath; + + @ApiModelProperty("是否验证数据集") + private Integer valType; + + @ApiModelProperty(value = "是否打开模型选择") + private Integer modelType; + + @ApiModelProperty(value = "模型类型(0我的模型1预置模型)") + private Integer modelResource; + + @ApiModelProperty(value = "模型名称") + private String modelName; + + @ApiModelProperty(value = "模型加载路径") + private String modelLoadPathDir; + + @ApiModelProperty(value = "模型id") + private Integer modelId; + } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainVO.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainVO.java index 910632e..efda628 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainVO.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/vo/PtTrainVO.java @@ -25,7 +25,7 @@ import java.io.Serializable; import java.sql.Timestamp; /** - * @description: 训练查询结果 + * @description 训练查询结果 * @date 2020-04-27 */ @Data @@ -63,4 +63,7 @@ public class PtTrainVO implements Serializable { @ApiModelProperty("数据来源名称") private String dataSourceName; + + @ApiModelProperty("资源拥有者ID") + private Long originUserId; } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/enums/AlgorithmStatusEnum.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/enums/AlgorithmStatusEnum.java new file mode 100644 index 0000000..b27c8a2 --- /dev/null +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/enums/AlgorithmStatusEnum.java @@ -0,0 +1,63 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ +package org.dubhe.enums; + +/** + * @description 算法状态枚举 + * @date 2020-08-19 + */ +public enum AlgorithmStatusEnum { + + + /** + * 创建中 + */ + MAKING(0, "创建中"), + /** + * 创建成功 + */ + SUCCESS(1, "创建成功"), + /** + * 创建失败 + */ + FAIL(2, "创建失败"); + + + /** + * 编码 + */ + private Integer code; + + /** + * 描述 + */ + private String description; + + AlgorithmStatusEnum(int code, String description) { + this.code = code; + this.description = description; + } + + public Integer getCode() { + return code; + } + + public String getDescription() { + return description; + } +} + diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/enums/ResourcesPoolTypeEnum.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/enums/ResourcesPoolTypeEnum.java index 32a7640..53d1e2f 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/enums/ResourcesPoolTypeEnum.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/enums/ResourcesPoolTypeEnum.java @@ -17,9 +17,8 @@ package org.dubhe.enums; /** - *@description 规格类型 - *@date: 2020-07-15 - + * @description 规格类型 + * @date 2020-07-15 */ public enum ResourcesPoolTypeEnum { @@ -50,4 +49,13 @@ public enum ResourcesPoolTypeEnum { return description; } + /** + * 是否是GPU编码 + * @param code + * @return true 是 ,false 否 + */ + public static boolean isGpuCode(Integer code){ + return GPU.getCode().equals(code); + } + } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/enums/TrainTypeEnum.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/enums/TrainTypeEnum.java new file mode 100644 index 0000000..8d50bff --- /dev/null +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/enums/TrainTypeEnum.java @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ +package org.dubhe.enums; + +import lombok.Getter; + +/** + * @description 训练类型 + * @date 2020-08-31 + */ +@Getter +public enum TrainTypeEnum { + + /** + * 普通训练 + */ + JOB(0,"普通训练"), + /** + * 分布式训练 + */ + DISTRIBUTE_TRAIN(1,"分布式训练"), + ; + + + private Integer code; + + private String name; + + TrainTypeEnum(Integer code, String name) { + this.code = code; + this.name = name; + } + + /** + * 判断是否是分布式训练 + * @param trainType 训练类型 + * @return true 分布式训练,false 普通训练 + */ + public static boolean isDistributeTrain(int trainType){ + return DISTRIBUTE_TRAIN.getCode() == trainType; + } +} diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/rest/K8sCallbackPodController.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/rest/K8sCallbackPodController.java index 87f05f4..de537b4 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/rest/K8sCallbackPodController.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/rest/K8sCallbackPodController.java @@ -17,20 +17,21 @@ package org.dubhe.rest; -import javax.annotation.Resource; - +import io.swagger.annotations.Api; +import io.swagger.annotations.ApiOperation; +import io.swagger.annotations.ApiParam; import org.dubhe.base.DataResponseBody; -import org.dubhe.factory.DataResponseFactory; import org.dubhe.dto.callback.AlgorithmK8sPodCallbackCreateDTO; import org.dubhe.dto.callback.NotebookK8sPodCallbackCreateDTO; +import org.dubhe.factory.DataResponseFactory; import org.dubhe.service.PodCallbackAsyncService; import org.dubhe.utils.K8sCallBackTool; import org.springframework.validation.annotation.Validated; import org.springframework.web.bind.annotation.*; -import io.swagger.annotations.Api; -import io.swagger.annotations.ApiOperation; -import io.swagger.annotations.ApiParam; +import javax.annotation.Resource; + +import static org.dubhe.constant.StringConstant.K8S_CALLBACK_URI; /** * @description k8s Pod 异步回调处理类 @@ -39,7 +40,7 @@ import io.swagger.annotations.ApiParam; */ @Api(tags = "k8s回调:Pod") @RestController -@RequestMapping("/api/k8s/callback/pod") +@RequestMapping(K8S_CALLBACK_URI) public class K8sCallbackPodController { @Resource(name = "noteBookAsyncServiceImpl") @@ -72,7 +73,7 @@ public class K8sCallbackPodController { */ @PostMapping(value = "/algorithm") @ApiOperation("算法管理 pod 回调") - public DataResponseBody notebookPodCallBack(@ApiParam(type = "head") @RequestHeader(name= K8sCallBackTool.K8S_CALLBACK_TOKEN) String k8sToken + public DataResponseBody algorithmPodCallBack(@ApiParam(type = "head") @RequestHeader(name= K8sCallBackTool.K8S_CALLBACK_TOKEN) String k8sToken ,@Validated @RequestBody AlgorithmK8sPodCallbackCreateDTO k8sPodCallbackReq) { algorithmAsyncServiceImpl.podCallBack(k8sPodCallbackReq); return DataResponseFactory.success("算法管理异步回调处理方法中"); diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/rest/NoteBookController.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/rest/NoteBookController.java index 1076cc7..5ac67d6 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/rest/NoteBookController.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/rest/NoteBookController.java @@ -68,7 +68,6 @@ public class NoteBookController { @GetMapping(value = "/notebooks") @RequiresPermissions(Permissions.DEVELOPMENT_NOTEBOOK) public DataResponseBody getNoteBookList(Page page, NoteBookListQueryDTO noteBookListQueryDTO) { - noteBookListQueryDTO.setUserId(NotebookUtil.getCurUserId()); return new DataResponseBody(noteBookService.getNoteBookList(page, noteBookListQueryDTO)); } @@ -102,7 +101,7 @@ public class NoteBookController { } - @ApiOperation("开启notebook") + @ApiOperation("打开notebook") @GetMapping(value = "/{id}") @RequiresPermissions(Permissions.DEVELOPMENT_NOTEBOOK) public DataResponseBody openNotebook(@PathVariable(name = "id", required = true) Long noteBookId) { @@ -141,7 +140,7 @@ public class NoteBookController { @GetMapping(value = "/run-number") @RequiresPermissions(Permissions.DEVELOPMENT_NOTEBOOK) public DataResponseBody getNoteBookRunNumber() { - return new DataResponseBody(noteBookService.getNoteBookRunNumber(NotebookUtil.getCurUserId())); + return new DataResponseBody(noteBookService.getNoteBookRunNumber()); } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/rest/PtImageController.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/rest/PtImageController.java index 792b32d..d4a3460 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/rest/PtImageController.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/rest/PtImageController.java @@ -23,7 +23,9 @@ import org.apache.shiro.authz.annotation.RequiresPermissions; import org.dubhe.annotation.ApiVersion; import org.dubhe.base.DataResponseBody; import org.dubhe.constant.Permissions; +import org.dubhe.domain.dto.PtImageDeleteDTO; import org.dubhe.domain.dto.PtImageQueryDTO; +import org.dubhe.domain.dto.PtImageUpdateDTO; import org.dubhe.domain.dto.PtImageUploadDTO; import org.dubhe.service.PtImageService; import org.springframework.beans.factory.annotation.Autowired; @@ -71,4 +73,28 @@ public class PtImageController { public DataResponseBody getHarborProjectList() { return new DataResponseBody(ptImageService.getHarborProjectList()); } + + + @DeleteMapping + @ApiOperation("删除镜像") + @RequiresPermissions(Permissions.TRAINING_IMAGE) + public DataResponseBody deleteTrainImage(@RequestBody PtImageDeleteDTO ptImageDeleteDTO) { + ptImageService.deleteTrainImage(ptImageDeleteDTO); + return new DataResponseBody(); + } + + @PutMapping + @ApiOperation("修改镜像信息") + @RequiresPermissions(Permissions.TRAINING_IMAGE) + public DataResponseBody updateTrainImage(@Validated @RequestBody PtImageUpdateDTO ptImageUpdateDTO) { + ptImageService.updateTrainImage(ptImageUpdateDTO); + return new DataResponseBody(); + } + + @GetMapping("/imageNameList") + @ApiOperation("获取镜像名称列表") + @RequiresPermissions(Permissions.TRAINING_IMAGE) + public DataResponseBody getImageNameList() { + return new DataResponseBody(ptImageService.getImageNameList()); + } } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/rest/PtTrainLogController.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/rest/PtTrainLogController.java index add7226..ab82378 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/rest/PtTrainLogController.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/rest/PtTrainLogController.java @@ -26,10 +26,12 @@ import org.dubhe.base.MagicNumConstant; import org.dubhe.constant.Permissions; import org.dubhe.domain.dto.PtTrainLogQueryDTO; import org.dubhe.domain.vo.PtTrainLogQueryVO; +import org.dubhe.factory.DataResponseFactory; import org.dubhe.service.PtTrainLogService; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.validation.annotation.Validated; import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; @@ -64,4 +66,14 @@ public class PtTrainLogController { return new DataResponseBody(ptTrainLogService.getTrainLogString(ptTrainLogQueryVO.getContent())); } + @GetMapping("/pod/{id}") + @ApiOperation("获取pod节点") + @RequiresPermissions(Permissions.TRAINING_JOB) + public DataResponseBody getPods(@PathVariable Long id) { + return DataResponseFactory.success(ptTrainLogService.getPods(id)); + } + + + + } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/NoteBookService.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/NoteBookService.java index abfe52d..92f33cc 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/NoteBookService.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/NoteBookService.java @@ -159,10 +159,9 @@ public interface NoteBookService { /** * 获取正在运行的notebook数量 * - * @param curUserId * @return int */ - int getNoteBookRunNumber(long curUserId); + int getNoteBookRunNumber(); /** * 获取notebook模板 diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/PtImageService.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/PtImageService.java index f549ffd..8fb06ea 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/PtImageService.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/PtImageService.java @@ -17,12 +17,15 @@ package org.dubhe.service; +import org.dubhe.domain.dto.PtImageDeleteDTO; import org.dubhe.domain.dto.PtImageQueryDTO; +import org.dubhe.domain.dto.PtImageUpdateDTO; import org.dubhe.domain.dto.PtImageUploadDTO; import org.dubhe.domain.entity.HarborProject; import java.util.List; import java.util.Map; +import java.util.Set; /** * @description 镜像服务service @@ -47,12 +50,6 @@ public interface PtImageService { void uploadImage(PtImageUploadDTO ptImageUploadDTO); - /** - * 定时到harbor同步imageName - */ - void harborImageNameSync(); - - /** * 通过imageName查询所含镜像版本信息 * @@ -68,4 +65,27 @@ public interface PtImageService { * @return List harbor镜像集合 **/ List getHarborProjectList(); + + + /** + * 删除镜像 + * + * @param imageDeleteDTO 删除镜像条件参数 + */ + void deleteTrainImage(PtImageDeleteDTO imageDeleteDTO); + + /** + * 修改镜像信息 + * + * @param imageUpdateDTO 修改的镜像信息 + */ + void updateTrainImage(PtImageUpdateDTO imageUpdateDTO); + + + /** + * 获取镜像名称列表 + * + * @return Set 镜像列表 + */ + Set getImageNameList(); } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/PtTrainJobService.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/PtTrainJobService.java index d42cb19..3714da8 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/PtTrainJobService.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/PtTrainJobService.java @@ -119,7 +119,7 @@ public interface PtTrainJobService { * 获取job在grafana监控的地址 * * @param jobId 任务ID - * @return PtJobMetricsGrafanaVO Pod Metrics Grafana url + * @return List Pod Metrics Grafana url */ - PtJobMetricsGrafanaVO getGrafanaUrl(Long jobId); + List getGrafanaUrl(Long jobId); } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/PtTrainLogService.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/PtTrainLogService.java index b4abf6d..e1d5c02 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/PtTrainLogService.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/PtTrainLogService.java @@ -20,6 +20,7 @@ package org.dubhe.service; import java.util.List; import org.dubhe.domain.dto.PtTrainLogQueryDTO; +import org.dubhe.k8s.domain.vo.PodVO; import org.dubhe.domain.vo.PtTrainLogQueryVO; /** @@ -43,4 +44,12 @@ public interface PtTrainLogService { * @return String 字符串 */ String getTrainLogString(List content); + + /** + * 获取训练任务的Pod + * + * @param id 训练作业job表 id + * @return 训练任务的Pod + */ + List getPods(Long id); } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/convert/PtJupyterResourceConvert.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/convert/PtJupyterResourceConvert.java index 9c371e8..219f6a8 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/convert/PtJupyterResourceConvert.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/convert/PtJupyterResourceConvert.java @@ -37,9 +37,11 @@ public class PtJupyterResourceConvert { * NoteBook 转换为 PtJupyterResourceBO * * @param noteBook + * @param k8sNameTool + * @param notebookDelayDeleteTime * @return PtJupyterResourceBO */ - public static PtJupyterResourceBO toPtJupyterResourceBo(NoteBook noteBook,K8sNameTool k8sNameTool){ + public static PtJupyterResourceBO toPtJupyterResourceBo(NoteBook noteBook, K8sNameTool k8sNameTool, Integer notebookDelayDeleteTime){ if (noteBook == null){ return null; } @@ -59,6 +61,7 @@ public class PtJupyterResourceConvert { .setDatasetDir(k8sNameTool.getAbsoluteNfsPath(noteBook.getDataSourcePath())) .setDatasetMountPath(k8sNameTool.getDatasetPath()) .setDatasetReadOnly(true) + .setDelayDeleteTime(notebookDelayDeleteTime) ; return bo; } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/AlgorithmAsyncServiceImpl.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/AlgorithmAsyncServiceImpl.java index 34fb1ae..4dba6cd 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/AlgorithmAsyncServiceImpl.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/AlgorithmAsyncServiceImpl.java @@ -19,21 +19,34 @@ package org.dubhe.service.impl; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import org.dubhe.base.MagicNumConstant; -import org.dubhe.enums.TrainJobStatusEnum; +import org.dubhe.dao.PtJobParamMapper; import org.dubhe.dao.PtTrainJobMapper; +import org.dubhe.domain.entity.PtJobParam; import org.dubhe.domain.entity.PtTrainJob; import org.dubhe.dto.callback.AlgorithmK8sPodCallbackCreateDTO; import org.dubhe.dto.callback.BaseK8sPodCallbackCreateDTO; import org.dubhe.enums.LogEnum; +import org.dubhe.enums.TrainJobStatusEnum; +import org.dubhe.enums.TrainTypeEnum; +import org.dubhe.k8s.api.LogMonitoringApi; +import org.dubhe.k8s.api.PodApi; +import org.dubhe.k8s.domain.bo.LogMonitoringBO; +import org.dubhe.k8s.domain.resource.BizPod; +import org.dubhe.k8s.enums.K8sKindEnum; +import org.dubhe.k8s.enums.ContainerStatusesStateEnum; +import org.dubhe.k8s.utils.PodUtil; import org.dubhe.service.PodCallbackAsyncService; import org.dubhe.service.abstracts.AbstractPodCallback; -import org.dubhe.utils.K8sNameTool; -import org.dubhe.utils.LogUtil; -import org.dubhe.utils.TrainUtil; +import org.dubhe.utils.*; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + /** * @description 训练任务回调 @@ -44,9 +57,22 @@ public class AlgorithmAsyncServiceImpl extends AbstractPodCallback implements Po @Autowired private PtTrainJobMapper ptTrainJobMapper; + @Autowired private K8sNameTool k8sNameTool; + @Autowired + private PodApi podApi; + + @Autowired + private PtJobParamMapper ptJobParamMapper; + + @Autowired + private LogMonitoringApi logMonitoringApi; + + @Autowired + private RedisUtils redisUtils; + /** * pod 异步回调具体实现处理类 * @@ -60,48 +86,161 @@ public class AlgorithmAsyncServiceImpl extends AbstractPodCallback implements Po // 强制转型 AlgorithmK8sPodCallbackCreateDTO req = (AlgorithmK8sPodCallbackCreateDTO) k8sPodCallbackCreateDTO; LogUtil.info(LogEnum.BIZ_TRAIN, "Thread {} try {} time.Request: {}", Thread.currentThread(), times, req.toString()); - // 根据namespace和podName找到job - QueryWrapper queryTrainJonWrapper = new QueryWrapper<>(); - Long userId = k8sNameTool.getUserIdFromNameSpace(req.getNamespace()); - String podName = req.getPodName(); - if (null == podName || podName.length() <= MagicNumConstant.SIX) { - LogUtil.error(LogEnum.BIZ_TRAIN, "podName={} too short", podName); - return false; - } - String k8sJobName = podName.substring(MagicNumConstant.ZERO, podName.length() - MagicNumConstant.SIX); - queryTrainJonWrapper.eq("k8s_job_name", k8sJobName). - eq("create_user_id", userId); - PtTrainJob ptTrainJob = ptTrainJobMapper.selectOne(queryTrainJonWrapper); + // 匹配训练任务 + PtTrainJob ptTrainJob = getPtTrainJob(req); if (null == ptTrainJob) { - LogUtil.error(LogEnum.BIZ_TRAIN, "k8s_job_name={} not found", k8sJobName); + LogUtil.error(LogEnum.BIZ_TRAIN, "req={} not found", req); return false; } String phase = req.getPhase(); - // 对于当前状态为结束状态或上报状态为删除的任务不做处理 - if (TrainJobStatusEnum.isEnd(ptTrainJob.getTrainStatus()) || TrainJobStatusEnum.DELETED.getMessage().equalsIgnoreCase(phase)) { + if (TrainJobStatusEnum.isEnd(ptTrainJob.getTrainStatus())) { + // 对于当前状态为结束状态的任务不做处理 + return true; + } + // 处理启动异常日志 + dealFailed(req,times); + if (undoDistributeTrain(ptTrainJob, req)){ + // 不需要做回调处理的分布式训练场景 return true; } - PtTrainJob updatePtTrainJob = new PtTrainJob(); - - // 更新job运行时间和状态 - updatePtTrainJob.setId(ptTrainJob.getId()) - .setTrainStatus(TrainJobStatusEnum.get(phase).getStatus()); // 如果上报状态是结束状态并没指定过运行时间,则更新运行时间 - if (TrainJobStatusEnum.isEnd(phase) && "".equals(ptTrainJob.getRuntime())) { - long timeDelta = System.currentTimeMillis() - ptTrainJob.getCreateTime().getTime(); + if (TrainJobStatusEnum.isEnd(phase) && TrainUtil.INIT_RUNTIME.equals(ptTrainJob.getRuntime())) { + //获取训练运行参数 + QueryWrapper jobParamQueryWrapper = new QueryWrapper<>(); + jobParamQueryWrapper.eq("train_job_id", ptTrainJob.getId()).last(" limit 1 "); + PtJobParam ptJobParam = ptJobParamMapper.selectOne(jobParamQueryWrapper); + if (ptJobParam == null) { + LogUtil.error(LogEnum.BIZ_TRAIN, "the data of table pt_job_param queried by trainJobId does not exist {}", ptTrainJob.getId()); + return false; + } + //判断训练是否延时启动,并更新运行时长 + long timeDelta = (ptJobParam.getDelayCreateTime() != null && ptJobParam.getDelayCreateTime().getTime() > ptTrainJob.getUpdateTime().getTime()) ? System.currentTimeMillis() - ptJobParam.getDelayCreateTime().getTime() : System.currentTimeMillis() - ptTrainJob.getUpdateTime().getTime(); String runTime = String.format(TrainUtil.RUNTIME, TimeUnit.MILLISECONDS.toHours(timeDelta), TimeUnit.MILLISECONDS.toMinutes(timeDelta) % TimeUnit.HOURS.toMinutes(1), TimeUnit.MILLISECONDS.toSeconds(timeDelta) % TimeUnit.MINUTES.toSeconds(1) ); - updatePtTrainJob.setRuntime(runTime); + ptTrainJob.setRuntime(runTime); } - int updateResult = ptTrainJobMapper.updateById(updatePtTrainJob); + // 更新job运行时间和状态 + ptTrainJob.setTrainStatus(TrainJobStatusEnum.transferStatus(phase).getStatus()); + int updateResult = ptTrainJobMapper.updateById(ptTrainJob); if (updateResult < 1) { LogUtil.error(LogEnum.BIZ_TRAIN, "update trainJob_id={} failed, phase={}", ptTrainJob.getId(), req.getPhase()); return false; } + return true; + } + + /** + * 记录异常情况的日志 + * @param req + * @param times 尝试次数 + */ + private void dealFailed(AlgorithmK8sPodCallbackCreateDTO req, int times) { + if (times != 1){ + // 仅第一次执行,避免重复产生日志 + return; + } + TrainJobStatusEnum trainJobStatusEnum = TrainJobStatusEnum.getByMessage(req.getPhase()); + if(TrainJobStatusEnum.FAILED != trainJobStatusEnum || StringUtils.isBlank(req.getMessages())){ + // 必须是回调FAILED且有日志才执行日志记录 + return; + } + // 生成资源唯一标识,避免并发调用重复执行 + String key = req.getNamespace() + "#" + req.getResourceName(); + // 线程唯一身份标识 + String uuid = UUID.randomUUID().toString(); + try { + if (!redisUtils.getDistributedLock(key,uuid,MagicNumConstant.TEN)){ + return; + } + if(logMonitoringApi.searchLogByPodName( + 0, + 1, + new LogMonitoringBO(req.getNamespace(),req.getResourceName()) + ).getTotalLogs() > 0){ + // 已有失败日志,不执行 + return; + } + List logList = new ArrayList<>(2); + logList.add(DateUtil.getCurrentTimeStr() + ": Pod startup failure!"); + logList.add("Reason: "+ ContainerStatusesStateEnum.getStateMessage(req.getMessages())); + logMonitoringApi.addLogsToEs(req.getPodName(), req.getNamespace(),logList); + }finally { + redisUtils.releaseDistributedLock(key,uuid); + } + } + /** + * 验证是否是不需要做回调处理的分布式训练场景 + * 1,RUNNING回调时有Pod还没启动成功 + * 2,非 Master Pod的回调的结束状态状态 + * @param ptTrainJob + * @param req + * @return true 不需要做回调处理,false,需要做回调处理 + */ + private boolean undoDistributeTrain(PtTrainJob ptTrainJob,AlgorithmK8sPodCallbackCreateDTO req){ + String phase = req.getPhase(); + if (TrainTypeEnum.isDistributeTrain(ptTrainJob.getTrainType())) { + // 分布式训练 + if (ptTrainJob.getResourcesPoolNode() > MagicNumConstant.ONE + && TrainJobStatusEnum.RUNNING == TrainJobStatusEnum.getByMessage(phase) + && !validateDistributedRunningPod(req.getNamespace(), ptTrainJob)) { + // 节点数大于1 且 其回调状态为RUNNING时,需要做 多节点是否都已RUNNING的判断,以保证分布式训练任务已经处于运行状态 + // 没有启动完毕,等待下次Pod回调 + return true; + } + if (TrainJobStatusEnum.isEnd(phase) + && !PodUtil.isMaster(req.getPodName())) { + // 仅是主节点结束状态才需要更新分布式训练结束状态信息 + return true; + } + } + return false; + } + + /** + * 匹配训练任务 + * @param req + * @return PtTrainJob + */ + private PtTrainJob getPtTrainJob(AlgorithmK8sPodCallbackCreateDTO req) { + // 根据namespace和podName找到job + Long userId = k8sNameTool.getUserIdFromNamespace(req.getNamespace()); + QueryWrapper queryTrainJonWrapper = new QueryWrapper<>(); + queryTrainJonWrapper.eq("create_user_id", userId); + if (K8sKindEnum.DISTRIBUTETRAIN.getKind().equals(req.getPodParentType()) + || K8sKindEnum.JOB.getKind().equals(req.getPodParentType())) { + queryTrainJonWrapper.eq("k8s_job_name", req.getPodParentName()); + } else { + LogUtil.error(LogEnum.BIZ_TRAIN, "Pod parent type [{}] not support in callback!", req.getPodParentType()); + return null; + } + return ptTrainJobMapper.selectOne(queryTrainJonWrapper); + } + + + /** + * 验证分布式训练节点是否都已启动 + * @param namespace + * @param ptTrainJob + * @return true 完全启动,false 没有启动完毕 + */ + private boolean validateDistributedRunningPod(String namespace, PtTrainJob ptTrainJob) { + List podList = podApi.getListByResourceName(namespace, ptTrainJob.getJobName()); + if (podList.size() != ptTrainJob.getResourcesPoolNode()) { + LogUtil.error(LogEnum.BIZ_TRAIN, "k8s pod num ne resources pod num {}/{} !", podList.size(), ptTrainJob.getResourcesPoolNode()); + return false; + } + int runningPodSize = podList.stream() + .filter(p -> TrainJobStatusEnum.RUNNING == TrainJobStatusEnum.getByMessage(p.getPhase())) + .collect(Collectors.toList()) + .size(); + if (runningPodSize != ptTrainJob.getResourcesPoolNode()) { + LogUtil.warn(LogEnum.BIZ_TRAIN, "k8s running pod num {}/{} ", runningPodSize, ptTrainJob.getResourcesPoolNode()); + return false; + } return true; } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/NoteBookServiceImpl.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/NoteBookServiceImpl.java index 0537a24..3a7043e 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/NoteBookServiceImpl.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/NoteBookServiceImpl.java @@ -34,10 +34,7 @@ import org.dubhe.domain.dto.SourceNoteBookDTO; import org.dubhe.domain.entity.NoteBook; import org.dubhe.domain.entity.NoteBookModel; import org.dubhe.domain.vo.NoteBookVO; -import org.dubhe.enums.BizEnum; -import org.dubhe.enums.BizNfsEnum; -import org.dubhe.enums.LogEnum; -import org.dubhe.enums.NoteBookStatusEnum; +import org.dubhe.enums.*; import org.dubhe.exception.NotebookBizException; import org.dubhe.harbor.api.HarborApi; import org.dubhe.k8s.api.PodApi; @@ -60,6 +57,7 @@ import org.dubhe.utils.NumberUtil; import org.dubhe.utils.PageUtil; import org.dubhe.utils.WrapperHelp; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; @@ -106,6 +104,9 @@ public class NoteBookServiceImpl implements NoteBookService { @Autowired private HarborProjectService harborProjectService; + @Value("${delay.notebook.delete}") + private Integer notebookDelayDeleteTime; + private static final String BLANK = SymbolConstant.BLANK; /** @@ -117,10 +118,28 @@ public class NoteBookServiceImpl implements NoteBookService { */ @Override public Map getNoteBookList(Page page, NoteBookListQueryDTO noteBookListQueryDTO) { - IPage noteBookPage = noteBookMapper.selectPage(page, WrapperHelp.getWrapper(noteBookListQueryDTO) - .ne(true, "status", NoteBookStatusEnum.DELETE.getCode()) - .ne(true, "deleted", NoteBookStatusEnum.STOP.getCode()) - .orderBy(true, false, "id")); + QueryWrapper queryWrapper = WrapperHelp.getWrapper(noteBookListQueryDTO); + queryWrapper.ne(true, NoteBook.COLUMN_STATUS, NoteBookStatusEnum.DELETE.getCode()) + .ne(true, "deleted", NoteBookStatusEnum.STOP.getCode()); + if (noteBookListQueryDTO.getStatus() != null){ + if (noteBookListQueryDTO.getStatus().equals(NoteBookStatusEnum.RUN.getCode())){ + //运行中的notebook必须有url + queryWrapper.eq(NoteBook.COLUMN_STATUS, NoteBookStatusEnum.RUN.getCode()) + .ne(NoteBook.COLUMN_URL,SymbolConstant.BLANK); + }else if (noteBookListQueryDTO.getStatus().equals(NoteBookStatusEnum.STARTING.getCode())){ + //启动中的notebook还包括运行中但没有url + queryWrapper.and((qw)-> + qw.eq(NoteBook.COLUMN_STATUS, NoteBookStatusEnum.RUN.getCode()).eq(NoteBook.COLUMN_URL, SymbolConstant.BLANK) + .or() + .eq(NoteBook.COLUMN_STATUS,NoteBookStatusEnum.STARTING.getCode()) + ); + }else { + // 其他状态照常 + queryWrapper.eq(NoteBook.COLUMN_STATUS, NoteBookStatusEnum.RUN.getCode()); + } + } + queryWrapper.orderBy(true, false, "id"); + IPage noteBookPage = noteBookMapper.selectPage(page, queryWrapper); return PageUtil.toPage(noteBookPage, noteBookConvert::toDto); } @@ -166,13 +185,13 @@ public class NoteBookServiceImpl implements NoteBookService { @Override @Transactional(rollbackFor = Exception.class) public NoteBookVO createNoteBook(NoteBook noteBook) { - if (noteBookMapper.findByNameAndUserId(noteBook.getNoteBookName(), noteBook.getCreateUserId(),NoteBookStatusEnum.DELETE.getCode()) != null) { + if (noteBookMapper.findByNameAndStatus(noteBook.getNoteBookName(),NoteBookStatusEnum.DELETE.getCode()) != null) { throw new NotebookBizException("Notebook名称已使用过!请重新提交。"); } if (StringUtils.isEmpty(noteBook.getName())) { noteBook.setName(k8sNameTool.getK8sName()); } - noteBook.setK8sNamespace(k8sNameTool.generateNameSpace(noteBook.getCreateUserId())); + noteBook.setK8sNamespace(k8sNameTool.generateNamespace(noteBook.getCreateUserId())); noteBook.setK8sResourceName(k8sNameTool.generateResourceName(BizEnum.NOTEBOOK, noteBook.getName())); if (StringUtils.isBlank(noteBook.getK8sPvcPath())) { //20200618 修改为 使用训练路劲 @@ -368,7 +387,7 @@ public class NoteBookServiceImpl implements NoteBookService { if (initNameSpace(noteBook, null)) { try { //20200618 修改为 创建时不创建PVC - PtJupyterDeployVO result = jupyterResourceApi.create(PtJupyterResourceConvert.toPtJupyterResourceBo(noteBook, k8sNameTool)); + PtJupyterDeployVO result = jupyterResourceApi.create(PtJupyterResourceConvert.toPtJupyterResourceBo(noteBook, k8sNameTool, notebookDelayDeleteTime)); noteBook.setK8sStatusCode(result.getCode() == null ? BLANK : result.getCode()); noteBook.setK8sStatusInfo(NotebookUtil.getK8sStatusInfo(result)); return HttpUtils.isSuccess(result.getCode()); @@ -556,7 +575,7 @@ public class NoteBookServiceImpl implements NoteBookService { noteBook.setDescription(bizNfsEnum.getBizName()); noteBook.setName(k8sNameTool.getK8sName()); String notebookName = NotebookUtil.generateName(bizNfsEnum, sourceNoteBookDTO.getSourceId()); - if (noteBookMapper.findByNameAndUserId(notebookName, noteBook.getCreateUserId(),NoteBookStatusEnum.DELETE.getCode()) != null) { + if (noteBookMapper.findByNameAndStatus(notebookName,NoteBookStatusEnum.DELETE.getCode()) != null) { // 重名随机符号拼接 notebookName += RandomUtil.randomString(MagicNumConstant.TWO); } @@ -611,7 +630,7 @@ public class NoteBookServiceImpl implements NoteBookService { @Override public String deletePvc(NoteBook noteBook) { noteBook.setStatus(NoteBookStatusEnum.DELETE.getCode()); - noteBook.setDeleted(MagicNumConstant.ONE); + noteBook.setDeleted(true); return NoteBookStatusEnum.DELETE.getDescription(); } @@ -637,12 +656,11 @@ public class NoteBookServiceImpl implements NoteBookService { /** * 获取正在运行的notebook数量 * - * @param curUserId * @return int */ @Override - public int getNoteBookRunNumber(long curUserId) { - return noteBookMapper.selectRunNoteBookNum(curUserId,NoteBookStatusEnum.RUN.getCode()); + public int getNoteBookRunNumber() { + return noteBookMapper.selectRunNoteBookNum(NoteBookStatusEnum.RUN.getCode()); } /** @@ -720,8 +738,7 @@ public class NoteBookServiceImpl implements NoteBookService { public List getNotebookDetail(Set noteBookIds) { QueryWrapper queryWrapper = new QueryWrapper<>(); queryWrapper.in("id",noteBookIds); - queryWrapper.eq("user_id",NotebookUtil.getCurUserId()); - queryWrapper.ne("status",NoteBookStatusEnum.DELETE.getCode()); + queryWrapper.ne(NoteBook.COLUMN_STATUS,NoteBookStatusEnum.DELETE.getCode()); List noteBookList = noteBookMapper.selectList(queryWrapper); return noteBookConvert.toDto(noteBookList); } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtDatasetServiceImpl.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtDatasetServiceImpl.java index 28fcced..ac88a20 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtDatasetServiceImpl.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtDatasetServiceImpl.java @@ -27,7 +27,6 @@ import org.dubhe.service.PtDatasetService; import org.dubhe.service.convert.PtDatasetConvert; import org.dubhe.utils.FileUtil; import org.dubhe.utils.PageUtil; -import org.dubhe.utils.StringUtils; import org.dubhe.utils.WrapperHelp; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.cache.annotation.CacheConfig; diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtDevEnvsServiceImpl.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtDevEnvsServiceImpl.java index 09d43f9..262a962 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtDevEnvsServiceImpl.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtDevEnvsServiceImpl.java @@ -27,7 +27,6 @@ import org.dubhe.service.PtDevEnvsService; import org.dubhe.service.convert.PtDevEnvsConvert; import org.dubhe.utils.FileUtil; import org.dubhe.utils.PageUtil; -import org.dubhe.utils.StringUtils; import org.dubhe.utils.WrapperHelp; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.cache.annotation.CacheConfig; diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtImageServiceImpl.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtImageServiceImpl.java index 05040fc..c23109c 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtImageServiceImpl.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtImageServiceImpl.java @@ -18,29 +18,26 @@ package org.dubhe.service.impl; import cn.hutool.core.util.StrUtil; +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.baomidou.mybatisplus.core.metadata.IPage; import com.baomidou.mybatisplus.extension.plugins.pagination.Page; +import org.dubhe.annotation.DataPermissionMethod; +import org.dubhe.async.HarborImagePushAsync; import org.dubhe.base.ResponseCode; -import org.dubhe.config.TrainHarborConfig; import org.dubhe.config.NfsConfig; +import org.dubhe.config.TrainHarborConfig; import org.dubhe.dao.HarborProjectMapper; import org.dubhe.dao.PtImageMapper; import org.dubhe.data.constant.Constant; -import org.dubhe.domain.dto.PtImageQueryDTO; -import org.dubhe.domain.dto.PtImageUploadDTO; -import org.dubhe.domain.dto.UserDTO; +import org.dubhe.domain.dto.*; import org.dubhe.domain.entity.HarborProject; import org.dubhe.domain.entity.PtImage; import org.dubhe.domain.vo.PtImageQueryVO; -import org.dubhe.enums.HarborResourceEnum; -import org.dubhe.enums.ImageSourceEnum; -import org.dubhe.enums.ImageStateEnum; -import org.dubhe.enums.LogEnum; +import org.dubhe.enums.*; import org.dubhe.exception.BusinessException; import org.dubhe.harbor.api.HarborApi; import org.dubhe.service.PtImageService; -import org.dubhe.task.HarborImagePushAsync; import org.dubhe.utils.*; import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Autowired; @@ -76,10 +73,10 @@ public class PtImageServiceImpl implements PtImageService { @Autowired private TrainHarborConfig trainHarborConfig; - public final static List filedNames; + public final static List FIELD_NAMES; static { - filedNames = ReflectionUtils.getFieldNames(PtImageQueryVO.class); + FIELD_NAMES = ReflectionUtils.getFieldNames(PtImageQueryVO.class); } /** @@ -89,6 +86,7 @@ public class PtImageServiceImpl implements PtImageService { * @return Map 返回镜像分页数据 **/ @Override + @DataPermissionMethod(dataType = DatasetTypeEnum.PUBLIC) public Map getImage(PtImageQueryDTO ptImageQueryDTO) { //从会话中获取用户信息 @@ -100,17 +98,19 @@ public class PtImageServiceImpl implements PtImageService { ptImageQueryDTO.setImageResource(ImageSourceEnum.MINE.getCode()); } QueryWrapper query = new QueryWrapper<>(); - if (ImageSourceEnum.MINE.getCode().equals(ptImageQueryDTO.getImageResource())) { - query.eq("create_user_id", user.getId()); - } if (ptImageQueryDTO.getImageStatus() != null) { query.eq("image_status", ptImageQueryDTO.getImageStatus()); } query.eq("image_resource", ptImageQueryDTO.getImageResource()); + + if (StringUtils.isNotEmpty(ptImageQueryDTO.getImageNameOrId())) { + query.and(x -> x.eq("id", ptImageQueryDTO.getImageNameOrId()).or().like("image_name", ptImageQueryDTO.getImageNameOrId())); + } + //排序 IPage ptImages; try { - if (ptImageQueryDTO.getSort() != null && filedNames.contains(ptImageQueryDTO.getSort())) { + if (ptImageQueryDTO.getSort() != null && FIELD_NAMES.contains(ptImageQueryDTO.getSort())) { if (Constant.SORT_ASC.equalsIgnoreCase(ptImageQueryDTO.getOrder())) { query.orderByAsc(StringUtils.humpToLine(ptImageQueryDTO.getSort())); } else { @@ -143,14 +143,6 @@ public class PtImageServiceImpl implements PtImageService { public void uploadImage(PtImageUploadDTO ptImageUploadDTO) { LogUtil.info(LogEnum.BIZ_TRAIN, "Upload image to harbor to receive parameters :{}", ptImageUploadDTO); UserDTO currentUser = JwtUtils.getCurrentUserDto(); - QueryWrapper query = new QueryWrapper<>(); - query.eq("image_name", ptImageUploadDTO.getImageName()); - Integer harborProjectCountResult = harborProjectMapper.selectCount(query); - if (harborProjectCountResult < 1) { - LogUtil.info(LogEnum.BIZ_TRAIN, "The imageName for uploading the image is [{}] not configured", ptImageUploadDTO.getImageName()); - - throw new BusinessException(ResponseCode.SUCCESS, "上传镜像的harborProject未配置!"); - } //校验用户自定义镜像不能和预置镜像重名 List resList = checkUploadImage(ptImageUploadDTO, currentUser, ImageSourceEnum.MINE.getCode()); @@ -166,8 +158,8 @@ public class PtImageServiceImpl implements PtImageService { throw new BusinessException(ResponseCode.SUCCESS, "镜像信息已存在,不允许重复上传!"); } - String harborImagePath = trainHarborConfig.getModelName() + StrUtil.SLASH + ptImageUploadDTO.getImageName() + - StrUtil.COLON + ptImageUploadDTO.getImageTag() + StrUtil.DASHED + currentUser.getId(); + String harborImagePath = trainHarborConfig.getModelName() + StrUtil.SLASH + ptImageUploadDTO.getImageName() + StrUtil.DASHED + currentUser.getId() + + StrUtil.COLON + ptImageUploadDTO.getImageTag(); //存储镜像信息 PtImage ptImage = new PtImage(); ptImage.setImageName(ptImageUploadDTO.getImageName()) @@ -178,6 +170,7 @@ public class PtImageServiceImpl implements PtImageService { .setRemark(ptImageUploadDTO.getRemark()) .setImageTag(ptImageUploadDTO.getImageTag()) .setCreateUserId(currentUser.getId()); + ptImage.setOriginUserId(currentUser.getId()); int count = ptImageMapper.insert(ptImage); if (count < 1) { imagePushAsync.updateImageStatus(ptImage, ImageStateEnum.FAIL.getCode()); @@ -187,7 +180,7 @@ public class PtImageServiceImpl implements PtImageService { //shell脚本上传镜像 try { String imagePath = nfsConfig.getRootDir() + nfsConfig.getBucket().substring(1) + ptImageUploadDTO.getImagePath(); - String imageNameAndTag = ptImageUploadDTO.getImageName() + StrUtil.COLON + ptImageUploadDTO.getImageTag() + StrUtil.DASHED + currentUser.getId(); + String imageNameAndTag = ptImageUploadDTO.getImageName() + StrUtil.DASHED + currentUser.getId() + StrUtil.COLON + ptImageUploadDTO.getImageTag(); imagePushAsync.execShell(imagePath, imageNameAndTag, ptImage); } catch (Exception e) { LogUtil.error(LogEnum.BIZ_TRAIN, "Image upload exception :{}", e); @@ -195,33 +188,6 @@ public class PtImageServiceImpl implements PtImageService { } } - /** - *定时到harbor同步projectName - */ - @Override - public void harborImageNameSync() { - //每天晚上11点定时去harbor同步项目名到表harbor_project - QueryWrapper query = new QueryWrapper<>(); - List imageNames = harborApi.searchImageByProjects(Arrays.asList(trainHarborConfig.getModelName())); - Set imageList = new HashSet<>(); - imageNames.forEach(image -> { - imageList.add((String) image.get("imageName")); - }); - query.in("image_name", imageList); - List harborProjects = harborProjectMapper.selectList(query); - harborProjects.forEach(harborProject -> { - imageList.removeIf(image -> image.contains(harborProject.getImageName())); - }); - - HarborProject project = new HarborProject(); - project.setCreateResource(HarborResourceEnum.TRAIN_SYNC.getCode()); - project.setSyncStatus(TrainUtil.NUMBER_ONE); - imageList.forEach(imageName -> { - project.setImageName(imageName); - harborProjectMapper.insert(project); - }); - } - /** * 查询所含镜像版本信息 * @@ -229,6 +195,7 @@ public class PtImageServiceImpl implements PtImageService { * @return List 通过imageName查询所含镜像版本信息 */ @Override + @DataPermissionMethod(dataType = DatasetTypeEnum.PUBLIC) public List searchImages(String imageName) { QueryWrapper queryWrapper = new QueryWrapper<>(); queryWrapper.eq("image_name", imageName); @@ -253,6 +220,7 @@ public class PtImageServiceImpl implements PtImageService { * @return List 获取Harbor **/ @Override + @DataPermissionMethod(dataType = DatasetTypeEnum.PUBLIC) public List getHarborProjectList() { LogUtil.info(LogEnum.BIZ_TRAIN, "Query the mirror project list..."); QueryWrapper queryWrapper = new QueryWrapper<>(); @@ -261,20 +229,99 @@ public class PtImageServiceImpl implements PtImageService { return harborProjectMapper.selectList(queryWrapper); } + /** + * 删除镜像 + * + * @param imageDeleteDTO 删除镜像条件参数 + */ + @Override + @Transactional(rollbackFor = Exception.class) + public void deleteTrainImage(PtImageDeleteDTO imageDeleteDTO) { + + UserDTO user = JwtUtils.getCurrentUserDto(); + List imageList = ptImageMapper.selectList(new LambdaQueryWrapper() + .eq(PtImage::getCreateUserId, user.getId()) + .in(PtImage::getId, imageDeleteDTO.getIds())); + + imageList.forEach(image -> { + //禁止删除预置镜像 + if (ImageSourceEnum.PRE.getCode().equals(image.getImageResource())) { + throw new BusinessException("禁止删除预置镜像"); + } + String imageUrl = trainHarborConfig.getAddress() + StrUtil.SLASH + image.getImageUrl(); + LogUtil.info(LogEnum.BIZ_TRAIN, "delete harbor image url:{}", imageUrl); + //同步删除harbor镜像 + harborApi.deleteImageByTag(imageUrl); + }); + + + //删除本地镜像 + int deleteSum = ptImageMapper.deleteBatchIds(imageDeleteDTO.getIds()); + if (deleteSum < imageDeleteDTO.getIds().size()) { + LogUtil.error(LogEnum.BIZ_TRAIN, "The user {} failed to delete image, and the pt_image table deletion operation failed according to the id array {}", user.getId(), imageDeleteDTO.getIds()); + throw new BusinessException("内部错误"); + } + } + + /** + * 修改镜像信息 + * + * @param imageUpdateDTO + */ + @Override + @Transactional(rollbackFor = Exception.class) + public void updateTrainImage(PtImageUpdateDTO imageUpdateDTO) { + + UserDTO user = JwtUtils.getCurrentUserDto(); + + List imageList = ptImageMapper.selectList(new LambdaQueryWrapper() + .eq(PtImage::getCreateUserId, user.getId()) + .in(PtImage::getId, imageUpdateDTO.getIds())); + + if (CollectionUtils.isEmpty(imageList)) { + LogUtil.error(LogEnum.BIZ_TRAIN, "The user{} update image failed,inquire condition ids{} not result", user.getId(), imageUpdateDTO.getIds()); + throw new BusinessException("内部错误"); + } + for (PtImage image : imageList) { + //禁止修改预置镜像 + if (ImageSourceEnum.PRE.getCode().equals(image.getImageResource())) { + throw new BusinessException("无法修改预置镜像信息"); + } + image.setRemark(imageUpdateDTO.getRemark()); + LogUtil.info(LogEnum.BIZ_TRAIN, "The user{}update image,update image info:{}", user.getId(), image); + ptImageMapper.updateById(image); + } + } + + /** + * 获取镜像名称列表 + * + * @return Set 镜像列表 + */ + @Override + @DataPermissionMethod(dataType = DatasetTypeEnum.PUBLIC) + public Set getImageNameList() { + + UserDTO currentUser = JwtUtils.getCurrentUserDto(); + List imageList = ptImageMapper.selectList(new LambdaQueryWrapper().eq(PtImage::getCreateUserId, currentUser.getId()).or().eq(PtImage::getImageResource, ImageSourceEnum.PRE.getCode())); + Set imageNames = new HashSet<>(); + imageList.forEach(image -> { + imageNames.add(image.getImageName()); + }); + return imageNames; + } + + /** * @param ptImageUploadDTO 镜像上传逻辑校验 * @param user 用户 * @param source 来源 * @return List 镜像列表 **/ + @DataPermissionMethod(dataType = DatasetTypeEnum.PUBLIC) private List checkUploadImage(PtImageUploadDTO ptImageUploadDTO, UserDTO user, int source) { QueryWrapper queryWrapper = new QueryWrapper<>(); - if (ImageSourceEnum.PRE.getCode().equals(source)) { - queryWrapper.eq("create_user_id", user.getId()); - } else { - queryWrapper.eq("image_resource", ImageSourceEnum.PRE.getCode()); - } queryWrapper.eq("image_name", ptImageUploadDTO.getImageName()); queryWrapper.eq("image_tag", ptImageUploadDTO.getImageTag()); List imageList = ptImageMapper.selectList(queryWrapper); diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtStorageServiceImpl.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtStorageServiceImpl.java index 8426079..ae491a0 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtStorageServiceImpl.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtStorageServiceImpl.java @@ -28,7 +28,6 @@ import org.dubhe.service.PtStorageService; import org.dubhe.service.convert.PtStorageConvert; import org.dubhe.utils.FileUtil; import org.dubhe.utils.PageUtil; -import org.dubhe.utils.StringUtils; import org.dubhe.utils.WrapperHelp; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.cache.annotation.CacheConfig; diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtTrainAlgorithmServiceImpl.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtTrainAlgorithmServiceImpl.java index 999dfbe..6dd0216 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtTrainAlgorithmServiceImpl.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtTrainAlgorithmServiceImpl.java @@ -24,10 +24,13 @@ import com.baomidou.mybatisplus.core.metadata.IPage; import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.collections4.map.HashedMap; +import org.dubhe.annotation.DataPermissionMethod; +import org.dubhe.async.TrainAlgorithmUploadAsync; import org.dubhe.base.MagicNumConstant; import org.dubhe.base.ResponseCode; import org.dubhe.config.NfsConfig; import org.dubhe.constant.AlgorithmSourceEnum; +import org.dubhe.config.RecycleConfig; import org.dubhe.constant.TrainAlgorithmConstant; import org.dubhe.dao.NoteBookMapper; import org.dubhe.dao.PtImageMapper; @@ -38,13 +41,11 @@ import org.dubhe.domain.entity.NoteBook; import org.dubhe.domain.entity.PtImage; import org.dubhe.domain.entity.PtTrainAlgorithm; import org.dubhe.domain.vo.PtTrainAlgorithmQueryVO; -import org.dubhe.enums.BizNfsEnum; -import org.dubhe.enums.ImageSourceEnum; -import org.dubhe.enums.ImageStateEnum; -import org.dubhe.enums.LogEnum; +import org.dubhe.enums.*; import org.dubhe.exception.BusinessException; import org.dubhe.service.NoteBookService; import org.dubhe.service.PtTrainAlgorithmService; +import org.dubhe.service.RecycleTaskService; import org.dubhe.utils.*; import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Autowired; @@ -74,6 +75,9 @@ public class PtTrainAlgorithmServiceImpl implements PtTrainAlgorithmService { @Autowired private NfsUtil nfsUtil; + @Autowired + private LocalFileUtil localFileUtil; + @Autowired private K8sNameTool k8sNameTool; @@ -89,10 +93,19 @@ public class PtTrainAlgorithmServiceImpl implements PtTrainAlgorithmService { @Autowired private NoteBookMapper noteBookMapper; - public final static List filedNames; + @Autowired + private TrainAlgorithmUploadAsync algorithmUpdateAsync; + + @Autowired + private RecycleTaskService recycleTaskService; + + @Autowired + private RecycleConfig recycleConfig; + + public final static List FIELD_NAMES; static { - filedNames = ReflectionUtils.getFieldNames(PtTrainAlgorithmQueryVO.class); + FIELD_NAMES = ReflectionUtils.getFieldNames(PtTrainAlgorithmQueryVO.class); } /** @@ -102,6 +115,7 @@ public class PtTrainAlgorithmServiceImpl implements PtTrainAlgorithmService { * @return Map 返回查询数据 */ @Override + @DataPermissionMethod(dataType = DatasetTypeEnum.PUBLIC) public Map queryAll(PtTrainAlgorithmQueryDTO ptTrainAlgorithmQueryDTO) { //从会话中获取用户信息 UserDTO user = JwtUtils.getCurrentUserDto(); @@ -128,7 +142,7 @@ public class PtTrainAlgorithmServiceImpl implements PtTrainAlgorithmService { Page page = ptTrainAlgorithmQueryDTO.toPage(); IPage ptTrainAlgorithms; try { - if (ptTrainAlgorithmQueryDTO.getSort() != null && filedNames.contains(ptTrainAlgorithmQueryDTO.getSort())) { + if (ptTrainAlgorithmQueryDTO.getSort() != null && FIELD_NAMES.contains(ptTrainAlgorithmQueryDTO.getSort())) { if (Constant.SORT_ASC.equalsIgnoreCase(ptTrainAlgorithmQueryDTO.getOrder())) { wrapper.orderByAsc(StringUtils.humpToLine(ptTrainAlgorithmQueryDTO.getSort())); } else { @@ -166,14 +180,9 @@ public class PtTrainAlgorithmServiceImpl implements PtTrainAlgorithmService { //从会话中获取用户信息 UserDTO user = JwtUtils.getCurrentUserDto(); LogUtil.info(LogEnum.BIZ_TRAIN, "Save the new algorithm and receive the parameter {}", ptTrainAlgorithmCreateDTO); - // 校验path - if (!(k8sNameTool.validateBizNfsPath(ptTrainAlgorithmCreateDTO.getCodeDir(), BizNfsEnum.ALGORITHM))) { - LogUtil.error(LogEnum.BIZ_TRAIN, "The user {} passed in the path {} is not valid", user.getUsername(), ptTrainAlgorithmCreateDTO.getCodeDir()); - throw new BusinessException("路径名称不合法"); - } //获取镜像url if (StringUtils.isNotBlank(ptTrainAlgorithmCreateDTO.getImageName()) && StringUtils.isNotBlank(ptTrainAlgorithmCreateDTO.getImageTag())) { - ptTrainAlgorithmCreateDTO.setImageName(getImages(ptTrainAlgorithmCreateDTO, user)); + ptTrainAlgorithmCreateDTO.setImageName(getImageUrl(ptTrainAlgorithmCreateDTO, user)); } //创建算法校验DTO并设置默认值 setAlgorithmDtoDefault(ptTrainAlgorithmCreateDTO); @@ -192,7 +201,6 @@ public class PtTrainAlgorithmServiceImpl implements PtTrainAlgorithmService { //算法名称校验 QueryWrapper queryWrapper = new QueryWrapper<>(); queryWrapper.eq("algorithm_name", ptTrainAlgorithmCreateDTO.getAlgorithmName()); - queryWrapper.eq("create_user_id", user.getId()); Integer countResult = ptTrainAlgorithmMapper.selectCount(queryWrapper); //如果是通过【保存至算法】接口创建算法,名称重复可用随机数生成新算法名,待后续客户自主修改 if (countResult > 0) { @@ -208,28 +216,12 @@ public class PtTrainAlgorithmServiceImpl implements PtTrainAlgorithmService { if (path.toLowerCase().endsWith(Constant.COMPRESS_ZIP)) { unZip(user, path, ptTrainAlgorithm); } - //校验创建算法来源(true:由fork创建算法,false:其它创建算法方式),若为true则拷贝预置算法文件至新路径 - if (ptTrainAlgorithmCreateDTO.getFork()) { - //生成算法相对路径 - String algorithmPath = k8sNameTool.getNfsPath(BizNfsEnum.ALGORITHM, user.getId()); - //拷贝预置算法文件夹 - boolean copyResult = nfsUtil.copyPath(path, nfsConfig.getBucket() + algorithmPath); - if (!copyResult) { - LogUtil.error(LogEnum.BIZ_TRAIN, "The user {} copied the preset algorithm path {} successfully", user.getUsername(), path); - throw new BusinessException("内部错误"); - } - ptTrainAlgorithm.setCodeDir(algorithmPath); - } - try { //算法未保存成功,抛出异常,并返回失败信息 ptTrainAlgorithmMapper.insert(ptTrainAlgorithm); - //保存算法根据notbookId更新算法id - if (ptTrainAlgorithmCreateDTO.getNoteBookId() != null) { - LogUtil.info(LogEnum.BIZ_TRAIN, "Save algorithm Update algorithm ID :{} according to notBookId:{}", ptTrainAlgorithmCreateDTO.getNoteBookId(), ptTrainAlgorithm.getId()); - noteBookService.updateTrainIdByNoteBookId(ptTrainAlgorithmCreateDTO.getNoteBookId(), ptTrainAlgorithm.getId()); - } + //上传算法异步处理 + algorithmUpdateAsync.createTrainAlgorithm(user, ptTrainAlgorithm, ptTrainAlgorithmCreateDTO); } catch (Exception e) { LogUtil.error(LogEnum.BIZ_TRAIN, "The user {} saving algorithm was not successful. Failure reason :{}", user.getUsername(), e.getMessage()); throw new BusinessException("算法未保存成功"); @@ -263,7 +255,6 @@ public class PtTrainAlgorithmServiceImpl implements PtTrainAlgorithmService { //算法名称校验 QueryWrapper queryWrapper = new QueryWrapper<>(); queryWrapper.eq("algorithm_name", ptTrainAlgorithmUpdateDTO.getAlgorithmName()) - .eq("create_user_id", currentUser.getId()) .ne("id", ptTrainAlgorithmUpdateDTO.getId()); Integer countResult = ptTrainAlgorithmMapper.selectCount(queryWrapper); if (countResult > 0) { @@ -300,28 +291,22 @@ public class PtTrainAlgorithmServiceImpl implements PtTrainAlgorithmService { } /** - * 解压缩zip压缩包 + * 解压缩zip压缩包 * - * @param user 用户 - * @param path 文件路径 - * @param ptTrainAlgorithm 算法参数 + * @param user 用户 + * @param path 文件路径 + * @param ptTrainAlgorithm 算法参数 */ private void unZip(UserDTO user, String path, PtTrainAlgorithm ptTrainAlgorithm) { - String[] pathArray = path.split(StrUtil.SLASH); - String pathSuffix = pathArray[pathArray.length - 1]; - String targetPath = path.replace(pathSuffix, ""); - //上传路径垃圾文件清理 - Boolean aBoolean = nfsUtil.cleanPath(path, targetPath); - if (!aBoolean) { - LogUtil.error(LogEnum.BIZ_TRAIN, "User {} failed to clean up {} garbage", user.getUsername(), targetPath); - } - Boolean unzip = nfsUtil.unzip(path, targetPath); + //目标路径 + String targetPath = k8sNameTool.getNfsPath(BizNfsEnum.ALGORITHM, user.getId()); + boolean unzip = localFileUtil.unzipLocalPath(path, nfsConfig.getBucket() + targetPath); if (!unzip) { LogUtil.error(LogEnum.BIZ_TRAIN, "User {} failed to unzip", user.getUsername()); throw new BusinessException("内部错误"); } //算法路径 - ptTrainAlgorithm.setCodeDir(StrUtil.SLASH + path.replace(nfsConfig.getBucket(), "").replace(pathSuffix, "")); + ptTrainAlgorithm.setCodeDir(targetPath); } /** @@ -331,6 +316,7 @@ public class PtTrainAlgorithmServiceImpl implements PtTrainAlgorithmService { */ @Override @Transactional(rollbackFor = Exception.class) + @DataPermissionMethod(dataType = DatasetTypeEnum.PUBLIC) public void deleteAll(PtTrainAlgorithmDeleteDTO ptTrainAlgorithmDeleteDTO) { //从会话中获取用户信息 UserDTO user = JwtUtils.getCurrentUserDto(); @@ -338,10 +324,9 @@ public class PtTrainAlgorithmServiceImpl implements PtTrainAlgorithmService { Set idList = ptTrainAlgorithmDeleteDTO.getIds(); //权限校验 QueryWrapper query = new QueryWrapper<>(); - query.eq("create_user_id", user.getId()); query.in("id", idList); - Integer queryCountResult = ptTrainAlgorithmMapper.selectCount(query); - if (queryCountResult < idList.size()) { + List algorithmList = ptTrainAlgorithmMapper.selectList(query); + if (algorithmList.size() < idList.size()) { LogUtil.error(LogEnum.BIZ_TRAIN, "User {} delete algorithm failed, no permission to delete the corresponding data in the algorithm table", user.getUsername()); throw new BusinessException(ResponseCode.SUCCESS, "您删除的ID不存在或已被删除"); } @@ -352,7 +337,6 @@ public class PtTrainAlgorithmServiceImpl implements PtTrainAlgorithmService { } //同步更新noteBook表中algorithmId=0 QueryWrapper queryWrapper = new QueryWrapper<>(); - queryWrapper.eq("user_id", user.getId()); queryWrapper.in("algorithm_id", idList); List noteBookList = noteBookMapper.selectList(queryWrapper); if (!CollectionUtils.isEmpty(noteBookList)) { @@ -360,6 +344,17 @@ public class PtTrainAlgorithmServiceImpl implements PtTrainAlgorithmService { noteBookService.updateTrainIdByNoteBookId(noteBook.getId(), null); }); } + //定时任务删除相应的算法文件 + for (PtTrainAlgorithm algorithm : algorithmList) { + RecycleTaskCreateDTO recycleTask = new RecycleTaskCreateDTO(); + recycleTask.setRecycleModule(RecycleModuleEnum.BIZ_ALGORITHM.getValue()) + .setRecycleType(RecycleTypeEnum.FILE.getCode()) + .setRecycleDelayDate(recycleConfig.getAlgorithmValid()) + .setRecycleCondition(nfsUtil.formatPath(nfsConfig.getRootDir() + nfsConfig.getBucket() + algorithm.getCodeDir())) + .setRecycleNote("删除算法文件"); + recycleTaskService.createRecycleTask(recycleTask); + } + LogUtil.info(LogEnum.BIZ_TRAIN, "User {} delete algorithm end, delete algorithm ID array IDS ={}", user.getUsername(), idList); } @@ -374,7 +369,6 @@ public class PtTrainAlgorithmServiceImpl implements PtTrainAlgorithmService { UserDTO user = JwtUtils.getCurrentUserDto(); LogUtil.info(LogEnum.BIZ_TRAIN, "The user {} queries his algorithm number", user.getUsername()); QueryWrapper wrapper = new QueryWrapper(); - wrapper.eq("create_user_id", user.getId()); wrapper.eq("algorithm_source", AlgorithmSourceEnum.MINE.getStatus()); Integer countResult = ptTrainAlgorithmMapper.selectCount(wrapper); return new HashedMap() {{ @@ -385,8 +379,8 @@ public class PtTrainAlgorithmServiceImpl implements PtTrainAlgorithmService { /** * 获取镜像名称与版本 * - * @param trainAlgorithm 镜像URL - * @param ptTrainAlgorithmQueryVO 镜像名称与版本 + * @param trainAlgorithm 镜像URL + * @param ptTrainAlgorithmQueryVO 镜像名称与版本 */ private void getImageNameAndImageTag(PtTrainAlgorithm trainAlgorithm, PtTrainAlgorithmQueryVO ptTrainAlgorithmQueryVO) { if (StringUtils.isNotBlank(trainAlgorithm.getImageName())) { @@ -426,37 +420,22 @@ public class PtTrainAlgorithmServiceImpl implements PtTrainAlgorithmService { /** * 获取镜像url * - * @param ptTrainAlgorithmCreateDTO 获取镜像 - * @param user 用户 + * @param ptTrainAlgorithmCreateDTO 获取镜像 + * @param user 用户 * @return String 返回镜像路径 **/ - private String getImages(PtTrainAlgorithmCreateDTO ptTrainAlgorithmCreateDTO, UserDTO user) { + private String getImageUrl(PtTrainAlgorithmCreateDTO ptTrainAlgorithmCreateDTO, UserDTO user) { //获取镜像url QueryWrapper queryWrapper = new QueryWrapper<>(); - queryWrapper.eq("image_name", ptTrainAlgorithmCreateDTO.getImageName()); - queryWrapper.eq("image_tag", ptTrainAlgorithmCreateDTO.getImageTag()); - queryWrapper.eq("image_status", ImageStateEnum.SUCCESS.getCode()); - List ptImages = ptImageMapper.selectList(queryWrapper); - if (CollectionUtils.isEmpty(ptImages)) { + queryWrapper.eq("image_name", ptTrainAlgorithmCreateDTO.getImageName()) + .eq("image_tag", ptTrainAlgorithmCreateDTO.getImageTag()) + .eq("image_status", ImageStateEnum.SUCCESS.getCode()).last(" limit 1 "); + ; + PtImage ptImage = ptImageMapper.selectOne(queryWrapper); + if (ptImage == null || StringUtils.isBlank(ptImage.getImageUrl())) { LogUtil.error(LogEnum.BIZ_TRAIN, "User {} gets image ,the imageName is {}, the imageTag is {}, and the result of query image table (PT_image) is empty", user.getUsername(), ptTrainAlgorithmCreateDTO.getImageName(), ptTrainAlgorithmCreateDTO.getImageTag()); throw new BusinessException("镜像不存在"); } - //获取镜像为用户自定义镜像或预置镜像,且两者自身不能重复 - if (ptImages.size() > MagicNumConstant.TWO) { - LogUtil.error(LogEnum.BIZ_TRAIN, "User {} got more images than scheduled, the imageName provided is {} and the imageTag is {}. The parameters are illegal", user.getUsername(), ptTrainAlgorithmCreateDTO.getImageName(), ptTrainAlgorithmCreateDTO.getImageTag()); - throw new BusinessException("镜像不合法"); - } - for (PtImage ptImage : ptImages) { - if (ImageSourceEnum.PRE.getCode().equals(ptImage.getImageResource())) { - ptTrainAlgorithmCreateDTO.setImageName(ptImage.getImageUrl()); - } else if (user.getId().equals(ptImage.getCreateUserId())) { - ptTrainAlgorithmCreateDTO.setImageName(ptImage.getImageUrl()); - } - } - if (StringUtils.isBlank(ptTrainAlgorithmCreateDTO.getImageName())) { - LogUtil.error(LogEnum.BIZ_TRAIN, "User {} gets image, the imageName provided is {} and the imageTag is {}. The parameters are illegal", user.getUsername(), ptTrainAlgorithmCreateDTO.getImageName(), ptTrainAlgorithmCreateDTO.getImageTag()); - throw new BusinessException("镜像不合法"); - } - return ptTrainAlgorithmCreateDTO.getImageName(); + return ptImage.getImageUrl(); } } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtTrainAlgorithmUsageServiceImpl.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtTrainAlgorithmUsageServiceImpl.java index 0a209cd..8ff64df 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtTrainAlgorithmUsageServiceImpl.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtTrainAlgorithmUsageServiceImpl.java @@ -20,6 +20,8 @@ package org.dubhe.service.impl; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.baomidou.mybatisplus.core.metadata.IPage; import com.baomidou.mybatisplus.extension.plugins.pagination.Page; +import org.dubhe.aspect.PermissionAspect; +import org.dubhe.base.DataContext; import org.dubhe.base.MagicNumConstant; import org.dubhe.base.ResponseCode; import org.dubhe.dao.PtTrainAlgorithmUsageMapper; @@ -38,10 +40,7 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.Set; +import java.util.*; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -72,23 +71,24 @@ public class PtTrainAlgorithmUsageServiceImpl implements PtTrainAlgorithmUsageSe IPage ptTrainAlgorithms = null; if (ptTrainAlgorithmUsageQueryDTO.getIsContainDefault()) { - wrapper.and(qw -> qw.eq("user_id", user.getId()).or().eq("is_default", - ptTrainAlgorithmUsageQueryDTO.getIsContainDefault())); + wrapper.in("origin_user_id", user.getId(), PermissionAspect.PUBLIC_DATA_USER_ID); } else { - wrapper.eq("user_id", user.getId()); + wrapper.eq("origin_user_id", user.getId()); } wrapper.eq("type", ptTrainAlgorithmUsageQueryDTO.getType()); + DataContext.set(CommonPermissionDataDTO.builder().type(true).build()); ptTrainAlgorithms = ptTrainAlgorithUsagemMapper.selectPage(page, wrapper); + DataContext.remove(); List ptTrainAlgorithmUsageQueryResult = ptTrainAlgorithms.getRecords().stream() .map(x -> { PtTrainAlgorithmUsageQueryVO ptTrainAlgorithmUsageQueryVO = new PtTrainAlgorithmUsageQueryVO(); BeanUtils.copyProperties(x, ptTrainAlgorithmUsageQueryVO); + ptTrainAlgorithmUsageQueryVO.setIsDefault(Objects.equals(x.getOriginUserId(),PermissionAspect.PUBLIC_DATA_USER_ID)); return ptTrainAlgorithmUsageQueryVO; }).collect(Collectors.toList()); - return PageUtil.toPage(page, ptTrainAlgorithmUsageQueryResult); } @@ -104,7 +104,7 @@ public class PtTrainAlgorithmUsageServiceImpl implements PtTrainAlgorithmUsageSe UserDTO user = JwtUtils.getCurrentUserDto(); PtTrainAlgorithmUsage ptTrainAlgorithmUsage = new PtTrainAlgorithmUsage(); ptTrainAlgorithmUsage.setAuxInfo(ptTrainAlgorithmUsageCreateDTO.getAuxInfo()) - .setType(ptTrainAlgorithmUsageCreateDTO.getType()).setUserId(user.getId()); + .setType(ptTrainAlgorithmUsageCreateDTO.getType()); int insertResult = ptTrainAlgorithUsagemMapper.insert(ptTrainAlgorithmUsage); @@ -126,7 +126,6 @@ public class PtTrainAlgorithmUsageServiceImpl implements PtTrainAlgorithmUsageSe UserDTO user = JwtUtils.getCurrentUserDto(); Set idList = Stream.of(ptTrainAlgorithmUsageDeleteDTO.getIds()).collect(Collectors.toSet()); QueryWrapper query = new QueryWrapper<>(); - query.eq("user_id", user.getId()); query.in("id", idList); Integer queryCountResult = ptTrainAlgorithUsagemMapper.selectCount(query); @@ -140,7 +139,6 @@ public class PtTrainAlgorithmUsageServiceImpl implements PtTrainAlgorithmUsageSe LogUtil.error(LogEnum.BIZ_TRAIN, "User {} failed to delete user assistance information. User service deletion based on id array {} failed", user.getUsername(), ptTrainAlgorithmUsageDeleteDTO.getIds()); throw new BusinessException(ResponseCode.SUCCESS, "内部错误"); } - } /** @@ -154,7 +152,6 @@ public class PtTrainAlgorithmUsageServiceImpl implements PtTrainAlgorithmUsageSe UserDTO user = JwtUtils.getCurrentUserDto(); QueryWrapper query = new QueryWrapper<>(); - query.eq("user_id", user.getId()); query.in("id", ptTrainAlgorithmUsageUpdateDTO.getId()); Integer queryIntResult = ptTrainAlgorithUsagemMapper.selectCount(query); diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtTrainJobServiceImpl.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtTrainJobServiceImpl.java index 8393072..1b1a466 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtTrainJobServiceImpl.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtTrainJobServiceImpl.java @@ -18,45 +18,109 @@ package org.dubhe.service.impl; import cn.hutool.core.bean.BeanUtil; +import cn.hutool.core.io.FileUtil; import cn.hutool.core.util.StrUtil; import com.alibaba.fastjson.JSONObject; +import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper; import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import org.apache.commons.collections4.CollectionUtils; import org.apache.commons.collections4.map.HashedMap; +import org.dubhe.annotation.DataPermissionMethod; +import org.dubhe.async.StopTrainJobAsync; +import org.dubhe.async.TransactionAsyncManager; import org.dubhe.base.MagicNumConstant; import org.dubhe.base.ResponseCode; +import org.dubhe.config.NfsConfig; import org.dubhe.config.TrainHarborConfig; import org.dubhe.constant.AlgorithmSourceEnum; -import org.dubhe.constant.TrainJobConstant; -import org.dubhe.dao.*; +import org.dubhe.config.RecycleConfig; +import org.dubhe.constant.SymbolConstant; +import org.dubhe.config.TrainJobConfig; +import org.dubhe.dao.DictDetailMapper; +import org.dubhe.dao.ModelQueryMapper; +import org.dubhe.dao.PtJobParamMapper; +import org.dubhe.dao.PtTrainAlgorithmMapper; +import org.dubhe.dao.PtTrainJobMapper; +import org.dubhe.dao.PtTrainJobSpecsMapper; +import org.dubhe.dao.PtTrainMapper; +import org.dubhe.dao.PtTrainParamMapper; import org.dubhe.data.constant.Constant; -import org.dubhe.domain.dto.*; -import org.dubhe.domain.entity.*; -import org.dubhe.domain.vo.*; +import org.dubhe.domain.dto.BaseTrainJobDTO; +import org.dubhe.domain.dto.PtTrainDataSourceStatusQueryDTO; +import org.dubhe.domain.dto.PtTrainJobCreateDTO; +import org.dubhe.domain.dto.PtTrainJobDeleteDTO; +import org.dubhe.domain.dto.PtTrainJobDetailQueryDTO; +import org.dubhe.domain.dto.PtTrainJobResumeDTO; +import org.dubhe.domain.dto.PtTrainJobStopDTO; +import org.dubhe.domain.dto.PtTrainJobUpdateDTO; +import org.dubhe.domain.dto.PtTrainJobVersionQueryDTO; +import org.dubhe.domain.dto.PtTrainQueryDTO; +import org.dubhe.domain.dto.RecycleTaskCreateDTO; +import org.dubhe.domain.dto.UserDTO; +import org.dubhe.domain.entity.DictDetail; +import org.dubhe.domain.entity.ModelQuery; +import org.dubhe.domain.entity.ModelQueryBrance; +import org.dubhe.domain.entity.PtJobParam; +import org.dubhe.domain.entity.PtTrain; +import org.dubhe.domain.entity.PtTrainAlgorithm; +import org.dubhe.domain.entity.PtTrainJob; +import org.dubhe.domain.entity.PtTrainJobSpecs; +import org.dubhe.domain.entity.PtTrainParam; +import org.dubhe.domain.vo.PtImageAndAlgorithmVO; +import org.dubhe.domain.vo.PtJobMetricsGrafanaVO; +import org.dubhe.domain.vo.PtTrainDataSourceStatusQueryVO; +import org.dubhe.domain.vo.PtTrainJobDeleteVO; +import org.dubhe.domain.vo.PtTrainJobDetailQueryVO; +import org.dubhe.domain.vo.PtTrainJobDetailVO; +import org.dubhe.domain.vo.PtTrainJobStatisticsMineVO; +import org.dubhe.domain.vo.PtTrainJobStopVO; +import org.dubhe.domain.vo.PtTrainVO; +import org.dubhe.enums.AlgorithmStatusEnum; +import org.dubhe.enums.DatasetTypeEnum; import org.dubhe.enums.LogEnum; +import org.dubhe.enums.RecycleModuleEnum; +import org.dubhe.enums.RecycleTypeEnum; import org.dubhe.enums.TrainJobStatusEnum; +import org.dubhe.enums.TrainTypeEnum; import org.dubhe.exception.BusinessException; +import org.dubhe.k8s.api.DistributeTrainApi; import org.dubhe.k8s.api.PersistentVolumeClaimApi; import org.dubhe.k8s.api.PodApi; import org.dubhe.k8s.api.TrainJobApi; import org.dubhe.k8s.domain.PtBaseResult; import org.dubhe.k8s.domain.resource.BizPod; +import org.dubhe.k8s.utils.PodUtil; import org.dubhe.service.PtTrainJobService; -import org.dubhe.task.TransactionAsyncManager; -import org.dubhe.utils.*; +import org.dubhe.service.RecycleTaskService; +import org.dubhe.utils.ImageUtil; +import org.dubhe.utils.JwtUtils; +import org.dubhe.utils.K8sNameTool; +import org.dubhe.utils.KeyUtil; +import org.dubhe.utils.LogUtil; +import org.dubhe.utils.NfsUtil; +import org.dubhe.utils.PageUtil; +import org.dubhe.utils.ReflectionUtils; +import org.dubhe.utils.SqlUtil; +import org.dubhe.utils.StringUtils; +import org.dubhe.utils.TrainUtil; import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; import org.springframework.lang.NonNull; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; -import java.time.LocalDateTime; -import java.time.ZoneOffset; -import java.time.format.DateTimeFormatter; -import java.util.*; -import java.util.function.Consumer; +import java.io.File; +import java.sql.Timestamp; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; /** @@ -94,10 +158,7 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { private PodApi podApi; @Autowired - private TrainJobConstant trainJobConstant; - - @Autowired - private NfsUtil nfsUtil; + private TrainJobConfig trainJobConfig; @Autowired private TrainHarborConfig trainHarborConfig; @@ -111,10 +172,38 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { @Autowired private TransactionAsyncManager asyncManager; - public final static List filedNames; + @Autowired + private StopTrainJobAsync stopTrainJobAsync; + + @Autowired + private DistributeTrainApi distributeTrainApi; + + @Autowired + private DictDetailMapper dictDetailMapper; + + + @Autowired + private ModelQueryMapper modelQueryMapper; + + @Autowired + private NfsConfig nfsConfig; + + @Autowired + private NfsUtil nfsUtil; + + @Autowired + private RecycleConfig recycleConfig; + + @Autowired + private RecycleTaskService recycleTaskService; + + @Value("${k8s.pod.metrics.grafanaUrl}") + private String k8sPodMetricsGrafanaUrl; + + public final static List FIELD_NAMES; static { - filedNames = ReflectionUtils.getFieldNames(PtTrainVO.class); + FIELD_NAMES = ReflectionUtils.getFieldNames(PtTrainVO.class); } /** @@ -124,6 +213,7 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { * @return Map 作业列表分页信息 **/ @Override + @DataPermissionMethod(dataType = DatasetTypeEnum.PUBLIC) public Map getTrainJob(@NonNull PtTrainQueryDTO ptTrainQueryDTO) { Page pageTrainResult; UserDTO currentUser = JwtUtils.getCurrentUserDto(); @@ -136,7 +226,7 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { //排序方式 order = Constant.SORT_ASC.equalsIgnoreCase(ptTrainQueryDTO.getOrder()) ? Constant.SORT_ASC : Constant.SORT_DESC; //排序字段 - String sortField = filedNames.contains(ptTrainQueryDTO.getSort()) ? ptTrainQueryDTO.getSort() : Constant.ID; + String sortField = FIELD_NAMES.contains(ptTrainQueryDTO.getSort()) ? ptTrainQueryDTO.getSort() : Constant.ID; sort = StringUtils.humpToLine(sortField); pageTrainResult = ptTrainJobMapper.getPageTrain(page, currentUser.getId(), ptTrainQueryDTO.getTrainStatus(), ptTrainQueryDTO.getTrainName(), sort, order); } catch (Exception e) { @@ -151,53 +241,6 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { } - /** - * 计算job训练时长 - * - * @param bizPod pod信息 - * @return String 训练时长 - */ - private String calculateRuntime(BizPod bizPod) { - return calculateRuntime(bizPod, (x) -> { - }); - } - - /** - * 计算job训练时长 - * - * @param bizPod - * @param consumer pod已经完成状态的回调函数 - * @return res 返回训练时长 - */ - private String calculateRuntime(BizPod bizPod, Consumer consumer) { - Long completedTime; - if (StringUtils.isBlank(bizPod.getStartTime())) { - return ""; - } - Long startTime = transformTime(bizPod.getStartTime()); - boolean hasCompleted = StringUtils.isNotBlank(bizPod.getCompletedTime()); - completedTime = hasCompleted ? transformTime(bizPod.getCompletedTime()) : LocalDateTime.now().toEpochSecond(ZoneOffset.of(trainJobConstant.getPlusEight())); - Long time = completedTime - startTime; - String res = DubheDateUtil.convert2Str(time); - if (hasCompleted) { - consumer.accept(res); - } - return res; - } - - /** - * 时间转换 - * - * @param time 时间 - * @return Long 时间戳 - */ - private Long transformTime(String time) { - LocalDateTime localDateTime = LocalDateTime.parse(time, DateTimeFormatter.ISO_OFFSET_DATE_TIME); - //没有根据时区做处理, 默认当前为东八区 - localDateTime = localDateTime.plusHours(Long.valueOf(trainJobConstant.getEight())); - return localDateTime.toEpochSecond(ZoneOffset.of(trainJobConstant.getPlusEight())); - } - /** * 作业不同版本job列表展示 * @@ -205,6 +248,7 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { * @return List 训练详情集合 **/ @Override + @DataPermissionMethod(dataType = DatasetTypeEnum.PUBLIC) public List getTrainJobVersion(PtTrainJobVersionQueryDTO ptTrainJobVersionQueryDTO) { UserDTO currentUser = JwtUtils.getCurrentUserDto(); LogUtil.info(LogEnum.BIZ_TRAIN, "User {} queries different versions of job list display, received parameter trainId is {}", currentUser.getUsername(), ptTrainJobVersionQueryDTO.getTrainId()); @@ -273,6 +317,15 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { .setRunParams(x.getRunParams()) .setParamF1(x.getParamF1()).setParamCallback(x.getParamCallback()) .setParamPrecise(x.getParamPrecise()).setParamAccuracy(x.getParamAccuracy()); + long nowTime = System.currentTimeMillis(); + //获取训练延时启动倒计时(分钟) + if (x.getDelayCreateTime() != null && nowTime < x.getDelayCreateTime().getTime() && TrainJobStatusEnum.checkRunStatus(ptTrainJobDetailVO.getTrainStatus())) { + ptTrainJobDetailVO.setDelayCreateCountDown(TrainUtil.getCountDown(x.getDelayCreateTime().getTime())); + } + //获取训练自动停止倒计时(分钟) + if (x.getDelayDeleteTime() != null && nowTime < x.getDelayDeleteTime().getTime() && TrainJobStatusEnum.checkRunStatus(ptTrainJobDetailVO.getTrainStatus())) { + ptTrainJobDetailVO.setDelayDeleteCountDown(TrainUtil.getCountDown(x.getDelayDeleteTime().getTime())); + } //image信息拼装 if (StringUtils.isNotBlank(x.getImageName())) { String imageNameSuffix = x.getImageName().substring(x.getImageName().lastIndexOf(StrUtil.SLASH) + MagicNumConstant.ONE); @@ -294,6 +347,9 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { .setAlgorithmUsage(ptTrainAlgorithm.getAlgorithmUsage()) .setAccuracy(ptTrainAlgorithm.getAccuracy()) .setP4InferenceSpeed(ptTrainAlgorithm.getP4InferenceSpeed()); + if (ptTrainAlgorithm.getAlgorithmSource() == MagicNumConstant.ONE) { + ptTrainJobDetailVO.setAlgorithmCodeDir(ptTrainAlgorithm.getCodeDir()); + } } } @@ -327,10 +383,12 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { */ @Override @Transactional(rollbackFor = Exception.class) + @DataPermissionMethod(dataType = DatasetTypeEnum.PUBLIC) public List createTrainJobVersion(PtTrainJobCreateDTO ptTrainJobCreateDTO) { UserDTO currentUser = JwtUtils.getCurrentUserDto(); LogUtil.info(LogEnum.BIZ_TRAIN, "User {} creates a training job and receives {} as an argument", currentUser.getUsername(), ptTrainJobCreateDTO); + // 判断当前trainName是否已经存在 checkTrainName(ptTrainJobCreateDTO.getTrainName(), currentUser.getId()); // 校验trainParamName是否存在 @@ -343,21 +401,28 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { PtImageAndAlgorithmVO ptImageAndAlgorithmVO = getPtImageByAlgorithmId(ptTrainJobCreateDTO.getAlgorithmId(), currentUser.getId()); //使用用户创建训练时提供的镜像与运行命令 - String images = imageUtil.getImages(ptTrainJobCreateDTO, currentUser); + String images = imageUtil.getImageUrl(ptTrainJobCreateDTO, currentUser); ptImageAndAlgorithmVO.setImageName(trainHarborConfig.getAddress() + StrUtil.SLASH + images).setRunCommand(ptTrainJobCreateDTO.getRunCommand()); - // 获取规格 - PtTrainJobSpecs ptTrainJobSpecs = getSpecs(ptTrainJobCreateDTO.getTrainJobSpecsId(), currentUser); //jobKey String trainKey = KeyUtil.generateTrainKey(currentUser.getId()); + + //获取规格 + PtTrainJobSpecs ptTrainJobSpecs = new PtTrainJobSpecs(); + + ptTrainJobSpecs.setResourcesPoolType(ptTrainJobCreateDTO.getResourcesPoolType()); + ptTrainJobSpecs.setSpecsName(ptTrainJobCreateDTO.getTrainJobSpecsName()); + ptTrainJobSpecs.setSpecsInfo(JSONObject.parseObject(ptTrainJobCreateDTO.getTrainJobSpecsInfo())); + //版本 - String version = trainJobConstant.getVersionLabel() + String.format(TrainUtil.FOUR_DECIMAL, 1); + String version = trainJobConfig.getVersionLabel() + String.format(TrainUtil.FOUR_DECIMAL, 1); //生成k8s 的job名称 - String jobName = trainKey + trainJobConstant.getSeparator() + version; + String jobName = trainKey + trainJobConfig.getSeparator() + version; BaseTrainJobDTO baseTrainJobDTO = new BaseTrainJobDTO(); BeanUtil.copyProperties(ptTrainJobCreateDTO, baseTrainJobDTO); baseTrainJobDTO.setJobName(jobName); + baseTrainJobDTO.setPtTrainJobSpecs(ptTrainJobSpecs); //结果集处理 @@ -390,12 +455,24 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { LogUtil.error(LogEnum.BIZ_TRAIN, "User {} creates training job, pt Train table insert data failed", currentUser.getUsername()); throw new BusinessException("内部错误"); } - // 添加train_job表 PtTrainJob ptTrainJob = new PtTrainJob(); + //查询modelName值 + ModelQuery modelNameById = modelQueryMapper.findModelNameById(ptTrainJobCreateDTO.getModelId()); + ModelQueryBrance modelVersionByUrl = modelQueryMapper.findModelVersionByUrl(ptTrainJobCreateDTO.getModelLoadPathDir()); + if (modelNameById != null) { + String name = modelNameById.getName(); + if (modelVersionByUrl != null) { + ptTrainJobCreateDTO.setModelName(name + SymbolConstant.COLON + modelVersionByUrl.getVersion()); + } else { + //设置预置模型的url路径 + ptTrainJobCreateDTO.setModelLoadPathDir(modelNameById.getUrl()); + ptTrainJobCreateDTO.setModelName(name); + } + } BeanUtil.copyProperties(ptTrainJobCreateDTO, ptTrainJob); ptTrainJob.setTrainId(ptTrain.getId()) - .setTrainVersion(trainJobConstant.getVersionLabel().toUpperCase() + String.format(TrainUtil.FOUR_DECIMAL, 1)) + .setTrainVersion(trainJobConfig.getVersionLabel().toUpperCase() + String.format(TrainUtil.FOUR_DECIMAL, 1)) .setJobName(jobName) .setCreateUserId(currentUser.getId()); int jobResult = ptTrainJobMapper.insert(ptTrainJob); @@ -412,6 +489,18 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { .setImageName(imageName) .setRunParams(ptTrainJobCreateDTO.getRunParams()) .setCreateUserId(currentUser.getId()); + //保存训练延时启动时间 + if (ptTrainJobCreateDTO.getDelayCreateTime() != null && ptTrainJobCreateDTO.getDelayCreateTime() > 0) { + ptJobParam.setDelayCreateTime(TrainUtil.getDelayTime(ptTrainJobCreateDTO.getDelayCreateTime())); + } + //保存训练自动停止时间 + if (ptTrainJobCreateDTO.getDelayDeleteTime() != null && ptTrainJobCreateDTO.getDelayDeleteTime() > 0) { + if (ptTrainJobCreateDTO.getDelayCreateTime() != null && ptTrainJobCreateDTO.getDelayCreateTime() > 0) { + ptJobParam.setDelayDeleteTime(TrainUtil.getDelayTime(ptTrainJobCreateDTO.getDelayCreateTime() + ptTrainJobCreateDTO.getDelayDeleteTime())); + } else { + ptJobParam.setDelayDeleteTime(TrainUtil.getDelayTime(ptTrainJobCreateDTO.getDelayDeleteTime())); + } + } int jobParamResult = ptJobParamMapper.insert(ptJobParam); if (jobParamResult < 1) { LogUtil.error(LogEnum.BIZ_TRAIN, "User {} created training job, pT_job_parAM table insert data failed", currentUser.getUsername()); @@ -430,7 +519,7 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { PtTrainParam ptTrainParam = new PtTrainParam(); BeanUtil.copyProperties(ptTrainJobCreateDTO, ptTrainParam); //获取镜像url - String images = imageUtil.getImages(ptTrainJobCreateDTO, currentUser); + String images = imageUtil.getImageUrl(ptTrainJobCreateDTO, currentUser); ptTrainParam.setImageName(images); ptTrainParam.setParamName(ptTrainJobCreateDTO.getTrainParamName()) .setDescription(ptTrainJobCreateDTO.getTrainParamDesc()) @@ -472,11 +561,17 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { PtTrainAlgorithm ptTrainAlgorithm = ptTrainAlgorithmMapper.selectById(algorithmId); if (null == ptTrainAlgorithm || StringUtils.isBlank(ptTrainAlgorithm.getCodeDir())) { LogUtil.error(LogEnum.BIZ_TRAIN, "The record with algorithm training ID {} has no corresponding image or algorithm directory configuration", algorithmId); - throw new BusinessException(ResponseCode.SUCCESS, "该id的记录没有相应的镜像或者算法目录配置"); + throw new BusinessException(ResponseCode.ERROR, "该id的记录没有相应的镜像或者算法目录配置"); + } + + if (!AlgorithmStatusEnum.SUCCESS.getCode().equals(ptTrainAlgorithm.getAlgorithmStatus())) { + LogUtil.error(LogEnum.BIZ_TRAIN, "The algorithm ID {} algorithmStatus is{} unusual", algorithmId, ptTrainAlgorithm.getAlgorithmStatus()); + throw new BusinessException(ResponseCode.ERROR, "该算法状态异常!"); } + if (!(userId.equals(ptTrainAlgorithm.getCreateUserId()) || AlgorithmSourceEnum.PRE.getStatus().equals(ptTrainAlgorithm.getAlgorithmSource()))) { LogUtil.error(LogEnum.BIZ_TRAIN, "The data {} does not belong to the user {}!", ptTrainAlgorithm, userId); - throw new BusinessException(ResponseCode.SUCCESS, "该数据不属于该用户!"); + throw new BusinessException(ResponseCode.ERROR, "该数据不属于该用户!"); } PtImageAndAlgorithmVO ptImageAndAlgorithmVO = new PtImageAndAlgorithmVO(); @@ -527,7 +622,7 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { * @return String 版本 */ private String buildVersion(PtTrain ptTrain) { - return ptTrain.getTrainKey() + trainJobConstant.getSeparator() + trainJobConstant.getVersionLabel() + String.format(TrainUtil.FOUR_DECIMAL, ptTrain.getTotalNum() + 1); + return ptTrain.getTrainKey() + trainJobConfig.getSeparator() + trainJobConfig.getVersionLabel() + String.format(TrainUtil.FOUR_DECIMAL, ptTrain.getTotalNum() + 1); } /** @@ -538,6 +633,7 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { **/ @Override @Transactional(rollbackFor = Exception.class) + @DataPermissionMethod(dataType = DatasetTypeEnum.PUBLIC) public List updateTrainJob(PtTrainJobUpdateDTO ptTrainJobUpdateDTO) { UserDTO currentUser = JwtUtils.getCurrentUserDto(); @@ -548,16 +644,18 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { LogUtil.error(LogEnum.BIZ_TRAIN, "It is illegal for a user {} to modify a training job, jobId, to {}", currentUser.getUsername(), ptTrainJobUpdateDTO.getId()); throw new BusinessException(ResponseCode.SUCCESS, "您输入的id不存在或已被删除"); } - //获取算法 PtImageAndAlgorithmVO ptImageAndAlgorithmVO = getPtImageByAlgorithmId(ptTrainJobUpdateDTO.getAlgorithmId(), currentUser.getId()); //使用用户修改训练时提供的镜像与运行命令 //获取镜像url - String images = imageUtil.getImages(ptTrainJobUpdateDTO, currentUser); + String images = imageUtil.getImageUrl(ptTrainJobUpdateDTO, currentUser); ptImageAndAlgorithmVO.setImageName(trainHarborConfig.getAddress() + StrUtil.SLASH + images).setRunCommand(ptTrainJobUpdateDTO.getRunCommand()); //获取规格 - PtTrainJobSpecs ptTrainJobSpecs = getSpecs(ptTrainJobUpdateDTO.getTrainJobSpecsId(), currentUser); + PtTrainJobSpecs ptTrainJobSpecs = new PtTrainJobSpecs(); + ptTrainJobSpecs.setResourcesPoolType(ptTrainJobUpdateDTO.getResourcesPoolType()); + ptTrainJobSpecs.setSpecsName(ptTrainJobUpdateDTO.getTrainJobSpecsName()); + ptTrainJobSpecs.setSpecsInfo(JSONObject.parseObject(ptTrainJobUpdateDTO.getTrainJobSpecsInfo())); PtTrain ptTrain = ptTrainMapper.selectById(existPtTrainJob.getTrainId()); @@ -590,8 +688,22 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { currentUser, PtTrainJob existPtTrainJob, String imageName, PtTrain ptTrain, String jobName) { //添加train_job表 PtTrainJob ptTrainJob = new PtTrainJob(); + //根据id查询model的name值 + ModelQuery modelName = modelQueryMapper.findModelNameById(ptTrainJobUpdateDTO.getModelId()); + //根据Url查询版本的路径值 + ModelQueryBrance modelVersion = modelQueryMapper.findModelVersionByUrl(ptTrainJobUpdateDTO.getModelLoadPathDir()); + if (modelName != null) { + String name = modelName.getName(); + if (modelVersion != null) { + ptTrainJobUpdateDTO.setModelName(name + SymbolConstant.COLON + modelVersion.getVersion()); + } else { + //设置预置模型的url + ptTrainJobUpdateDTO.setModelLoadPathDir(modelName.getUrl()); + ptTrainJobUpdateDTO.setModelName(name); + } + } BeanUtil.copyProperties(ptTrainJobUpdateDTO, ptTrainJob); - ptTrainJob.setTrainId(ptTrain.getId()).setTrainVersion(trainJobConstant.getVersionLabel().toUpperCase() + String.format(TrainUtil.FOUR_DECIMAL, ptTrain.getTotalNum() + 1)) + ptTrainJob.setTrainId(ptTrain.getId()).setTrainVersion(trainJobConfig.getVersionLabel().toUpperCase() + String.format(TrainUtil.FOUR_DECIMAL, ptTrain.getTotalNum() + 1)) .setJobName(jobName).setParentTrainVersion(existPtTrainJob.getTrainVersion()) .setCreateUserId(currentUser.getId()); int jobResult = ptTrainJobMapper.insert(ptTrainJob); @@ -608,6 +720,18 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { .setImageName(imageName) .setRunParams(ptTrainJobUpdateDTO.getRunParams()) .setCreateUserId(currentUser.getId()); + //保存训练延时启动时间 + if (ptTrainJobUpdateDTO.getDelayCreateTime() != null && ptTrainJobUpdateDTO.getDelayCreateTime() > 0) { + ptJobParam.setDelayCreateTime(TrainUtil.getDelayTime(ptTrainJobUpdateDTO.getDelayCreateTime())); + } + //保存训练自动停止时间 + if (ptTrainJobUpdateDTO.getDelayDeleteTime() != null && ptTrainJobUpdateDTO.getDelayDeleteTime() > 0) { + if (ptTrainJobUpdateDTO.getDelayCreateTime() != null && ptTrainJobUpdateDTO.getDelayCreateTime() > 0) { + ptJobParam.setDelayDeleteTime(TrainUtil.getDelayTime(ptTrainJobUpdateDTO.getDelayCreateTime() + ptTrainJobUpdateDTO.getDelayDeleteTime())); + } else { + ptJobParam.setDelayDeleteTime(TrainUtil.getDelayTime(ptTrainJobUpdateDTO.getDelayDeleteTime())); + } + } int jobParamResult = ptJobParamMapper.insert(ptJobParam); if (jobParamResult < 1) { LogUtil.error(LogEnum.BIZ_TRAIN, "User {} created training job, pT_job_parAM table insert data failed", currentUser.getUsername()); @@ -634,6 +758,7 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { **/ @Override @Transactional(rollbackFor = Exception.class) + @DataPermissionMethod(dataType = DatasetTypeEnum.PUBLIC) public PtTrainJobDeleteVO deleteTrainJob(PtTrainJobDeleteDTO ptTrainJobDeleteDTO) { UserDTO currentUser = JwtUtils.getCurrentUserDto(); LogUtil.info(LogEnum.BIZ_TRAIN, "User {} deletes the training job and receives the parameter {}", currentUser.getUsername(), ptTrainJobDeleteDTO); @@ -644,14 +769,15 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { Collection jobIdList = new ArrayList<>(); if (null != ptTrainJobDeleteDTO.getId()) { - //删除job - deleteJobs(currentUser, jobList); - - int jobResult = ptTrainJobMapper.deleteById(ptTrainJobDeleteDTO.getId()); - if (jobResult < 1) { + //要删除的训练任务 + PtTrainJob ptTrainJob = ptTrainJobMapper.selectById(ptTrainJobDeleteDTO.getId()); + if (ptTrainJob == null) { LogUtil.error(LogEnum.BIZ_TRAIN, "User {} deleted training job, pT_train_job table failed to delete data", currentUser.getUsername()); throw new BusinessException(ResponseCode.SUCCESS, "训练任务已删除或参数不合法"); } + //删除job + deleteJobs(currentUser, jobList); + ptTrainJobMapper.deleteById(ptTrainJobDeleteDTO.getId()); PtTrain updatePtTrain = new PtTrain(); updatePtTrain.setVersionNum(ptTrain.getVersionNum() - 1); @@ -671,8 +797,15 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { } } jobIdList.add(ptTrainJobDeleteDTO.getId()); + + //回收已删除训练任务无效文件 + String recyclePath = nfsUtil.formatPath(nfsConfig.getRootDir() + nfsConfig.getBucket() + trainJobConfig.getManage() + + File.separator + ptTrainJob.getCreateUserId() + File.separator + ptTrainJob.getJobName()); + recycleTaskWithTrain(recyclePath); + } else { deleteTrainAndJob(ptTrainJobDeleteDTO, currentUser, jobList, ptTrain, jobIdList); + } //删除pt_job_param表中相关数据 @@ -738,6 +871,13 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { LogUtil.error(LogEnum.BIZ_TRAIN, "User {} deleted training job, pT_train_job table failed to delete data", currentUser.getUsername()); throw new BusinessException("内部错误"); } + + //回收已删除训练任务无效文件 + for (PtTrainJob trainJob : ptTrainJobs) { + String recyclePath = nfsUtil.formatPath(nfsConfig.getRootDir() + nfsConfig.getBucket() + trainJobConfig.getManage() + + StrUtil.SLASH + trainJob.getCreateUserId() + StrUtil.SLASH + trainJob.getJobName()); + recycleTaskWithTrain(recyclePath); + } } /** @@ -767,7 +907,7 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { } /** - * 检测停止训练任务DTO + * 检测停止训练任务 * * @param ptTrainJobStopDTO 停止训练DTO * @param currentUser 用户 @@ -798,11 +938,13 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { * @param jobList 任务集合 */ private void deleteJobs(UserDTO currentUser, List jobList) { - String namespace = k8sNameTool.generateNameSpace(currentUser.getId()); + String namespace = k8sNameTool.generateNamespace(currentUser.getId()); try { for (PtTrainJob job : jobList) { if (TrainJobStatusEnum.STOP.getStatus().equals(job.getTrainStatus())) { - boolean bool = trainJobApi.delete(namespace, job.getJobName()); + boolean bool = TrainTypeEnum.isDistributeTrain(job.getTrainType()) ? + distributeTrainApi.deleteByResourceName(namespace, job.getJobName()).isSuccess() : + trainJobApi.delete(namespace, job.getJobName()); if (!bool) { LogUtil.error(LogEnum.BIZ_TRAIN, "User {} deletes the training job and K8s fails to execute the delete() method, namespace为{}, resourceName为{}", currentUser.getUsername(), namespace, job.getJobName()); @@ -828,6 +970,7 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { **/ @Override @Transactional(rollbackFor = Exception.class) + @DataPermissionMethod(dataType = DatasetTypeEnum.PUBLIC) public PtTrainJobStopVO stopTrainJob(PtTrainJobStopDTO ptTrainJobStopDTO) { UserDTO currentUser = JwtUtils.getCurrentUserDto(); LogUtil.info(LogEnum.BIZ_TRAIN, "User {} stops training Job and receives the parameter {}", currentUser.getUsername(), ptTrainJobStopDTO); @@ -837,7 +980,7 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { if (null != ptTrainJobStopDTO.getId()) { //停止job - stopJobs(currentUser, jobList); + stopTrainJobAsync.stopJobs(currentUser, jobList); } else if (null != ptTrainJobStopDTO.getTrainId()) { QueryWrapper queryTrainJonWrapper = new QueryWrapper<>(); @@ -855,42 +998,22 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { } //停止job - stopJobs(currentUser, jobList); + stopTrainJobAsync.stopJobs(currentUser, jobList); } - //更新job状态 - updateJobStatus(currentUser, jobList); - PtTrainJobStopVO ptTrainJobStopVO = new PtTrainJobStopVO(); ptTrainJobStopVO.setTrainId(ptTrainJobStopDTO.getTrainId()); ptTrainJobStopVO.setId(ptTrainJobStopDTO.getId()); return ptTrainJobStopVO; } - /** - * 更新训练状态 - * - * @param currentUser 用户 - * @param jobList 任务集合 - */ - private void updateJobStatus(UserDTO currentUser, List jobList) { - for (PtTrainJob ptTrainJob : jobList) { - PtTrainJob updateTrainJob = new PtTrainJob(); - updateTrainJob.setId(ptTrainJob.getId()).setRuntime(ptTrainJob.getRuntime()).setTrainStatus(TrainJobStatusEnum.STOP.getStatus()); - int updateResult = ptTrainJobMapper.updateById(updateTrainJob); - if (updateResult < 1) { - LogUtil.error(LogEnum.BIZ_TRAIN, "User {} stops training job, pT_train_job table fails to update status, the information is as follows {}", currentUser.getUsername(), updateTrainJob); - throw new BusinessException(ResponseCode.SUCCESS, "没有待停止的job"); - } - } - } - /** * 任务统计 * * @return PtTrainJobStatisticsMineVO 我的训练任务统计结果 **/ @Override + @DataPermissionMethod(dataType = DatasetTypeEnum.PUBLIC) public PtTrainJobStatisticsMineVO statisticsMine() { UserDTO userDTO = JwtUtils.getCurrentUserDto(); // 获取运行中的任务 @@ -909,29 +1032,6 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { return vo; } - - /** - * 停止任务 - * - * @param currentUser 用户 - * @param jobList 任务集合 - */ - void stopJobs(UserDTO currentUser, List jobList) { - String namespace = k8sNameTool.generateNameSpace(currentUser.getId()); - jobList.forEach(job -> { - BizPod bizPod = podApi.getWithResourceName(namespace, job.getJobName()); - if (!bizPod.isSuccess()) { - LogUtil.error(LogEnum.BIZ_TRAIN, "User {} stops training Job return code:{},message:{}", currentUser.getUsername(), Integer.valueOf(bizPod.getCode()), bizPod.getMessage()); - } - boolean bool = trainJobApi.delete(namespace, job.getJobName()); - if (!bool) { - LogUtil.error(LogEnum.BIZ_TRAIN, "User {} stops training Job and K8S fails in the stop process, namespace为{}, resourceName为{}", - currentUser.getUsername(), namespace, job.getJobName()); - } - job.setRuntime(calculateRuntime(bizPod)); - }); - } - /** * 查询训练作业job状态 * @@ -939,6 +1039,7 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { * @return HashedMap 数据集路径-是否可以删除 的map集合 **/ @Override + @DataPermissionMethod(dataType = DatasetTypeEnum.PUBLIC) public Map getTrainDataSourceStatus(PtTrainDataSourceStatusQueryDTO ptTrainDataSourceStatusQueryDTO) { UserDTO currentUser = JwtUtils.getCurrentUserDto(); LogUtil.info(LogEnum.BIZ_TRAIN, "The user {} queries the state of the dataset starting with the received parameter {}", currentUser.getUsername(), ptTrainDataSourceStatusQueryDTO); @@ -980,34 +1081,47 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { * @return PtTrainQueryJobDetailVO 根据jobId查询训练任务详情返回结果 */ @Override + @DataPermissionMethod(dataType = DatasetTypeEnum.PUBLIC) public PtTrainJobDetailQueryVO getTrainJobDetail(PtTrainJobDetailQueryDTO ptTrainJobDetailQueryDTO) { //从会话中获取用户信息 UserDTO user = JwtUtils.getCurrentUserDto(); LogUtil.info(LogEnum.BIZ_TRAIN, "The user {} starts by querying jobId={} for training task details", user.getUsername(), ptTrainJobDetailQueryDTO.getId()); + //获取训练job参数 QueryWrapper trainJobQuery = new QueryWrapper<>(); trainJobQuery.eq("create_user_id", user.getId()); trainJobQuery.eq("id", ptTrainJobDetailQueryDTO.getId()); - //获取训练 PtTrainJob ptTrainJob = ptTrainJobMapper.selectOne(trainJobQuery); if (ptTrainJob == null) { LogUtil.error(LogEnum.BIZ_TRAIN, "The jobId for the user {} query does not exist", user.getUsername()); throw new BusinessException(ResponseCode.SUCCESS, "您查询的id不存在或已被删除"); } + //获取训练参数 + PtTrain ptTrain = ptTrainMapper.selectById(ptTrainJob.getTrainId()); + //获取训练任务参数 QueryWrapper jobParamQuery = new QueryWrapper<>(); jobParamQuery.eq("train_job_id", ptTrainJob.getId()); - //获取训练任务参数 PtJobParam ptJobParam = ptJobParamMapper.selectOne(jobParamQuery); if (ptJobParam == null || ptJobParam.getAlgorithmId() < MagicNumConstant.ONE) { LogUtil.error(LogEnum.BIZ_TRAIN, "The algorithm ID corresponding to the jobId={} query by the user {} does not exist", user.getUsername(), ptTrainJobDetailQueryDTO.getId()); throw new BusinessException(ResponseCode.SUCCESS, "您查询的jobId对应的算法id不存在或已被删除"); } - //获取算法 + //获取算法参数 PtTrainAlgorithm ptTrainAlgorithm = ptTrainAlgorithmMapper.selectAllById(ptJobParam.getAlgorithmId()); - //拼装job Detail信息 + //结果集处理 PtTrainJobDetailQueryVO ptTrainJobDetailQueryVO = new PtTrainJobDetailQueryVO(); BeanUtils.copyProperties(ptTrainJob, ptTrainJobDetailQueryVO); - ptTrainJobDetailQueryVO.setAlgorithmId(ptJobParam.getAlgorithmId()).setRunCommand(ptJobParam.getRunCommand()).setRunParams(ptJobParam.getRunParams()) - .setParamF1(ptJobParam.getParamF1()).setParamCallback(ptJobParam.getParamCallback()).setParamPrecise(ptJobParam.getParamPrecise()).setParamAccuracy(ptJobParam.getParamAccuracy()); + ptTrainJobDetailQueryVO.setTrainName(ptTrain.getTrainName()).setAlgorithmId(ptJobParam.getAlgorithmId()).setRunCommand(ptJobParam.getRunCommand()) + .setRunParams(ptJobParam.getRunParams()).setParamF1(ptJobParam.getParamF1()).setParamCallback(ptJobParam.getParamCallback()) + .setParamPrecise(ptJobParam.getParamPrecise()).setParamAccuracy(ptJobParam.getParamAccuracy()); + long nowTime = System.currentTimeMillis(); + //获取训练延时启动倒计时(分钟) + if (ptJobParam.getDelayCreateTime() != null && nowTime < ptJobParam.getDelayCreateTime().getTime() && TrainJobStatusEnum.checkRunStatus(ptTrainJob.getTrainStatus())) { + ptTrainJobDetailQueryVO.setDelayCreateCountDown(TrainUtil.getCountDown(ptJobParam.getDelayCreateTime().getTime())); + } + //获取训练自动停止倒计时(分钟) + if (ptJobParam.getDelayDeleteTime() != null && nowTime < ptJobParam.getDelayDeleteTime().getTime() && TrainJobStatusEnum.checkRunStatus(ptTrainJob.getTrainStatus())) { + ptTrainJobDetailQueryVO.setDelayDeleteCountDown(TrainUtil.getCountDown(ptJobParam.getDelayDeleteTime().getTime())); + } //拼装镜像信息 if (StringUtils.isNotBlank(ptJobParam.getImageName())) { String imageNameSuffix = ptJobParam.getImageName().substring(ptJobParam.getImageName().lastIndexOf(StrUtil.SLASH) + MagicNumConstant.ONE); @@ -1022,6 +1136,9 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { .setAlgorithmUsage(ptTrainAlgorithm.getAlgorithmUsage()) .setAccuracy(ptTrainAlgorithm.getAccuracy()) .setP4InferenceSpeed(ptTrainAlgorithm.getP4InferenceSpeed()); + if (ptTrainAlgorithm.getAlgorithmSource() == MagicNumConstant.ONE) { + ptTrainJobDetailQueryVO.setAlgorithmCodeDir(ptTrainAlgorithm.getCodeDir()); + } } return ptTrainJobDetailQueryVO; @@ -1033,6 +1150,8 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { * @param ptTrainJobResumeDTO 恢复训练请求参数 */ @Override + @Transactional(rollbackFor = Exception.class) + @DataPermissionMethod(dataType = DatasetTypeEnum.PUBLIC) public void resumeTrainJob(PtTrainJobResumeDTO ptTrainJobResumeDTO) { //从会话中获取用户信息 UserDTO currentUser = JwtUtils.getCurrentUserDto(); @@ -1054,27 +1173,39 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { //获取镜像 PtImageAndAlgorithmVO ptImageAndAlgorithmVO = getPtImageByAlgorithmId(ptJobParam.getAlgorithmId(), currentUser.getId()); //使用用户训练时提供的镜像与运行命令 - ptImageAndAlgorithmVO.setImageName(ptJobParam.getImageName()).setRunCommand(ptJobParam.getRunCommand()); - //获取规格 - PtTrainJobSpecs ptTrainJobSpecs = getSpecs(ptTrainJob.getTrainJobSpecsId(), currentUser); + ptImageAndAlgorithmVO.setImageName(trainHarborConfig.getAddress() + StrUtil.SLASH + ptJobParam.getImageName()).setRunCommand(ptJobParam.getRunCommand()); + String[] codeDirResult = ptImageAndAlgorithmVO.getCodeDir().split(StrUtil.SLASH); + String codeDirName = codeDirResult[codeDirResult.length - 1]; //处理目录问题 - String commonPath = nfsUtil.getNfsConfig().getBucket() + trainJobConstant.getManage() + StrUtil.SLASH - + currentUser.getId() + StrUtil.SLASH + ptTrainJob.getJobName(); - String outPath = commonPath + StrUtil.SLASH + trainJobConstant.getOutPath(); - String loadPath = commonPath + StrUtil.SLASH + trainJobConstant.getLoadPath(); - String modelLoadDirName = nfsUtil.find2ndNewDir(outPath); - if ("".equals(modelLoadDirName)) { - LogUtil.error(LogEnum.BIZ_TRAIN, "outPath: {}", outPath); - throw new BusinessException(ResponseCode.ERROR, "该任务没有前序结果可以继续训练"); + String noEnvPath = StrUtil.SLASH + trainJobConfig.getManage() + StrUtil.SLASH + currentUser.getId() + StrUtil.SLASH + + ptTrainJob.getJobName(); + String commonPath = nfsConfig.getBucket() + noEnvPath.substring(1); + String outPath = commonPath + StrUtil.SLASH + trainJobConfig.getOutPath(); + String loadPath = commonPath + StrUtil.SLASH + trainJobConfig.getLoadPath(); + String codePath = commonPath + StrUtil.SLASH + codeDirName; + String noEnvOut = noEnvPath + StrUtil.SLASH + trainJobConfig.getOutPath(); + String path = ptTrainJobResumeDTO.getPath(); + if (!path.startsWith(noEnvOut)) { + LogUtil.error(LogEnum.BIZ_TRAIN, "path: {}", path); + throw new BusinessException("内部错误"); } - nfsUtil.deleteDirOrFile(loadPath); - nfsUtil.renameDir(outPath, loadPath); + String modelLoadDir = path.substring(noEnvOut.length()); + FileUtil.del(nfsConfig.getRootDir() + loadPath); + FileUtil.del(nfsConfig.getRootDir() + codePath); + FileUtil.rename(new File(nfsConfig.getRootDir() + outPath), nfsConfig.getRootDir() + loadPath, false, true); + + //获取训练规格信息 + PtTrainJobSpecs ptTrainJobSpecs = new PtTrainJobSpecs(); + List dictDetails = dictDetailMapper.selectList(new LambdaQueryWrapper().eq(DictDetail::getLabel, ptTrainJob.getTrainJobSpecsName())); + ptTrainJobSpecs.setResourcesPoolType(ptTrainJob.getResourcesPoolType()); + ptTrainJobSpecs.setSpecsName(ptTrainJob.getTrainJobSpecsName()); + ptTrainJobSpecs.setSpecsInfo(JSONObject.parseObject(dictDetails.get(0).getValue())); // 拼load路径 JSONObject runParams = ptJobParam.getRunParams(); - runParams.put(trainJobConstant.getLoadKey(), trainJobConstant.getDockerTrainPath() + StrUtil.SLASH + - trainJobConstant.getLoadPath() + StrUtil.SLASH + modelLoadDirName); + runParams.put(trainJobConfig.getLoadKey(), trainJobConfig.getDockerTrainPath() + StrUtil.SLASH + + trainJobConfig.getLoadPath() + modelLoadDir); BaseTrainJobDTO baseTrainJobDTO = new BaseTrainJobDTO(); BeanUtil.copyProperties(ptTrainJob, baseTrainJobDTO); baseTrainJobDTO.setPtTrainJobSpecs(ptTrainJobSpecs); @@ -1082,12 +1213,16 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { // 初始化训练时间和状态 PtTrainJob updatePtTrainJob = new PtTrainJob(); - updatePtTrainJob.setId(ptTrainJob.getId()).setRuntime("").setTrainStatus(TrainJobStatusEnum.PENDING.getStatus()); + updatePtTrainJob.setId(ptTrainJob.getId()).setRuntime(TrainUtil.INIT_RUNTIME) + .setTrainStatus(TrainJobStatusEnum.PENDING.getStatus()) + .setUpdateTime(new Timestamp(System.currentTimeMillis())); int updateResult = ptTrainJobMapper.updateById(updatePtTrainJob); if (updateResult < 1) { LogUtil.error(LogEnum.BIZ_TRAIN, "User {} resumed training job, pt train Job table update failed", currentUser.getUsername()); throw new BusinessException("内部错误"); } + // 此处将ptTrainJob的trainStatus和runTime设为null以避免doJob中再次调用updateById错误更新状态和时间 + ptTrainJob.setTrainStatus(null).setRuntime(null).setCreateTime(null); // 提交job asyncManager.execute(baseTrainJobDTO, currentUser, ptImageAndAlgorithmVO, ptTrainJob); } @@ -1095,11 +1230,11 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { /** * 获取job在grafana监控的地址 * - * @param jobId 任务ID - * @return PtJobMetricsGrafanaVO grafana监控的地址信息 + * @param jobId 任务ID + * @return List grafana监控的地址信息 */ @Override - public PtJobMetricsGrafanaVO getGrafanaUrl(Long jobId) { + public List getGrafanaUrl(Long jobId) { UserDTO currentUser = JwtUtils.getCurrentUserDto(); LogUtil.info(LogEnum.BIZ_TRAIN, "User {} gets grafanaUrl of Job and receives parameter [jobId] {}", currentUser.getUsername(), jobId); @@ -1111,13 +1246,41 @@ public class PtTrainJobServiceImpl implements PtTrainJobService { throw new BusinessException(ResponseCode.SUCCESS, "您输入的id不存在或已被删除,请重新输入"); } - String podMetricsGrafanaUrl = podApi - .getPodMetricsGrafanaUrl(k8sNameTool.generateNameSpace(currentUser.getId()), ptTrainJob.getJobName()); + List list = new ArrayList<>(); + try { + List bizPodList = podApi.getListByResourceName(k8sNameTool.generateNamespace(currentUser.getId()), ptTrainJob.getJobName()); + bizPodList.stream() + .filter(bizPod -> bizPod.getPhase().equalsIgnoreCase(TrainJobStatusEnum.RUNNING.getMessage())) + .forEach(bizPod -> { + String podName = bizPod.getName(); + PtJobMetricsGrafanaVO ptJobMetricsGrafanaVO = new PtJobMetricsGrafanaVO(); + ptJobMetricsGrafanaVO.setJobMetricsGrafanaUrl(k8sPodMetricsGrafanaUrl.concat(podName)); + ptJobMetricsGrafanaVO.setJobPodName(podName); + if (ptTrainJob.getTrainType() == 1 && PodUtil.isMaster(podName)) { + list.add(0, ptJobMetricsGrafanaVO); + } else { + list.add(ptJobMetricsGrafanaVO); + } + }); + } catch (Exception e) { + LogUtil.info(LogEnum.BIZ_K8S, "Failed to obtain grafanaUrl of Pod, params:[namespace]={}, [resourceName]={}, error:{}", + k8sNameTool.generateNamespace(currentUser.getId()), ptTrainJob.getJobName(), e); + } - PtJobMetricsGrafanaVO ptJobMetricsGrafanaVO = new PtJobMetricsGrafanaVO(); - ptJobMetricsGrafanaVO.setJobMetricsGrafanaUrl(podMetricsGrafanaUrl); LogUtil.info(LogEnum.BIZ_TRAIN, "User {} completes getting grafanaUrl on job, receives {} parameter [jobId], returns {} result", - currentUser.getUsername(), jobId, ptJobMetricsGrafanaVO); - return ptJobMetricsGrafanaVO; + currentUser.getUsername(), jobId, JSONObject.toJSONString(list)); + return list; + } + + public void recycleTaskWithTrain(String recyclePath) { + //创建已删除训练任务的无效文件回收任务 + RecycleTaskCreateDTO recycleTask = new RecycleTaskCreateDTO(); + recycleTask.setRecycleModule(RecycleModuleEnum.BIZ_TRAIN.getValue()) + .setRecycleType(RecycleTypeEnum.FILE.getCode()) + .setRecycleDelayDate(recycleConfig.getTrainValid()) + .setRecycleCondition(recyclePath) + .setRecycleNote("回收已删除训练任务文件"); + LogUtil.info(LogEnum.BIZ_TRAIN, "delete train job add recycle task:{}", recycleTask); + recycleTaskService.createRecycleTask(recycleTask); } } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtTrainLogServiceImpl.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtTrainLogServiceImpl.java index 7a59754..4ccea5e 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtTrainLogServiceImpl.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtTrainLogServiceImpl.java @@ -17,7 +17,9 @@ package org.dubhe.service.impl; +import cn.hutool.core.util.StrUtil; import org.dubhe.constant.NumberConstant; +import org.dubhe.constant.SymbolConstant; import org.dubhe.dao.PtTrainJobMapper; import org.dubhe.domain.dto.PtTrainLogQueryDTO; import org.dubhe.domain.dto.UserDTO; @@ -27,20 +29,25 @@ import org.dubhe.enums.LogEnum; import org.dubhe.exception.BusinessException; import org.dubhe.k8s.api.LogMonitoringApi; import org.dubhe.k8s.domain.bo.LogMonitoringBO; +import org.dubhe.k8s.domain.dto.PodQueryDTO; import org.dubhe.k8s.domain.vo.LogMonitoringVO; +import org.dubhe.k8s.domain.vo.PodVO; +import org.dubhe.k8s.service.PodService; import org.dubhe.service.PtTrainLogService; import org.dubhe.utils.JwtUtils; import org.dubhe.utils.K8sNameTool; import org.dubhe.utils.LogUtil; +import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; +import java.util.Collections; import java.util.List; /** - * @description: 训练日志 服务实现类 + * @description 训练日志 服务实现类 * @date 2020-05-08 */ @@ -56,6 +63,9 @@ public class PtTrainLogServiceImpl implements PtTrainLogService { @Autowired private K8sNameTool k8sNameTool; + @Autowired + private PodService podService; + /** * 查询训练任务运行日志 * @@ -68,7 +78,7 @@ public class PtTrainLogServiceImpl implements PtTrainLogService { //从会话中获取用户信息 UserDTO user = JwtUtils.getCurrentUserDto(); - String nameSpace = k8sNameTool.generateNameSpace(user.getId()); + String namespace = k8sNameTool.generateNamespace(user.getId()); PtTrainJob ptTrainJob = ptTrainJobMapper.selectById(ptTrainLogQueryDTO.getJobId()); if (null == ptTrainJob || !user.getId().equals(ptTrainJob.getCreateUserId())) { @@ -82,15 +92,15 @@ public class PtTrainLogServiceImpl implements PtTrainLogService { Integer lines = null == ptTrainLogQueryDTO.getLines() ? NumberConstant.NUMBER_50 : ptTrainLogQueryDTO.getLines(); /** 拼接请求es的参数 **/ LogMonitoringBO logMonitoringBo = new LogMonitoringBO(); - logMonitoringBo.setNamespace(nameSpace) + logMonitoringBo.setNamespace(namespace) .setResourceName(ptTrainJob.getJobName()); PtTrainLogQueryVO ptTrainLogQueryVO = new PtTrainLogQueryVO(); - LogMonitoringVO result = logMonitoringApi.searchLog(startLine, lines, logMonitoringBo); + LogMonitoringVO result = logMonitoringApi.searchLogByResName(startLine, lines, logMonitoringBo); List list = result.getLogs(); - if (result == null || CollectionUtils.isEmpty(list)) { + if (CollectionUtils.isEmpty(list)) { ptTrainLogQueryVO.setContent(list); ptTrainLogQueryVO.setStartLine(startLine); ptTrainLogQueryVO.setEndLine(startLine - 1); @@ -113,13 +123,27 @@ public class PtTrainLogServiceImpl implements PtTrainLogService { */ @Override public String getTrainLogString(List content) { - String strContent = ""; - if (content != null) { - for (String str : content) { - strContent = strContent.concat(str).concat("\r\n"); - } + if (content == null) { + return SymbolConstant.BLANK; } + return StringUtils.join(content,StrUtil.CRLF); + } - return strContent; + /** + * 获取训练任务的Pod + * + * @param id 训练作业job表 id + * @return 训练节点信息 + */ + @Override + public List getPods(Long id) { + PtTrainJob ptTrainJob = ptTrainJobMapper.selectById(id); + if (ptTrainJob == null){ + return Collections.emptyList(); + } + //从会话中获取用户信息 + UserDTO user = JwtUtils.getCurrentUserDto(); + String nameSpace = k8sNameTool.generateNamespace(user.getId()); + return podService.getPods(new PodQueryDTO(nameSpace,ptTrainJob.getJobName())); } } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtTrainParamServiceImpl.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtTrainParamServiceImpl.java index 19d4d88..d74f118 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtTrainParamServiceImpl.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/service/impl/PtTrainParamServiceImpl.java @@ -21,16 +21,22 @@ import cn.hutool.core.util.StrUtil; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import com.baomidou.mybatisplus.core.metadata.IPage; import com.baomidou.mybatisplus.extension.plugins.pagination.Page; +import org.dubhe.annotation.DataPermissionMethod; import org.dubhe.base.MagicNumConstant; import org.dubhe.base.ResponseCode; -import org.dubhe.constant.TrainJobConstant; +import org.dubhe.constant.SymbolConstant; +import org.dubhe.config.TrainJobConfig; +import org.dubhe.dao.ModelQueryMapper; import org.dubhe.dao.PtTrainAlgorithmMapper; import org.dubhe.dao.PtTrainParamMapper; import org.dubhe.data.constant.Constant; import org.dubhe.domain.dto.*; +import org.dubhe.domain.entity.ModelQuery; +import org.dubhe.domain.entity.ModelQueryBrance; import org.dubhe.domain.entity.PtTrainAlgorithm; import org.dubhe.domain.entity.PtTrainParam; import org.dubhe.domain.vo.PtTrainParamQueryVO; +import org.dubhe.enums.DatasetTypeEnum; import org.dubhe.enums.LogEnum; import org.dubhe.exception.BusinessException; import org.dubhe.service.PtTrainParamService; @@ -59,6 +65,8 @@ public class PtTrainParamServiceImpl implements PtTrainParamService { @Autowired private ImageUtil imageUtil; + @Autowired + private ModelQueryMapper modelQueryMapper; /** * 参数列表展示 * @@ -66,6 +74,7 @@ public class PtTrainParamServiceImpl implements PtTrainParamService { * @return Map 任务参数列表分页数据 **/ @Override + @DataPermissionMethod(dataType = DatasetTypeEnum.PUBLIC) public Map getTrainParam(PtTrainParamQueryDTO ptTrainParamQueryDTO) { //从会话中获取用户信息 UserDTO user = JwtUtils.getCurrentUserDto(); @@ -73,7 +82,6 @@ public class PtTrainParamServiceImpl implements PtTrainParamService { Page page = ptTrainParamQueryDTO.toPage(); //查询任务参数列表 QueryWrapper query = new QueryWrapper<>(); - query.eq("create_user_id", user.getId()); //根据任务参数名称模糊搜索 if (ptTrainParamQueryDTO.getParamName() != null) { query.like("param_name", ptTrainParamQueryDTO.getParamName()); @@ -84,7 +92,7 @@ public class PtTrainParamServiceImpl implements PtTrainParamService { } IPage ptTrainParams; try { - if (ptTrainParamQueryDTO.getSort() == null || ptTrainParamQueryDTO.getSort().equalsIgnoreCase(TrainJobConstant.ALGORITHM_NAME)) { + if (ptTrainParamQueryDTO.getSort() == null || ptTrainParamQueryDTO.getSort().equalsIgnoreCase(TrainJobConfig.ALGORITHM_NAME)) { query.orderByDesc(Constant.ID); } else { if (Constant.SORT_ASC.equalsIgnoreCase(ptTrainParamQueryDTO.getOrder())) { @@ -129,6 +137,7 @@ public class PtTrainParamServiceImpl implements PtTrainParamService { **/ @Override @Transactional(rollbackFor = Exception.class) + @DataPermissionMethod(dataType = DatasetTypeEnum.PUBLIC) public List createTrainParam(PtTrainParamCreateDTO ptTrainParamCreateDTO) { //从会话中获取用户信息 UserDTO user = JwtUtils.getCurrentUserDto(); @@ -139,9 +148,23 @@ public class PtTrainParamServiceImpl implements PtTrainParamService { Integer algorithmSource = ptTrainAlgorithm.getAlgorithmSource(); //保存任务参数 PtTrainParam ptTrainParam = new PtTrainParam(); + //模型名称 + ModelQuery modelName = modelQueryMapper.findModelNameById(ptTrainParamCreateDTO.getModelId()); + //模型版本 + ModelQueryBrance modelVersion = modelQueryMapper.findModelVersionByUrl(ptTrainParamCreateDTO.getModelLoadPathDir()); + if(modelName!=null){ + String name = modelName.getName(); + if(modelVersion!=null){ + ptTrainParamCreateDTO.setModelName(name+ SymbolConstant.COLON +modelVersion.getVersion()); + }else { + //设置预置模型的url路径 + ptTrainParamCreateDTO.setModelLoadPathDir(modelName.getUrl()); + ptTrainParamCreateDTO.setModelName(name); + } + } BeanUtils.copyProperties(ptTrainParamCreateDTO, ptTrainParam); //获取镜像 - String images = imageUtil.getImages(ptTrainParamCreateDTO, user); + String images = imageUtil.getImageUrl(ptTrainParamCreateDTO, user); ptTrainParam.setImageName(images).setAlgorithmSource(algorithmSource).setCreateUserId(user.getId()); int insertResult = ptTrainParamMapper.insert(ptTrainParam); //任务参数未保存成功,抛出异常,并返回失败信息 @@ -162,6 +185,7 @@ public class PtTrainParamServiceImpl implements PtTrainParamService { **/ @Override @Transactional(rollbackFor = Exception.class) + @DataPermissionMethod(dataType = DatasetTypeEnum.PUBLIC) public List updateTrainParam(PtTrainParamUpdateDTO ptTrainParamUpdateDTO) { //从会话中获取用户信息 UserDTO user = JwtUtils.getCurrentUserDto(); @@ -170,10 +194,26 @@ public class PtTrainParamServiceImpl implements PtTrainParamService { checkUpdateTrainParam(ptTrainParamUpdateDTO, user); //修改任务参数 PtTrainParam ptTrainParam = new PtTrainParam(); + //模型名称 + ModelQuery modelName = modelQueryMapper.findModelNameById(ptTrainParamUpdateDTO.getModelId()); + //模型版本 + ModelQueryBrance modelVersion = modelQueryMapper.findModelVersionByUrl(ptTrainParamUpdateDTO.getModelLoadPathDir()); + //设置版本 + if(modelName!=null){ + String name=modelName.getName(); + if(modelVersion!=null){ + ptTrainParamUpdateDTO.setModelName(name+SymbolConstant.COLON+modelVersion.getVersion()); + }else { + //设置预置模型的url值 + ptTrainParamUpdateDTO.setModelLoadPathDir(modelName.getUrl()); + ptTrainParamUpdateDTO.setModelName(name); + } + } + BeanUtils.copyProperties(ptTrainParamUpdateDTO, ptTrainParam); ptTrainParam.setUpdateUserId(user.getId()); //获取镜像url - String images = imageUtil.getImages(ptTrainParamUpdateDTO, user); + String images = imageUtil.getImageUrl(ptTrainParamUpdateDTO, user); //添加镜像url ptTrainParam.setImageName(images); try { @@ -196,6 +236,7 @@ public class PtTrainParamServiceImpl implements PtTrainParamService { **/ @Override @Transactional(rollbackFor = Exception.class) + @DataPermissionMethod(dataType = DatasetTypeEnum.PUBLIC) public void deleteTrainParam(PtTrainParamDeleteDTO ptTrainParamDeleteDTO) { //从会话中获取用户信息 UserDTO user = JwtUtils.getCurrentUserDto(); @@ -246,7 +287,6 @@ public class PtTrainParamServiceImpl implements PtTrainParamService { //任务参数名称校验 QueryWrapper query = new QueryWrapper<>(); query.eq("param_name", ptTrainParamCreateDTO.getParamName()); - query.eq("create_user_id", user.getId()); Integer trainParamCountResult = ptTrainParamMapper.selectCount(query); if (trainParamCountResult > 0) { LogUtil.error(LogEnum.BIZ_TRAIN, "The task parameter name ({}) already exists", ptTrainParamCreateDTO.getParamName()); @@ -275,7 +315,7 @@ public class PtTrainParamServiceImpl implements PtTrainParamService { } //权限校验 QueryWrapper queryWrapper = new QueryWrapper<>(); - queryWrapper.eq("id", ptTrainParamUpdateDTO.getId()).eq("create_user_id", user.getId()); + queryWrapper.eq("id", ptTrainParamUpdateDTO.getId()); Integer countResult = ptTrainParamMapper.selectCount(queryWrapper); if (countResult < 1) { LogUtil.error(LogEnum.BIZ_TRAIN, "The user {} failed to modify the task parameters and has no permission to modify the corresponding data in the pt_train_param table", user.getUsername()); @@ -283,7 +323,7 @@ public class PtTrainParamServiceImpl implements PtTrainParamService { } //任务参数名称校验 QueryWrapper query = new QueryWrapper<>(); - query.eq("param_name", ptTrainParamUpdateDTO.getParamName()).eq("create_user_id", user.getId()); + query.eq("param_name", ptTrainParamUpdateDTO.getParamName()); PtTrainParam trainParam = ptTrainParamMapper.selectOne(query); if (trainParam != null && !ptTrainParamUpdateDTO.getId().equals(trainParam.getId())) { LogUtil.error(LogEnum.BIZ_TRAIN, "The task parameter name ({}) already exists", ptTrainParamUpdateDTO.getParamName()); @@ -307,7 +347,6 @@ public class PtTrainParamServiceImpl implements PtTrainParamService { } //权限校验 QueryWrapper query = new QueryWrapper<>(); - query.eq("create_user_id", user.getId()); query.in("id", idList); Integer queryCountResult = ptTrainParamMapper.selectCount(query); if (queryCountResult < idList.size()) { diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/task/TrainJobAsyncTask.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/task/TrainJobAsyncTask.java deleted file mode 100644 index 0ccae1e..0000000 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/task/TrainJobAsyncTask.java +++ /dev/null @@ -1,245 +0,0 @@ -/** - * Copyright 2020 Zhejiang Lab. All Rights Reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================= - */ -package org.dubhe.task; - -import cn.hutool.core.util.StrUtil; -import com.alibaba.fastjson.JSONObject; -import org.dubhe.constant.TrainJobConstant; -import org.dubhe.dao.PtTrainJobMapper; -import org.dubhe.domain.dto.BaseTrainJobDTO; -import org.dubhe.domain.dto.UserDTO; -import org.dubhe.domain.entity.PtTrainJob; -import org.dubhe.domain.vo.PtImageAndAlgorithmVO; -import org.dubhe.enums.BizEnum; -import org.dubhe.enums.LogEnum; -import org.dubhe.enums.ResourcesPoolTypeEnum; -import org.dubhe.enums.TrainJobStatusEnum; -import org.dubhe.exception.BusinessException; -import org.dubhe.k8s.api.NamespaceApi; -import org.dubhe.k8s.api.TrainJobApi; -import org.dubhe.k8s.domain.bo.PtJupyterJobBO; -import org.dubhe.k8s.domain.resource.BizNamespace; -import org.dubhe.k8s.domain.vo.PtJupyterJobVO; -import org.dubhe.k8s.enums.K8sResponseEnum; -import org.dubhe.utils.K8sNameTool; -import org.dubhe.utils.LogUtil; -import org.dubhe.utils.NfsUtil; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.stereotype.Component; - -import java.util.ArrayList; -import java.util.List; - -/** - * @description 提交训练任务 - * @date 2020-07-17 - */ - -@Component -public class TrainJobAsyncTask { - - @Autowired - private K8sNameTool k8sNameTool; - - @Autowired - private NamespaceApi namespaceApi; - - @Autowired - private TrainJobConstant trainJobConstant; - - @Autowired - private NfsUtil nfsUtil; - - @Autowired - private PtTrainJobMapper ptTrainJobMapper; - - @Autowired - private TrainJobApi trainJobApi; - - /** - * 提交job - * - * @param baseTrainJobDTO 训练任务信息 - * @param currentUser 用户 - * @param ptImageAndAlgorithmVO 镜像和算法信息 - */ - public void doJob(BaseTrainJobDTO baseTrainJobDTO, UserDTO currentUser, PtImageAndAlgorithmVO ptImageAndAlgorithmVO, PtTrainJob ptTrainJob) { - PtJupyterJobBO jobBo = null; - String k8sJobName = ""; - boolean flag = false; - try { - //判断是否存在相应的namespace,如果没有则创建 - String namespace = getNamespace(currentUser); - - //封装PtJupyterJobBO对象,调用创建训练任务接口 - jobBo = pkgPtJupyterJobBo(baseTrainJobDTO, currentUser, ptImageAndAlgorithmVO, namespace); - if (null == jobBo) { - LogUtil.error(LogEnum.BIZ_TRAIN, "user{}create TrainJob,Encapsulating ptjupyterjobbo object is empty,the received parameters namespace:{}", currentUser.getId(), namespace); - updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, k8sJobName, flag); - } - PtJupyterJobVO ptJupyterJobResult = trainJobApi.create(jobBo); - k8sJobName = ptJupyterJobResult.getName(); - if (null == ptJupyterJobResult || !ptJupyterJobResult.isSuccess()) { - if (null != ptJupyterJobResult && ("" + K8sResponseEnum.LACK_OF_RESOURCES).equals(ptJupyterJobResult.getCode())) { - updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, k8sJobName, flag); - LogUtil.error(LogEnum.BIZ_TRAIN, "user{}create TrainJob, K8s creation failed, the received parameters are{}, the wrong information is{}", currentUser.getUsername(), - jobBo, ptJupyterJobResult.getMessage()); - } - String message = null == ptJupyterJobResult ? "未知的错误" : ptJupyterJobResult.getMessage(); - LogUtil.error(LogEnum.BIZ_TRAIN, "user{}create TrainJob, K8s creation failed, the received parameters are {}, the wrong information is{}", currentUser.getUsername(), jobBo, message); - updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, k8sJobName, flag); - } - flag = true; - //更新训练任务状态 - updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, k8sJobName, flag); - } catch (Exception e) { - LogUtil.error(LogEnum.BIZ_TRAIN, "user{}create TrainJob, K8s creation failed, the received parameters are {}, the wrong information is{}", currentUser.getUsername(), - jobBo, e); - updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, k8sJobName, flag); - } - } - - - /** - * 获取namespace - * - * @param currentUser 用户 - * @return String 命名空间 - */ - private String getNamespace(UserDTO currentUser) { - String namespaceStr = k8sNameTool.generateNameSpace(currentUser.getId()); - BizNamespace bizNamespace = namespaceApi.get(namespaceStr); - if (null == bizNamespace) { - BizNamespace namespace = namespaceApi.create(namespaceStr, null); - if (null == namespace || !namespace.isSuccess()) { - LogUtil.error(LogEnum.BIZ_TRAIN, "User {} failed to create namespace during training job..."); - throw new BusinessException("内部错误"); - } - } - return namespaceStr; - } - - - /** - * 封装出创建job所需的BO - * - * @param baseTrainJobDTO 训练任务信息 - * @param ptImageAndAlgorithmVO 镜像和算法信息 - * @param namespace 命名空间 - * @return PtJupyterJobBO jupyter任务BO - */ - private PtJupyterJobBO pkgPtJupyterJobBo(BaseTrainJobDTO baseTrainJobDTO, UserDTO currentUser, - PtImageAndAlgorithmVO ptImageAndAlgorithmVO, String namespace) { - //绝对路径 - String commonPath = nfsUtil.getNfsConfig().getBucket() + trainJobConstant.getManage() + StrUtil.SLASH - + currentUser.getId() + StrUtil.SLASH + baseTrainJobDTO.getJobName(); - //相对路径 - String relativeCommonPath = StrUtil.SLASH + trainJobConstant.getManage() + StrUtil.SLASH - + currentUser.getId() + StrUtil.SLASH + baseTrainJobDTO.getJobName(); - String[] codeDirArray = ptImageAndAlgorithmVO.getCodeDir().split(StrUtil.SLASH); - String workspaceDir = codeDirArray[codeDirArray.length - 1]; - // 算法路径待拷贝的地址 - String sourcePath = nfsUtil.getNfsConfig().getBucket() + ptImageAndAlgorithmVO.getCodeDir().substring(1); - String trainDir = commonPath.substring(1) + StrUtil.SLASH + workspaceDir; - LogUtil.info(LogEnum.BIZ_TRAIN, "Algorithm path copy::sourcePath:{},commonPath:{},trainDir:{}", sourcePath, commonPath, trainDir); - boolean bool = nfsUtil.copyPath(sourcePath.substring(1), trainDir); - if (!bool) { - LogUtil.error(LogEnum.BIZ_TRAIN, "During the process of user {} creating training Job and encapsulating k8s creating job interface parameters, it failed to copy algorithm directory {} to the specified directory {}", currentUser.getUsername(), sourcePath, - commonPath); - return null; - } - - List list = new ArrayList<>(); - JSONObject runParams = baseTrainJobDTO.getRunParams(); - - StringBuilder sb = new StringBuilder(); - sb.append(ptImageAndAlgorithmVO.getRunCommand()); - // 拼接out,log和dataset - String pattern = trainJobConstant.getPythonFormat(); - if (ptImageAndAlgorithmVO.getIsTrainOut()) { - nfsUtil.createDir(commonPath + StrUtil.SLASH + trainJobConstant.getOutPath()); - baseTrainJobDTO.setOutPath(relativeCommonPath + StrUtil.SLASH + trainJobConstant.getOutPath()); - sb.append(pattern).append(trainJobConstant.getDockerOutPath()); - } - if (ptImageAndAlgorithmVO.getIsTrainLog()) { - nfsUtil.createDir(commonPath + StrUtil.SLASH + trainJobConstant.getLogPath()); - baseTrainJobDTO.setLogPath(relativeCommonPath + StrUtil.SLASH + trainJobConstant.getLogPath()); - sb.append(pattern).append(trainJobConstant.getDockerLogPath()); - } - if (ptImageAndAlgorithmVO.getIsVisualizedLog()) { - nfsUtil.createDir(commonPath + StrUtil.SLASH + trainJobConstant.getVisualizedLogPath()); - baseTrainJobDTO.setVisualizedLogPath(relativeCommonPath + StrUtil.SLASH + trainJobConstant.getVisualizedLogPath()); - sb.append(pattern).append(trainJobConstant.getDockerVisualizedLogPath()); - } - - sb.append(pattern).append(trainJobConstant.getDockerDataset()); - - String command = sb.toString(); - if (null != runParams && !runParams.isEmpty()) { - sb.append(pattern); - runParams.entrySet() - .forEach(entry -> sb.append(entry.getKey()).append("=").append(entry.getValue()).append(pattern)); - command = sb.toString().substring(0, sb.toString().length() - pattern.length()); - } - list.add("-c"); - command = "echo 'training mission begins... " + command + "\r\n" + "'&& cd " + trainJobConstant.getDockerTrainPath() + StrUtil.SLASH - + workspaceDir + " && " + command + " && echo 'the training mission is over' "; - list.add(command); - PtJupyterJobBO jobBo = new PtJupyterJobBO(); - jobBo.setNamespace(namespace) - .setName(baseTrainJobDTO.getJobName()) - .setImage(ptImageAndAlgorithmVO.getImageName()) - .putNfsMounts(trainJobConstant.getDockerDatasetPath(), nfsUtil.getNfsConfig().getRootDir() + nfsUtil.getNfsConfig().getBucket().substring(1) + baseTrainJobDTO.getDataSourcePath()) - .setCmdLines(list) - .putNfsMounts(trainJobConstant.getDockerTrainPath(), nfsUtil.getNfsConfig().getRootDir() + commonPath.substring(1)) - .setBusinessLabel(k8sNameTool.getPodLabel(BizEnum.ALGORITHM)); - - jobBo.setCpuNum(baseTrainJobDTO.getPtTrainJobSpecs().getSpecsInfo().getInteger("cpuNum")).setMemNum(baseTrainJobDTO.getPtTrainJobSpecs().getSpecsInfo().getInteger("memNum")); - if (ResourcesPoolTypeEnum.GPU.getCode().equals(baseTrainJobDTO.getPtTrainJobSpecs().getResourcesPoolType())) { - jobBo.setUseGpu(true).setGpuNum(baseTrainJobDTO.getPtTrainJobSpecs().getSpecsInfo().getInteger("gpuNum")); - } else { - jobBo.setUseGpu(false); - } - return jobBo; - } - - /** - * 训练任务异步处理更新训练状态 - * - * @param user 用户 - * @param ptTrainJob 训练任务 - * @param baseTrainJobDTO 训练任务信息 - * @param k8sJobName k8s创建的job名称 - * @param flag 创建训练任务是否异常(true:正常,false:失败) - **/ - private void updateTrainStatus(UserDTO user, PtTrainJob ptTrainJob, BaseTrainJobDTO baseTrainJobDTO, String k8sJobName, boolean flag) { - - ptTrainJob.setK8sJobName(k8sJobName) - .setOutPath(baseTrainJobDTO.getOutPath()) - .setLogPath(baseTrainJobDTO.getLogPath()) - .setVisualizedLogPath(baseTrainJobDTO.getVisualizedLogPath()); - LogUtil.info(LogEnum.BIZ_TRAIN, "user {} training tasks are processed asynchronously to update training status,receiving parameters:{}", user.getId(), ptTrainJob); - if (flag) { - ptTrainJobMapper.updateById(ptTrainJob); - } else { - ptTrainJob.setTrainStatus(TrainJobStatusEnum.CREATE_FAILED.getStatus()); - //训练任务创建失败 - ptTrainJobMapper.updateById(ptTrainJob); - throw new BusinessException("内部错误"); - } - } -} diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/utils/ImageUtil.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/utils/ImageUtil.java index 00bb2a4..8f1c483 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/utils/ImageUtil.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/utils/ImageUtil.java @@ -18,24 +18,19 @@ package org.dubhe.utils; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; -import org.apache.commons.collections4.CollectionUtils; import org.dubhe.base.BaseImageDTO; -import org.dubhe.base.MagicNumConstant; import org.dubhe.dao.PtImageMapper; import org.dubhe.domain.dto.UserDTO; import org.dubhe.domain.entity.PtImage; -import org.dubhe.enums.ImageSourceEnum; import org.dubhe.enums.ImageStateEnum; import org.dubhe.enums.LogEnum; import org.dubhe.exception.BusinessException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; -import java.util.List; - /** - * @description: 镜像 - * @date: 2020-06-22 + * @description 镜像 + * @date 2020-06-22 */ @Component public class ImageUtil { @@ -49,32 +44,16 @@ public class ImageUtil { * @param baseImageDTO 镜像参数 * @return BaseImageDTO 镜像url **/ - public String getImages(BaseImageDTO baseImageDTO, UserDTO user) { + public String getImageUrl(BaseImageDTO baseImageDTO, UserDTO user) { QueryWrapper queryWrapper = new QueryWrapper<>(); queryWrapper.eq("image_name", baseImageDTO.getImageName()).eq("image_tag", baseImageDTO.getImageTag()) - .eq("image_status", ImageStateEnum.SUCCESS.getCode()); - List ptImages = ptImageMapper.selectList(queryWrapper); - if (CollectionUtils.isEmpty(ptImages)) { + .eq("image_status", ImageStateEnum.SUCCESS.getCode()).last(" limit 1 "); + PtImage ptImage = ptImageMapper.selectOne(queryWrapper); + if (ptImage == null || StringUtils.isBlank(ptImage.getImageUrl())) { LogUtil.error(LogEnum.BIZ_TRAIN, " User {} gets image ,the imageName is {}, the imageTag is {}, and the result of query image table (PT_image) is empty", user.getUsername(), baseImageDTO.getImageName(), baseImageDTO.getImageTag()); throw new BusinessException("镜像不存在"); } - //获取镜像为用户自定义镜像或预置镜像,且两者自身不能重复 - if (ptImages.size() > MagicNumConstant.TWO) { - LogUtil.error(LogEnum.BIZ_TRAIN, "User {} got more images than scheduled, the imageName provided is {} and the imageTag is {}. The parameters are illegal", user.getUsername(), baseImageDTO.getImageName(), baseImageDTO.getImageTag()); - throw new BusinessException("镜像不合法"); - } - for (PtImage ptImage : ptImages) { - if (ImageSourceEnum.PRE.getCode().equals(ptImage.getImageResource())) { - baseImageDTO.setImageName(ptImage.getImageUrl()); - } else if (user.getId().equals(ptImage.getCreateUserId())) { - baseImageDTO.setImageName(ptImage.getImageUrl()); - } - } - if (StringUtils.isBlank(baseImageDTO.getImageName())) { - LogUtil.error(LogEnum.BIZ_TRAIN, "User {} gets image, the imageName provided is {} and the imageTag is {}. The parameters are illegal", user.getUsername(), baseImageDTO.getImageName(), baseImageDTO.getImageTag()); - throw new BusinessException("镜像不合法"); - } - return baseImageDTO.getImageName(); + return ptImage.getImageUrl(); } } diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/utils/NotebookUtil.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/utils/NotebookUtil.java index 998a597..38913ba 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/utils/NotebookUtil.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/utils/NotebookUtil.java @@ -24,7 +24,6 @@ import cn.hutool.core.util.RandomUtil; import org.dubhe.constant.SymbolConstant; import org.dubhe.enums.BizNfsEnum; import org.dubhe.enums.LogEnum; -import org.dubhe.exception.NotebookBizException; import org.dubhe.k8s.domain.PtBaseResult; import org.apache.shiro.UnavailableSecurityManagerException; import org.dubhe.domain.dto.UserDTO; @@ -32,7 +31,7 @@ import org.dubhe.domain.dto.UserDTO; import java.util.Date; /** - * @description: Notebook 工具类 + * @description Notebook 工具类 * @date 2020-04-27 */ public class NotebookUtil { diff --git a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/utils/TrainUtil.java b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/utils/TrainUtil.java index 5a6e206..8182789 100644 --- a/dubhe-server/dubhe-admin/src/main/java/org/dubhe/utils/TrainUtil.java +++ b/dubhe-server/dubhe-admin/src/main/java/org/dubhe/utils/TrainUtil.java @@ -17,29 +17,64 @@ package org.dubhe.utils; +import org.dubhe.base.MagicNumConstant; +import org.dubhe.constant.SymbolConstant; + +import java.sql.Timestamp; + /** - * @description: 训练任务工具类 - * @date: 2020-07-14 + * @description 训练任务工具类 + * @date 2020-07-14 */ public class TrainUtil { + + + private TrainUtil() { + + } + public static final String REGEXP = "^[a-zA-Z0-9\\-\\_\\u4e00-\\u9fa5]+$"; public static final String REGEXP_NAME = "^[a-zA-Z0-9\\-\\_]+$"; public static final String REGEXP_TAG = "^[a-zA-Z0-9\\-\\_\\.]+$"; public static final String RUNTIME = "%02d:%02d:%02d"; public static final String FOUR_DECIMAL = "%04d"; + public static final String FOUR_TWO = "%.2f"; public static final int NUMBER_ZERO = 0; public static final int NUMBER_ONE = 1; public static final int NUMBER_TWO = 2; public static final int NUMBER_SEVEN = 7; + public static final int NUMBER_EIGHT = 8; public static final int NUMBER_TWENTY = 20; public static final int NUMBER_THIRTY_TWO = 32; public static final int NUMBER_SIXTY_FOUR = 64; public static final int NUMBER_ONE_HUNDRED_AND_TWENTY_SEVEN = 127; public static final int NUMBER_ONE_HUNDRED_AND_TWENTY_EIGHT = 128; + public static final int NUMBER_ONE_HUNDRED_AND_SIXTY_EIGHT = 168; public static final int NUMBER_TWO_HUNDRED_AND_FIFTY_FIVE = 255; public static final int NUMBER_ONE_THOUSAND = 1000; public static final int NUMBER_ONE_THOUSAND_AND_TWENTY_FOUR = 1024; + // 初始化训练时间 + public static final String INIT_RUNTIME = SymbolConstant.BLANK; + + /** + * 获取延时时间 + * @param delayTime 延时时间(单位为小时) + * @return 延时时间 + */ + public static Timestamp getDelayTime(Integer delayTime) { + return new Timestamp(System.currentTimeMillis() + delayTime * MagicNumConstant.SIXTY * MagicNumConstant.SIXTY * MagicNumConstant.ONE_THOUSAND); + } + + /** + * 获取倒计时 + * @param delayTime 延时时间(单位为毫秒) + * @return 倒计时(单位为分钟) + */ + public static Integer getCountDown(Long delayTime) { + return (int) ((delayTime - System.currentTimeMillis()) / (MagicNumConstant.SIXTY * MagicNumConstant.ONE_THOUSAND)); + } + } diff --git a/dubhe-server/dubhe-admin/src/main/resources/kubeconfig-prod b/dubhe-server/dubhe-admin/src/main/resources/kubeconfig-prod new file mode 100644 index 0000000..76cf823 --- /dev/null +++ b/dubhe-server/dubhe-admin/src/main/resources/kubeconfig-prod @@ -0,0 +1,19 @@ +apiVersion: v1 +clusters: +- cluster: + certificate-authority-data: + server: + name: kubernetes +contexts: +- context: + cluster: kubernetes + user: kubernetes-admin + name: kubernetes-admin@kubernetes +current-context: kubernetes-admin@kubernetes +kind: Config +preferences: {} +users: +- name: kubernetes-admin + user: + client-certificate-data: + client-key-data: \ No newline at end of file diff --git a/dubhe-server/dubhe-admin/src/main/resources/logback-spring-dev.xml b/dubhe-server/dubhe-admin/src/main/resources/logback-spring-dev.xml new file mode 100644 index 0000000..d873018 --- /dev/null +++ b/dubhe-server/dubhe-admin/src/main/resources/logback-spring-dev.xml @@ -0,0 +1,263 @@ + + + dubhe + + + + + + + + + ${log.pattern} + ${log.charset} + + + INFO + INFO + ACCEPT + DENY + + + + + + logs/${log.path}/info/dubhe-info.log + + logs/${log.path}/info/dubhe-${app.active}-info-%d{yyyy-MM-dd}.%i.log + + + 50MB + 7 + 250MB + + + %m%n + ${log.charset} + + + true + + INFO + INFO,K8S_CALLBACK + ACCEPT + DENY + + + + + + logs/${log.path}/debug/dubhe-debug.log + + logs/${log.path}/debug/dubhe-${app.active}-debug-%d{yyyy-MM-dd}.%i.log + + + 50MB + 7 + 250MB + + + %m%n + ${log.charset} + + + true + + DEBUG + DEBUG + ACCEPT + DENY + + + + + + logs/${log.path}/error/dubhe-error.log + + logs/${log.path}/error/dubhe-${app.active}-error-%d{yyyy-MM-dd}.%i.log + + + 50MB + 7 + 250MB + + + %m%n + ${log.charset} + + + true + + ERROR + ERROR + ACCEPT + DENY + + + + + + logs/${log.path}/warn/dubhe-warn.log + + logs/${log.path}/warn/dubhe-${app.active}-warn-%d{yyyy-MM-dd}.%i.log + + + 50MB + 7 + 250MB + + + %m%n + ${log.charset} + + + true + + + WARN + WARN + ACCEPT + DENY + + + + + + logs/${log.path}/trace/dubhe-trace.log + + logs/${log.path}/trace/dubhe-${app.active}-trace-%d{yyyy-MM-dd}.%i.log + + + 50MB + 7 + 250MB + + + %m%n + ${log.charset} + + + true + + TRACE + TRACE + ACCEPT + DENY + + + + + + + logs/${log.path}/info/dubhe-schedule.log + + logs/${log.path}/info/dubhe-${app.active}-schedule-%d{yyyy-MM-dd}.%i.log + + + 50MB + 7 + 250MB + + + %m%n + ${log.charset} + + + true + + INFO + SCHEDULE + ACCEPT + DENY + + + + + + logs/${log.path}/info/dubhe-request.log + + logs/${log.path}/info/dubhe-${app.active}-request-%d{yyyy-MM-dd}.%i.log + + + 50MB + 7 + 250MB + + + %m%n + ${log.charset} + + + true + + INFO + + GLOBAL_REQUEST + ACCEPT + DENY + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/BaseTest.java b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/BaseTest.java old mode 100644 new mode 100755 diff --git a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/K8sNameToolTest.java b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/K8sNameToolTest.java index 33c9db1..f38d6a1 100644 --- a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/K8sNameToolTest.java +++ b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/K8sNameToolTest.java @@ -49,12 +49,12 @@ public class K8sNameToolTest { @Test public void generateNameSpace(){ - Assert.assertEquals("namespace-0", k8sNameTool.generateNameSpace(0L)); + Assert.assertEquals("namespace-0", k8sNameTool.generateNamespace(0L)); } @Test public void getUserIdFromNameSpace(){ - Assert.assertSame(0L, k8sNameTool.getUserIdFromNameSpace("namespace-0")); + Assert.assertSame(0L, k8sNameTool.getUserIdFromNamespace("namespace-0")); } diff --git a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/PtImageTest.java b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/PtImageTest.java index 1bb22b5..09773a7 100644 --- a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/PtImageTest.java +++ b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/PtImageTest.java @@ -18,6 +18,8 @@ package org.dubhe; import com.alibaba.fastjson.JSON; +import org.dubhe.domain.dto.PtImageDeleteDTO; +import org.dubhe.domain.dto.PtImageUpdateDTO; import org.dubhe.domain.dto.PtImageUploadDTO; import org.junit.Test; import org.springframework.test.web.servlet.request.MockMvcRequestBuilders; @@ -25,6 +27,8 @@ import org.springframework.test.web.servlet.result.MockMvcResultMatchers; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; +import java.util.Arrays; + /** * @description 镜像接口单元测试 * @date 2020-06-28 @@ -41,8 +45,9 @@ public class PtImageTest extends BaseTest { params.add("size", "10"); params.add("sort", "id"); params.add("order", "desc"); - params.add("imageResource", "1"); - mockMvcWithNoRequestBody(mockMvc.perform(MockMvcRequestBuilders.get("/api/v1/ptImage/info").param("imageResource", "1")) + params.add("imageResource", "0"); + params.add("imageNameOrId", "oneflow"); + mockMvcWithNoRequestBody(mockMvc.perform(MockMvcRequestBuilders.get("/api/v1/ptImage/info").params(params)) .andExpect(MockMvcResultMatchers.status().isOk()).andReturn().getResponse(), 200); } @@ -74,5 +79,40 @@ public class PtImageTest extends BaseTest { MockMvcResultMatchers.status().is2xxSuccessful(), 200); } + /** + * 修改镜像信息 + */ + @Test + public void updateImageTest() throws Exception { + PtImageUpdateDTO imageUpdateDTO = new PtImageUpdateDTO(); + imageUpdateDTO.setIds(Arrays.asList()); + imageUpdateDTO.setRemark(""); + + mockMvcTest(MockMvcRequestBuilders.put("/api/v1/ptImage"), JSON.toJSONString(imageUpdateDTO), + MockMvcResultMatchers.status().is2xxSuccessful(), 200); + } + + /** + * 删除镜像 + */ + @Test + public void deleteImageTest() throws Exception { + PtImageDeleteDTO imageDeleteDTO = new PtImageDeleteDTO(); + imageDeleteDTO.setIds(Arrays.asList()); + + mockMvcTest(MockMvcRequestBuilders.delete("/api/v1/ptImage"), JSON.toJSONString(imageDeleteDTO), + MockMvcResultMatchers.status().is2xxSuccessful(), 200); + } + + + /** + * + * 获取镜像名称列表 + */ + @Test + public void getImageNameListTest() throws Exception { + mockMvcTest(MockMvcRequestBuilders.get("/api/v1/ptImage/imageNameList"), "", + MockMvcResultMatchers.status().is2xxSuccessful(), 200); + } } diff --git a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/PtTrainModelOptJobApiTest.java b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/PtTrainModelOptJobApiTest.java index fad45f0..c55b0a2 100644 --- a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/PtTrainModelOptJobApiTest.java +++ b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/PtTrainModelOptJobApiTest.java @@ -53,7 +53,7 @@ public class PtTrainModelOptJobApiTest extends BaseTest { ptTrainJobCreateDTO.setDataSourceName("dataset/68"); ptTrainJobCreateDTO.setDataSourcePath("dataset/68"); ptTrainJobCreateDTO.setDescription("job描述"); - ptTrainJobCreateDTO.setTrainJobSpecsId(1).setRunCommand("python p.py").setImageName("tensorflow").setImageTag("latest"); + ptTrainJobCreateDTO.setTrainJobSpecsName("11111111111111").setRunCommand("python p.py").setImageName("tensorflow").setImageTag("latest"); JSONObject runParams = new JSONObject(); runParams.put("key1", 33); runParams.put("key2", 33); @@ -147,7 +147,7 @@ public class PtTrainModelOptJobApiTest extends BaseTest { ptTrainJobUpdateDTO.setAlgorithmId(91L); ptTrainJobUpdateDTO.setDataSourceName("dataset/68"); ptTrainJobUpdateDTO.setDataSourcePath("dataset/68"); - ptTrainJobUpdateDTO.setTrainJobSpecsId(1).setRunCommand("python p.py").setImageName("tensorflow").setImageTag("latest"); + ptTrainJobUpdateDTO.setTrainJobSpecsName("").setRunCommand("python p.py").setImageName("tensorflow").setImageTag("latest"); mockMvcTest(MockMvcRequestBuilders.put("/api/v1/trainJob"), JSON.toJSONString(ptTrainJobUpdateDTO), MockMvcResultMatchers.status().is2xxSuccessful(), 200); diff --git a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/TrainAlgorithmTest.java b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/TrainAlgorithmTest.java index 91178aa..353bae5 100644 --- a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/TrainAlgorithmTest.java +++ b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/TrainAlgorithmTest.java @@ -28,8 +28,8 @@ import org.springframework.test.web.servlet.request.MockMvcRequestBuilders; import org.springframework.test.web.servlet.result.MockMvcResultMatchers; /** - * @description: 算法管理模块算法管理单元测试 - * @date: 2020-06-18 + * @description 算法管理模块算法管理单元测试 + * @date 2020-06-18 */ public class TrainAlgorithmTest extends BaseTest { diff --git a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/TrainModelOptJobApiTest.java b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/TrainModelOptJobApiTest.java index 492f673..0cd2c20 100644 --- a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/TrainModelOptJobApiTest.java +++ b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/TrainModelOptJobApiTest.java @@ -30,8 +30,8 @@ import org.springframework.test.web.servlet.request.MockMvcRequestBuilders; import org.springframework.test.web.servlet.result.MockMvcResultMatchers; /** - * @description: - * @date: 2020-04-17 + * @description + * @date 2020-04-17 */ public class TrainModelOptJobApiTest extends BaseTest { @@ -106,7 +106,7 @@ public class TrainModelOptJobApiTest extends BaseTest { dto.setDataSourcePath("dataset/68"); dto.setTrainName("test-train"+System.currentTimeMillis()); dto.setTrainParamDesc("test-train"); - dto.setTrainJobSpecsId(1).setRunCommand("python p.py").setImageName("tensorflow").setImageTag("latest"); + dto.setTrainJobSpecsName("11111").setRunCommand("python p.py").setImageName("tensorflow").setImageTag("latest"); MockHttpServletResponse response = this.mockMvc.perform( MockMvcRequestBuilders. diff --git a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/TrainParamApiTest.java b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/TrainParamApiTest.java index c2f68d2..9f69df5 100644 --- a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/TrainParamApiTest.java +++ b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/TrainParamApiTest.java @@ -31,8 +31,8 @@ import org.springframework.test.web.servlet.result.MockMvcResultMatchers; import org.springframework.transaction.annotation.Transactional; /** - * @description: 训练任务管理模块任务参数管理单元测试 - * @date: 2020-5-11 + * @description 训练任务管理模块任务参数管理单元测试 + * @date 2020-5-11 */ @RunWith(SpringRunner.class) @SpringBootTest diff --git a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/data/service/FileServiceImplTest.java b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/data/service/FileServiceImplTest.java index e10dedb..00a7cb8 100644 --- a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/data/service/FileServiceImplTest.java +++ b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/data/service/FileServiceImplTest.java @@ -1,12 +1,12 @@ /** * Copyright 2020 Zhejiang Lab. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -17,10 +17,8 @@ package org.dubhe.data.service; -import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import org.dubhe.BaseTest; import org.dubhe.data.service.impl.FileServiceImpl; -import org.junit.Test; import org.springframework.beans.factory.annotation.Autowired; /** @@ -32,9 +30,5 @@ public class FileServiceImplTest extends BaseTest { @Autowired private FileServiceImpl fileService; - @Test - public void list() { - System.out.println(fileService.listByLimit(3L, 10, new QueryWrapper<>())); - } } diff --git a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/data/service/TaskServiceServiceImplTest.java b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/data/service/TaskServiceServiceImplTest.java index 8f85ae9..ee74465 100644 --- a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/data/service/TaskServiceServiceImplTest.java +++ b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/data/service/TaskServiceServiceImplTest.java @@ -95,7 +95,7 @@ public class TaskServiceServiceImplTest extends BaseTest { .datasetIds(ids) .build()); - Map progress = fileService.listStatistics(idList); + Map progress = fileService.listStatistics(null); System.out.println(progress); try { Thread.sleep(3000L); diff --git a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/data/service/dataset/DataFactory.java b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/data/service/dataset/DataFactory.java index 66b55ca..94bd3b1 100644 --- a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/data/service/dataset/DataFactory.java +++ b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/data/service/dataset/DataFactory.java @@ -17,14 +17,14 @@ package org.dubhe.data.service.dataset; -import org.dubhe.data.constant.DatasetTypeEnum; +import org.dubhe.enums.DatasetTypeEnum; import org.dubhe.data.domain.dto.DatasetCreateDTO; import org.dubhe.data.domain.dto.DatasetVersionCreateDTO; import java.util.UUID; /** - * 使用简单工厂模式 生成测试数据 - * @create: 2020-05-15 09:42 + * @description 使用简单工厂模式 生成测试数据 + * @date 2020-05-15 09:42 */ public class DataFactory { @@ -47,9 +47,9 @@ public class DataFactory { /** * 数据集版本发布用 - * @param id - * @param versionNum - * @param versionNote + * @param id 数据集ID + * @param versionNum 版本名称 + * @param versionNote 版本说明 * @return */ public static DatasetVersionCreateDTO datasetVersionPublish(Long id, String versionNum, String versionNote) { diff --git a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/data/service/dataset/DatasetApiTest.java b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/data/service/dataset/DatasetApiTest.java index d8adb09..9c6d38b 100644 --- a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/data/service/dataset/DatasetApiTest.java +++ b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/data/service/dataset/DatasetApiTest.java @@ -34,7 +34,7 @@ import org.springframework.util.MultiValueMap; import org.springframework.web.context.WebApplicationContext; /** - * @date: 2020-05-14 + * @date 2020-05-14 */ @SpringBootTest @RunWith(SpringRunner.class) diff --git a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/data/service/dataset/MockUtil.java b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/data/service/dataset/MockUtil.java index 1698c95..72a9473 100644 --- a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/data/service/dataset/MockUtil.java +++ b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/data/service/dataset/MockUtil.java @@ -26,7 +26,7 @@ import org.springframework.test.web.servlet.result.MockMvcResultHandlers; import org.springframework.util.MultiValueMap; /** - * @date: 2020-05-15 + * @date 2020-05-15 */ @Slf4j public class MockUtil { diff --git a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/ResourceCacheTest.java b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/ResourceCacheTest.java index c2cf2fa..4fc7030 100644 --- a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/ResourceCacheTest.java +++ b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/ResourceCacheTest.java @@ -75,4 +75,24 @@ public class ResourceCacheTest { String namespace = "namespace-41",resourceName = "notebook-rn-algorithm-153",podName = "notebook-rn-algorithm-153-7djrk-0"; resourceCache.cachePods(namespace,resourceName); } + + @Test + public void getDistributedLock(){ + System.out.println(redisUtils.getDistributedLock("87jkssshjk","fhfgsssygfjfgh",10)); + System.out.println(redisUtils.getDistributedLock("87jkssshjk","fhfgsssygfjfgh",10)); + System.out.println(redisUtils.releaseDistributedLock("87jkssshjk","fhfgsssygfjfgh")); + System.out.println(redisUtils.getDistributedLock("87jkssshjk","fhfgsssygfjfgh",10)); + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + e.printStackTrace(); + } + System.out.println(redisUtils.getDistributedLock("87jkssshjk","fhfgsssygfjfgh",10)); + try { + Thread.sleep(10000); + } catch (InterruptedException e) { + e.printStackTrace(); + } + System.out.println(redisUtils.getDistributedLock("87jkssshjk","fhfgsssygfjfgh",10)); + } } diff --git a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/api/DistributeTrainApiTest.java b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/api/DistributeTrainApiTest.java new file mode 100644 index 0000000..597b76f --- /dev/null +++ b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/api/DistributeTrainApiTest.java @@ -0,0 +1,84 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.dubhe.k8s.api; + +import cn.hutool.core.io.FileUtil; +import com.alibaba.fastjson.JSON; +import org.dubhe.k8s.domain.PtBaseResult; +import org.dubhe.k8s.domain.bo.DistributeTrainBO; +import org.dubhe.k8s.domain.resource.BizDistributeTrain; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.test.context.junit4.SpringRunner; + +import javax.annotation.Resource; + +/** + * @description NodeApiTest测试类 + * @date 2020-04-22 + */ +@SpringBootTest +@RunWith(SpringRunner.class) +public class DistributeTrainApiTest { + @Resource + private DistributeTrainApi distributeTrainApi; + + @Test + public void testCreate(){ + DistributeTrainBO bo = new DistributeTrainBO(); + bo.setName("yepeng11"); + bo.setNamespace("yep"); + bo.setSize(3); + bo.setImage("harbor.dubhe.ai/oneflow/oneflow-cuda:py36-v3"); + bo.setMasterCmd("export NODE_IPS=`cat /home/hostfile.json |jq -r '.[]|.ip'|paste -d \",\" -s` && cd /workspace/Classification/cnns && rm -rf core.* && rm -rf ./output/snapshots/* && python3 of_cnn_train_val.py --train_data_dir=$DATA_ROOT/train --train_data_part_num=$TRAIN_DATA_PART_NUM --val_data_dir=$DATA_ROOT/validation --val_data_part_num=$VAL_DATA_PART_NUM --num_nodes=$NODE_NUM --node_ips=\"$NODE_IPS\" --gpu_num_per_node=$GPU_NUM_PER_NODE --model_update=\"momentum\" --learning_rate=0.256 --loss_print_every_n_iter=1 --batch_size_per_device=1 --val_batch_size_per_device=1 --num_epoch=1 --model=\"resnet50\" --model_save_dir=/model"); + bo.setMemNum(8192); + bo.setCpuNum(4000); + bo.setGpuNum(2); + bo.setSlaveCmd("export NODE_IPS=`cat /home/hostfile.json |jq -r '.[]|.ip'|paste -d \",\" -s` && cd /workspace/Classification/cnns && rm -rf core.* && rm -rf ./output/snapshots/* && python3 of_cnn_train_val.py --train_data_dir=$DATA_ROOT/train --train_data_part_num=$TRAIN_DATA_PART_NUM --val_data_dir=$DATA_ROOT/validation --val_data_part_num=$VAL_DATA_PART_NUM --num_nodes=$NODE_NUM --node_ips=\"$NODE_IPS\" --gpu_num_per_node=$GPU_NUM_PER_NODE --model_update=\"momentum\" --learning_rate=0.256 --loss_print_every_n_iter=1 --batch_size_per_device=1 --val_batch_size_per_device=1 --num_epoch=1 --model=\"resnet50\" --model_save_dir=/model"); + bo.setDatasetStoragePath("/nfs/sunjd/dataset/of_dataset"); + bo.setWorkspaceStoragePath("/nfs/sunjd/workspace"); + bo.setModelStoragePath("/nfs/sunjd/model"); + bo.setBusinessLabel("train"); + bo.setDelayCreateTime(10); + bo.setDelayDeleteTime(10); + + distributeTrainApi.create(bo); + } + @Test + public void deleteByResourceName() { + PtBaseResult result = distributeTrainApi.deleteByResourceName("tianlong", "tianlong-dt"); + System.out.println(JSON.toJSONString(result)); + } + + @Test + public void create() { + String filePath = "D:\\Devial\\之江实验室\\分布式训练\\image and demo\\resnet50\\demo-env.yaml"; + String ymlStr = FileUtil.readString(filePath,"utf-8"); + BizDistributeTrain result = distributeTrainApi.create(ymlStr); + System.out.println(JSON.toJSONString(result)); + } + + @Test + public void delete() { + String filePath = "D:\\Devial\\之江实验室\\分布式训练\\image and demo\\resnet50\\demo-env.yaml"; + String ymlStr = FileUtil.readString(filePath,"utf-8"); + Boolean result = distributeTrainApi.delete(ymlStr); + System.out.println(JSON.toJSONString(result)); + } +} diff --git a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/api/JupyterResourceApiTest.java b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/api/JupyterResourceApiTest.java index 725177b..039eefa 100644 --- a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/api/JupyterResourceApiTest.java +++ b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/api/JupyterResourceApiTest.java @@ -33,8 +33,8 @@ import javax.annotation.Resource; import java.util.List; /** - * @description: JupyterResourceApiTest测试类 - * @date: 2020-04-14 + * @description JupyterResourceApiTest测试类 + * @date 2020-04-14 */ @SpringBootTest @RunWith(SpringRunner.class) @@ -78,7 +78,8 @@ public class JupyterResourceApiTest { .setWorkspaceDir("/nfs/namespace/workspace1") .setWorkspaceMountPath("/workspace") .setWorkspaceRequest("100Mi") - .setWorkspaceLimit("200Mi"); + .setWorkspaceLimit("200Mi") + .setDelayDeleteTime(20); PtJupyterDeployVO result = jupyterResourceApi.create(bo); System.out.println(JSON.toJSONString(result)); int i = 0; @@ -132,7 +133,8 @@ public class JupyterResourceApiTest { .setWorkspaceMountPath("/workspace") .setWorkspaceRequest("100Mi") .setWorkspaceLimit("200Mi") - .setBusinessLabel("notebook"); + .setBusinessLabel("notebook") + .setDelayDeleteTime(10); PtJupyterDeployVO result = jupyterResourceApi.createWithPvc(bo); System.out.println(JSON.toJSONString(result)); System.out.println(podApi.getUrlByResourceName("namespace","myhfb")); diff --git a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/api/LogMonitoringApiTest.java b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/api/LogMonitoringApiTest.java index dd97449..9100ecf 100644 --- a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/api/LogMonitoringApiTest.java +++ b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/api/LogMonitoringApiTest.java @@ -26,6 +26,7 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.test.context.junit4.SpringRunner; import javax.annotation.Resource; +import java.util.ArrayList; /** * @description LogMonitoringApiTest测试类 @@ -40,24 +41,56 @@ public class LogMonitoringApiTest { /**通过资源名称查找日志信息**/ @Test - public void searchLog() { + public void searchLogByResName() { int from = 1; /**size为0表示无限制**/ int size = 200; + LogMonitoringBO logMonitoringBo = new LogMonitoringBO(); - logMonitoringBo.setIndexName("logstash-*"); - logMonitoringBo.setResourceName("train-1-20200713103822-v0013"); + logMonitoringBo.setResourceName("train-1-20200803170114-v0033"); logMonitoringBo.setNamespace("namespace-1"); - LogMonitoringVO logMonitoringVO = logMonitoringApi.searchLog(from, size, logMonitoringBo); + LogMonitoringVO logMonitoringVO = logMonitoringApi.searchLogByResName(from, size, logMonitoringBo); + + } + + /**通过Pod名称查找日志信息**/ + @Test + public void searchLogByPodName() { + int from = 1; + /**size为0表示无限制**/ + int size = 200; + + LogMonitoringBO logMonitoringBo = new LogMonitoringBO(); + logMonitoringBo.setPodName("train-1-20200828135251-v0013-5zouu-master-8npcd-95d96"); + logMonitoringBo.setNamespace("namespace-1"); + logMonitoringBo.setBeginTimeMillis(0L); + logMonitoringBo.setLogKeyword("training mission begins"); + logMonitoringBo.setEndTimeMillis(1601347634000L); + + LogMonitoringVO logMonitoringVO = logMonitoringApi.searchLogByPodName(from, size, logMonitoringBo); } @Test public void addlog(){ - logMonitoringApi.addLogsToEs("podName", "namespace"); + //logMonitoringApi.addLogsToEs("podName", "namespace"); + logMonitoringApi.addLogsToEs("train-1-20200915103934-v0055-3ppzh-master-untg7-ndnq8", "namespace-1",new ArrayList(){{ + add("Container is being created"); + add("Container is being created"); + } + }); + + } + @Test + public void searchLogCountByPodName(){ + LogMonitoringBO logMonitoringBo = new LogMonitoringBO(); + logMonitoringBo.setPodName("train-1-20200828135251-v0013-5zouu-master-8npcd-95d96"); + logMonitoringBo.setNamespace("namespace-1"); + Long count = logMonitoringApi.searchLogCountByPodName(logMonitoringBo); + System.out.println(count); } } diff --git a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/api/NativeResourceApiTest.java b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/api/NativeResourceApiTest.java new file mode 100644 index 0000000..694b599 --- /dev/null +++ b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/api/NativeResourceApiTest.java @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ +package org.dubhe.k8s.api; + +import cn.hutool.core.io.FileUtil; +import com.alibaba.fastjson.JSON; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.test.context.junit4.SpringRunner; + +import javax.annotation.Resource; + +/** + * @description + * @date 2020-08-28 + */ +@SpringBootTest +@RunWith(SpringRunner.class) +public class NativeResourceApiTest { + @Resource + private NativeResourceApi nativeResourceApi; + + String crYaml = FileUtil.readString("G:\\Kubernetes\\ingress-controller\\demo\\ingress-demo.yml","utf-8"); + + @Test + public void create(){ + System.out.println(JSON.toJSONString(nativeResourceApi.create(crYaml))); + } + + @Test + public void delete(){ + System.out.println(JSON.toJSONString(nativeResourceApi.delete(crYaml))); + } +} diff --git a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/api/NodeApiTest.java b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/api/NodeApiTest.java index 3f67ab8..3a7d50a 100644 --- a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/api/NodeApiTest.java +++ b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/api/NodeApiTest.java @@ -94,7 +94,7 @@ public class NodeApiTest { @Test public void isAllocatable(){ LackOfResourcesEnum flag; - flag = nodeApi.isAllocatable(500000,300000 ,30 ); + flag = nodeApi.isAllocatable(10000,300000 ,30 ); System.out.println(flag.getMessage()); } diff --git a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/api/PodApiTest.java b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/api/PodApiTest.java index 96be4bc..a268825 100644 --- a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/api/PodApiTest.java +++ b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/api/PodApiTest.java @@ -82,12 +82,6 @@ public class PodApiTest { System.out.println(JSON.toJSONString(pod)); } - @Test - public void getPodMetricsGrafanaUrl() { - String podMetricsGrafanaUrl = podApi.getPodMetricsGrafanaUrl("namespace-1", "pod1"); - System.out.println(JSON.toJSONString(podMetricsGrafanaUrl)); - } - @Test public void getWithNamespace() { List podList = podApi.getWithNamespace("namespace-1"); @@ -122,4 +116,9 @@ public class PodApiTest { public void listAllRuningPodGroupByNodeName(){ System.out.println(JSON.toJSONString(podApi.listAllRuningPodGroupByNodeName())); } + @Test + public void findByDtName(){ + List bizPodList = podApi.findByDtName("sun"); + System.out.println(bizPodList); + } } diff --git a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/api/TrainJobApiTest.java b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/api/TrainJobApiTest.java index 1be0200..ef48b64 100644 --- a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/api/TrainJobApiTest.java +++ b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/k8s/api/TrainJobApiTest.java @@ -60,7 +60,7 @@ public class TrainJobApiTest { @Test public void create() { PtJupyterJobBO bo = new PtJupyterJobBO(); - bo.setNamespace("xxx") + bo.setNamespace("namespace-1") .setName("train5") .setCpuNum(500) .setGpuNum(1) @@ -70,9 +70,12 @@ public class TrainJobApiTest { .setNfsMounts(new HashMap(){{ put("/dataset",new PtMountDirBO("/nfs/xxx/dataset")); put("/workspace",new PtMountDirBO("/nfs/xxx/dataset")); + put("/valdataset",new PtMountDirBO("/nfs/xxx/dataset")); }}) .setImage("tensorflow/tensorflow:latest") - .setBusinessLabel("train"); + .setBusinessLabel("train") + .setDelayDeleteTime(10) + .setDelayCreateTime(10); System.out.println("before create"); PtJupyterJobVO result = trainJobApi.create(bo); System.out.println("after create"); diff --git a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/utils/LogUtilTest.java b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/utils/LogUtilTest.java index 859e5a5..395c703 100644 --- a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/utils/LogUtilTest.java +++ b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/utils/LogUtilTest.java @@ -24,7 +24,7 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.test.context.junit4.SpringRunner; /** - * @description: LogUtil 工具测试类 + * @description LogUtil 工具测试类 * @date 2020-6-19 */ @RunWith(SpringRunner.class) diff --git a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/utils/NotebookUtilTest.java b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/utils/NotebookUtilTest.java index 63810b2..46b9214 100644 --- a/dubhe-server/dubhe-admin/src/test/java/org/dubhe/utils/NotebookUtilTest.java +++ b/dubhe-server/dubhe-admin/src/test/java/org/dubhe/utils/NotebookUtilTest.java @@ -38,7 +38,7 @@ import java.text.SimpleDateFormat; import java.util.Date; /** - * @description: Notebook 工具测试类 + * @description Notebook 工具测试类 * * @date 2020.04.27 */ diff --git a/dubhe-server/dubhe-data/pom.xml b/dubhe-server/dubhe-data/pom.xml index 000fa74..220519a 100644 --- a/dubhe-server/dubhe-data/pom.xml +++ b/dubhe-server/dubhe-data/pom.xml @@ -33,114 +33,13 @@ org.bytedeco javacv - 1.4.4 - - - org.bytedeco - javacpp - - - org.bytedeco.javacpp-presets - flycapture - - - org.bytedeco.javacpp-presets - libdc1394 - - - org.bytedeco.javacpp-presets - libfreenect - - - org.bytedeco.javacpp-presets - libfreenect2 - - - org.bytedeco.javacpp-presets - librealsense - - - org.bytedeco.javacpp-presets - videoinput - - - org.bytedeco.javacpp-presets - opencv - - - org.bytedeco.javacpp-presets - tesseract - - - org.bytedeco.javacpp-presets - leptonica - - - org.bytedeco.javacpp-presets - flandmark - - - org.bytedeco.javacpp-presets - artoolkitplus - - + 1.4.3 - org.bytedeco - javacv-platform - 1.4.4 - - - org.bytedeco - javacv - - - org.bytedeco.javacpp-presets - flycapture-platform - - - org.bytedeco.javacpp-presets - libdc1394-platform - - - org.bytedeco.javacpp-presets - libfreenect-platform - - - org.bytedeco.javacpp-presets - libfreenect2-platform - - - org.bytedeco.javacpp-presets - librealsense-platform - - - org.bytedeco.javacpp-presets - videoinput-platform - - - org.bytedeco.javacpp-presets - opencv-platform - - - org.bytedeco.javacpp-presets - tesseract-platform - - - org.bytedeco.javacpp-presets - leptonica-platform - - - org.bytedeco.javacpp-presets - flandmark-platform - - - org.bytedeco.javacpp-presets - artoolkitplus-platform - - + org.bytedeco.javacpp-presets + ffmpeg-platform + 4.0.2-1.4.3 - diff --git a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/AnnotateTypeEnum.java b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/AnnotateTypeEnum.java index 9cc3aeb..4d035e0 100644 --- a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/AnnotateTypeEnum.java +++ b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/AnnotateTypeEnum.java @@ -17,13 +17,10 @@ package org.dubhe.data.constant; - import lombok.Getter; /** - * 标注类型枚举类 - * - * @description 标注类型 + * @description 标注类型枚举类 * @date 2020-05-21 */ @Getter @@ -54,7 +51,7 @@ public enum AnnotateTypeEnum { * 标注类型校验 用户web端接口调用时参数校验 * * @param value 标注类型Integer值 - * @return + * @return 参数校验结果 */ public static boolean isValid(Integer value) { for (AnnotateTypeEnum annotateTypeEnum : AnnotateTypeEnum.values()) { @@ -68,8 +65,8 @@ public enum AnnotateTypeEnum { /** * 根据标注类型获取类型code值 * - * @param annotate - * @return 类型code值 + * @param annotate 标注类型 + * @return 类型code值 */ public static Integer getConvertAnnotateType(String annotate) { for (AnnotateTypeEnum annotateTypeEnum : AnnotateTypeEnum.values()) { diff --git a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/Constant.java b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/Constant.java index e94d7ba..f8fb37c 100644 --- a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/Constant.java +++ b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/Constant.java @@ -18,6 +18,8 @@ package org.dubhe.data.constant; import org.dubhe.base.MagicNumConstant; +import org.dubhe.data.machine.constant.DataStateCodeConstant; +import org.dubhe.data.machine.constant.FileStateCodeConstant; import java.io.File; import java.util.HashSet; @@ -53,17 +55,19 @@ public class Constant { */ public static final long RESERVED_LABEL_ID = MagicNumConstant.TEN_THOUSAND_LONG; + /** + * 自动标注需要符合的状态 + */ public static final Set AUTO_ANNOTATION_NEED_STATUS = new HashSet() {{ - add(FileStatusEnum.INIT.getValue()); + add(FileStateCodeConstant.NOT_ANNOTATION_FILE_STATE); }}; - /** * 自动跟踪需要符合的状态 */ public static final Set AUTO_TRACK_NEED_STATUS = new HashSet() {{ - add(DatasetStatusEnum.FINISHED.getValue()); - add(DatasetStatusEnum.AUTO_FINISHED.getValue()); + add(DataStateCodeConstant.ANNOTATION_COMPLETE_STATE); + add(DataStateCodeConstant.AUTO_TAG_COMPLETE_STATE); }}; /** @@ -71,6 +75,9 @@ public class Constant { */ public static final String DATASET_VERSION_NAME_REGEXP = "^V[0-9]{4}$"; + /** + * 数据集版本格式说明 + */ public static final String DATASET_VERSION_NAME_REGEXP_NOTE = "版本规则: 1.满足V0001结构(V0001-V9999) " + "2.只能是字母、数字、下划线或者中划线组成的合法字符串长度限制8个字符"; @@ -195,4 +202,45 @@ public class Constant { */ public static final String DATASET_DIRECTORY = "dataset"; + /** + * 临时文件 + */ + public static final String UPLOAD_TEMP = File.separator + "upload-temp"; + + + /** + * 分表业务编码 - 文件表 + */ + public static final String DATA_FILE = "DATA_FILE"; + + /** + * 分表业务编码 - 文件版本关系表 + */ + public static final String DATA_VERSION_FILE = "DATA_VERSION_FILE"; + + + /** + * 数据集预置标签组默认ID COCO + */ + public static final Long COCO_ID = 1L; + + /** + * 数据集预置标签组默认ID ImageNet + */ + public static final Long IMAGENET_ID = 2L; + + + /** + * 大数据默认删除数量 + */ + public static final int LIMIT_NUMBER = 10000; + + + + /** + * redis 预置标签key + */ + public final static String DATASET_LABEL_PUB_KEY = "dateset:label:pub"; + + } diff --git a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/ConversionStatusEnum.java b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/ConversionStatusEnum.java index 967311c..87a46a8 100644 --- a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/ConversionStatusEnum.java +++ b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/ConversionStatusEnum.java @@ -56,7 +56,7 @@ public enum ConversionStatusEnum { * 数据转换类型校验 用户web端接口调用时参数校验 * * @param value 数据转换类型 - * @return + * @return 参数校验结果 */ public static boolean isValid(Integer value) { for (ConversionStatusEnum conversionStatusEnum : ConversionStatusEnum.values()) { diff --git a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/DataStatusEnum.java b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/DataStatusEnum.java index 7e514f5..5e30978 100644 --- a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/DataStatusEnum.java +++ b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/DataStatusEnum.java @@ -48,4 +48,5 @@ public enum DataStatusEnum { private int value; private String msg; + } diff --git a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/FileStatusEnum.java b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/DataTaskTypeEnum.java similarity index 53% rename from dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/FileStatusEnum.java rename to dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/DataTaskTypeEnum.java index 56e53f2..ed13875 100644 --- a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/FileStatusEnum.java +++ b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/DataTaskTypeEnum.java @@ -19,56 +19,45 @@ package org.dubhe.data.constant; import lombok.Getter; -import java.util.HashSet; -import java.util.Set; - /** - * @description 文件状态 - * @date 2020-04-10 + * @description 数据集任务类型 + * @date 2020-08-27 */ @Getter -public enum FileStatusEnum { +public enum DataTaskTypeEnum { + + /** + * 自动标注 + */ + ANNOTATION(0, "自动标注"), /** - * 未标注 + * ofrecord格式转换 */ - INIT(0, "未标注"), + OFRECORD(1, "ofrecord格式转换"), /** - * 标注中 + * imageNet */ - ANNOTATING(1, "标注中"), + IMAGE_NET(2, "imageNet"), /** - * 自动标注完成 + * 数据增强 */ - AUTO_ANNOTATION(2, "自动标注完成"), + ENHANCE(3, "数据增强"), /** - * 已标注完成 + * 目标跟踪 */ - FINISHED(3, "标注完成"), + TARGET_TRACK(4, "目标跟踪"), /** - * 目标追踪完成 + * 视频采样 */ - FINISH_AUTO_TRACK(4, "目标追踪完成"), + VIDEO_SAMPLE(5, "视频采样") ; - FileStatusEnum(int value, String msg) { + DataTaskTypeEnum(Integer value, String msg) { this.value = value; this.msg = msg; } - private int value; + private Integer value; private String msg; - /** - * 获取所有文件状态值 - * - * @return - */ - public static Set getAllValue() { - Set allValues = new HashSet<>(); - for (FileStatusEnum fileStatusEnum : FileStatusEnum.values()) { - allValues.add(fileStatusEnum.value); - } - return allValues; - } - } diff --git a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/DatasetStatusEnum.java b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/DatasetStatusEnum.java old mode 100644 new mode 100755 index afd0548..1697854 --- a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/DatasetStatusEnum.java +++ b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/DatasetStatusEnum.java @@ -1,12 +1,12 @@ /** * Copyright 2020 Zhejiang Lab. All Rights Reserved. - * + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -25,6 +25,7 @@ import lombok.Getter; */ @Getter public enum DatasetStatusEnum { + /** * 文件全部未标注 */ @@ -61,7 +62,11 @@ public enum DatasetStatusEnum { /** * 数据增强中 */ - ENHANCING(8, "增强中"); + ENHANCING(8, "增强中"), + /** + * 采样失败 + */ + SAMPLE_FAILED(9, "采样失败"); DatasetStatusEnum(int value, String msg) { this.value = value; @@ -70,4 +75,5 @@ public enum DatasetStatusEnum { private int value; private String msg; + } diff --git a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/DatatypeEnum.java b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/DatatypeEnum.java index 24f25dd..c87667e 100644 --- a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/DatatypeEnum.java +++ b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/DatatypeEnum.java @@ -45,8 +45,9 @@ public enum DatatypeEnum { /** * 数据类型校验 用户web端接口调用时参数校验 + * * @param value 数据类型 - * @return + * @return 参数校验结果 */ public static boolean isValid(Integer value) { for (DatatypeEnum datatypeEnum : DatatypeEnum.values()) { diff --git a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/EnhanceTypeEnum.java b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/EnhanceTypeEnum.java index 6755d79..2d830a9 100644 --- a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/EnhanceTypeEnum.java +++ b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/EnhanceTypeEnum.java @@ -18,8 +18,8 @@ package org.dubhe.data.constant; /** - * @description: 增强算法枚举 - * @date: 2020-06-30 + * @description 增强算法枚举 + * @date 2020-06-30 */ public enum EnhanceTypeEnum { diff --git a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/ErrorEnum.java b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/ErrorEnum.java old mode 100644 new mode 100755 index 444129a..9b633d4 --- a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/ErrorEnum.java +++ b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/ErrorEnum.java @@ -57,6 +57,11 @@ public enum ErrorEnum implements ErrorCode { FILE_EXIST(1304, "文件已存在"), VIDEO_EXIST(1305, "数据集存在视频"), FILE_DELETE_ERROR(1306, "文件删除失败"), + DATASET_TYPE_MODIFY_ERROR(1307,"数据集存在文件不可更改数据类型"), + DATASET_ANNOTATION_MODIFY_ERROR(1308,"非未标注状态不可更改标注类型"), + DATASET_PUBLIC_LIMIT_ERROR(1309,"公共数据集不可操作"), + TASK_ABSENT(1302, "任务不存在"), + /** * 数据集标注操作错误 @@ -75,6 +80,10 @@ public enum ErrorEnum implements ErrorCode { */ LABEL_ERROR(1600, "标签名不能为空或非系统自动标注支持的标签"), LABEL_NAME_EXIST(1601, "本数据集已有同名标签"), + LABEL_NAME_DUPLICATION(1602,"标签名重复,请检查"), + LABEL_NOT_EXISTS(1603,"标签不存在"), + LABEL_NAME_COLOR_NOT_NULL(1604,"JSON文件中标签名称和颜色不能为空"), + LABEL_PUBLIC_EORROR(1605,"不允许操作公共标签"), /** * 数据集操作错误 @@ -91,6 +100,8 @@ public enum ErrorEnum implements ErrorCode { DATASET_VIDEO_HAS_NOT_BEEN_AUTOMATICALLY_TRACKED(1709, "该数据集视频未自动跟踪完成,请稍等"), DATASET_LABEL_EMPTY(1710, "增强类型不能为空!"), DATASET_ENHANCEMENT(1711, "该数据集正在增强中,请稍等"), + DATASET_TRACK_TYPE_ERROR(1712, "数据集类型只能是目标跟踪才能进行跟踪!"), + DATASET_DELETE_ERROR(1713, "数据集数据大数据删除异常!"), /** * 数据集版本校验 @@ -99,8 +110,25 @@ public enum ErrorEnum implements ErrorCode { DATASET_VERSION_PTJOB_STATUS(1802, "当前数据集正在训练不可删除"), DATASET_NOT_ENHANCE(1803, "数据集状态只能是自动标注完成、标注完成、目标跟踪完成才能进行数据增强!"), DATASET_PUBLIC_ERROR(1900, "不允许操作公共数据集"), - ; + /** + * 标签组错误 + */ + LABELGROUP_NAME_DUPLICATED_ERROR(1901,"标签组名已存在"), + LABELGROUP_PUBLIC_ERROR(1902,"不允许操作公共标签组"), + LABELGROUP_IN_USE_STATUS(1903,"当前标签组内标签正在使用,无法操作"), + LABELGROUP_JSON_FILE_ERROR(1904,"请上传json格式文件"), + LABELGROUP_JSON_FILE_SIZE_ERROR(1905,"文件大小不能超过5M"), + LABELGROUP_JSON_FILE_FORMAT_ERROR(1906,"请输入正确的JSON内容"), + LABELGROUP_DOES_NOT_EXIST(1907,"标签组不存在"), + LABELGROUP_FILE_NAME_NOT_EXIST(1908,"请输入文件名称"), + LABELGROUP_LABELG_ID_ERROR(1909,"标签ID异常"), + LABELGROUP_OPERATE_LABEL_ID_ERROR(1910,"不允许操作公共标签组中的标签"), + LABELGROUP_LABEL_NAME_ERROR(1911,"请输入正确预置标签组标签"), + LABELGROUP_LABEL_GROUP_EDIT_ERROR(1912,"标签组下标签不许修改"), + LABELGROUP_LABEL_GROUP_QUOTE_DEL_ERROR(1913,"标签组已被数据集引用,无法删除!"), + ; + ; ErrorEnum(int code, String msg) { this.code = code; diff --git a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/FileTypeEnum.java b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/FileTypeEnum.java old mode 100644 new mode 100755 index 638511f..25fb80d --- a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/FileTypeEnum.java +++ b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/FileTypeEnum.java @@ -18,6 +18,8 @@ package org.dubhe.data.constant; import lombok.Getter; +import org.dubhe.data.machine.constant.FileStateCodeConstant; +import org.dubhe.data.machine.enums.FileStateEnum; import java.util.HashMap; import java.util.HashSet; @@ -37,47 +39,92 @@ public enum FileTypeEnum { /** * 未标注 */ - UNFINISHED(1, "未标注"), + UNFINISHED(101, "未标注"), + /** + * 手动标注中 + */ + MANUAL_ANNOTATION(102, "手动标注中"), /** * 自动标注完成 */ - AUTO_FINISHED(2, "自动标注完成"), + AUTO_FINISHED(103, "自动标注完成"), + /** + * 已标注完成 + */ + FINISHED(104, "手动标注完成"), /** * 已标注完成 */ - FINISHED(3, "手动标注完成"), + ANNOTATION_NOT_DISTINGUISH_FILE(105, "标注完成未识别"), /** * 自动目标跟踪完成 */ - AUTO_TRACK_FINISHED(4, "自动目标跟踪完成"); + AUTO_TRACK_FINISHED(201, "自动目标跟踪完成"), + + /** + * 未完成 + */ + UNFINISHED_FILE(301,"未完成"), + + /** + * 已完成 + */ + FINISHED_FILE(302,"已完成"); static Set ALL_STATUS = new HashSet() {{ - addAll(FileStatusEnum.getAllValue()); + addAll(FileStateEnum.getAllValue()); }}; static Set UNFINISHED_STATUS = new HashSet() {{ - add(FileStatusEnum.INIT.getValue()); - add(FileStatusEnum.ANNOTATING.getValue()); + add(FileStateCodeConstant.NOT_ANNOTATION_FILE_STATE); + }}; + + static Set MANUAL_ANNOTATION_STATUS = new HashSet() {{ + add(FileStateCodeConstant.MANUAL_ANNOTATION_FILE_STATE); }}; static Set AUTO_FINISHED_STATUS = new HashSet() {{ - add(FileStatusEnum.AUTO_ANNOTATION.getValue()); + add(FileStateCodeConstant.AUTO_TAG_COMPLETE_FILE_STATE); }}; static Set FINISHED_STATUS = new HashSet() {{ - add(FileStatusEnum.FINISHED.getValue()); + add(FileStateCodeConstant.ANNOTATION_COMPLETE_FILE_STATE); + }}; + + static Set ANNOTATION_NOT_DISTINGUISH_FILE_STATUS = new HashSet() {{ + add(FileStateCodeConstant.ANNOTATION_NOT_DISTINGUISH_FILE_STATE); }}; static Set AUTO_TRACK_FINISHED_STATUS = new HashSet() {{ - add(FileStatusEnum.FINISH_AUTO_TRACK.getValue()); + add(FileStateCodeConstant.TARGET_COMPLETE_FILE_STATE); + }}; + + /** + * 未完成 + */ + static Set UNFINISHED_FILE_STATUS = new HashSet() {{ + add(FileStateCodeConstant.NOT_ANNOTATION_FILE_STATE); + add(FileStateCodeConstant.ANNOTATION_NOT_DISTINGUISH_FILE_STATE); + }}; + + /** + * 已完成 + */ + static Set FINISHED_FILE_STATUS = new HashSet() {{ + add(FileStateCodeConstant.AUTO_TAG_COMPLETE_FILE_STATE); + add(FileStateCodeConstant.ANNOTATION_COMPLETE_FILE_STATE); }}; private static final Map> TYPE_STATUS_MAP = new HashMap>() {{ put(All.value, ALL_STATUS); put(UNFINISHED.value, UNFINISHED_STATUS); + put(MANUAL_ANNOTATION.value, MANUAL_ANNOTATION_STATUS); put(AUTO_FINISHED.value, AUTO_FINISHED_STATUS); put(FINISHED.value, FINISHED_STATUS); + put(ANNOTATION_NOT_DISTINGUISH_FILE.value, ANNOTATION_NOT_DISTINGUISH_FILE_STATUS); put(AUTO_TRACK_FINISHED.value, AUTO_TRACK_FINISHED_STATUS); + put(UNFINISHED_FILE.value, UNFINISHED_FILE_STATUS); + put(FINISHED_FILE.value, FINISHED_FILE_STATUS); }}; FileTypeEnum(int value, String msg) { @@ -91,8 +138,8 @@ public enum FileTypeEnum { /** * 获取指定数据集状态下的文件状态列表 * - * @param type - * @return + * @param type 文件类型 + * @return Set 符合条件的文件类型集合 */ public static Set getStatus(Integer type) { return TYPE_STATUS_MAP.get(type); diff --git a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/TaskSplitStatusEnum.java b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/TaskSplitStatusEnum.java old mode 100644 new mode 100755 index 9b25d5a..eca7b03 --- a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/TaskSplitStatusEnum.java +++ b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/TaskSplitStatusEnum.java @@ -25,10 +25,15 @@ import lombok.Getter; */ @Getter public enum TaskSplitStatusEnum { + /** * 进行中 */ ING(1, "进行中"), + + /** + * 已完成 + */ FINISHED(2, "已完成"), ; @@ -39,4 +44,5 @@ public enum TaskSplitStatusEnum { private int value; private String msg; + } diff --git a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/TaskStatusEnum.java b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/TaskStatusEnum.java old mode 100644 new mode 100755 index e8ad0bc..e897c94 --- a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/TaskStatusEnum.java +++ b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/TaskStatusEnum.java @@ -25,11 +25,22 @@ import lombok.Getter; */ @Getter public enum TaskStatusEnum { + + /** + * 未处理 + */ + INIT(0, "未处理"), /** * 进行中 */ ING(1, "进行中"), + /** + * 已完成 + */ FINISHED(2, "已完成"), + /** + * 失败 + */ FAIL(3, "失败"), ; @@ -40,4 +51,5 @@ public enum TaskStatusEnum { private int value; private String msg; + } diff --git a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/dao/DatasetGroupLabelMapper.java b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/dao/DatasetGroupLabelMapper.java new file mode 100644 index 0000000..a5fee80 --- /dev/null +++ b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/dao/DatasetGroupLabelMapper.java @@ -0,0 +1,29 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.dubhe.data.dao; + +import com.baomidou.mybatisplus.core.mapper.BaseMapper; +import org.dubhe.annotation.DataPermission; +import org.dubhe.data.domain.entity.DatasetGroupLabel; + +/** + * @description 标签组标签中间表 Mapper 接口 + * @date 2020-09-22 + */ +public interface DatasetGroupLabelMapper extends BaseMapper { +} diff --git a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/dao/DatasetMapper.java b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/dao/DatasetMapper.java index d7d4159..99c8eeb 100644 --- a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/dao/DatasetMapper.java +++ b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/dao/DatasetMapper.java @@ -30,30 +30,17 @@ import org.dubhe.data.domain.entity.Dataset; * @description 数据集管理 Mapper 接口 * @date 2020-04-10 */ -@DataPermission(ignores = {"insert"}) +@DataPermission(ignoresMethod = {"insert","selectById","selectCountByPublic"}) public interface DatasetMapper extends BaseMapper { /** * 分页获取数据集 * - * @param page 分页插件 - * @param queryWrapper 查询条件 + * @param page 分页插件 + * @param queryWrapper 查询条件 * @return Page数据集列表 */ - @DataPermission(permission = PermissionConstant.SELECT) @Select("SELECT * FROM data_dataset ${ew.customSqlSegment}") - @Results(id = "datasetMapperResults", - value = { - @Result(column = "team_id", property = "team", - one = @One(select = "org.dubhe.dao.TeamMapper.selectById", - fetchType = FetchType.LAZY)), - @Result(column = "create_user_id", property = "createUser", - one = @One(select = "org.dubhe.dao.UserMapper.selectById", - fetchType = FetchType.LAZY)), - @Result(column = "update_user_id", property = "updateUser", - one = @One(select = "org.dubhe.dao.UserMapper.selectById", - fetchType = FetchType.LAZY)) - }) Page listPage(Page page, @Param("ew") Wrapper queryWrapper); /** @@ -62,7 +49,6 @@ public interface DatasetMapper extends BaseMapper { * @param id 数据集ID * @param versionName 数据集版本名称 */ - @DataPermission(permission = PermissionConstant.UPDATE) @Update("update data_dataset set current_version_name = #{versionName} where id = #{id}") void updateVersionName(@Param("id") Long id, @Param("versionName") String versionName); @@ -80,9 +66,27 @@ public interface DatasetMapper extends BaseMapper { * @param datasetId 数据集ID * @param sourceState 压缩开始状态 * @param targetState 压缩结束状态 - * @return + * @return int 被修改行数 */ @Update("update data_dataset set decompress_state = #{targetState} where id = #{datasetId} and decompress_state = #{sourceState}") int updateDecompressState(@Param("datasetId") Long datasetId, @Param("sourceState") Integer sourceState, @Param("targetState") Integer targetState); + /** + * 获取指定类型数据集的数量 + * + * @param type 数据集类型 + * @return 公共数据集的数量 + */ + @Select("SELECT count(1) FROM data_dataset where type = #{type}") + int selectCountByPublic(@Param("type") Integer type); + + + /** + * 根据标签组ID查询关联的数据集数量 + * + * @param labelGroupId 标签组ID + * @return 数量 + */ + @Select("SELECT count(1) FROM data_dataset where label_group_id = #{labelGroupId}") + int getCountByLabelGroupId(@Param("labelGroupId")Long labelGroupId); } diff --git a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/dao/DatasetVersionFileMapper.java b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/dao/DatasetVersionFileMapper.java index af5c9da..68e7ba8 100644 --- a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/dao/DatasetVersionFileMapper.java +++ b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/dao/DatasetVersionFileMapper.java @@ -24,18 +24,20 @@ import org.dubhe.data.domain.entity.DatasetVersionFile; import org.dubhe.data.domain.entity.File; import java.util.List; +import java.util.Map; +import java.util.Set; /** * @description 数据集版本文件 Mapper 接口 - * @date: 2020-05-14 + * @date 2020-05-14 */ public interface DatasetVersionFileMapper extends BaseMapper { /** * 根据数据集ID和版本名称获取正常状态下对应文件列表 * - * @param datasetId 数据集ID - * @param versionName 数据集版本名称 + * @param datasetId 数据集ID + * @param versionName 数据集版本名称 * @return List 数据集ID和版本名称获取正常状态下对应文件列表 */ @Select("select * from data_dataset_version_file where dataset_id = #{datasetId} " + @@ -49,7 +51,6 @@ public interface DatasetVersionFileMapper extends BaseMapper * @param versionSource 原版本号 * @param versionTarget 目的版本号 */ - @UpdateProvider(type = org.dubhe.data.dao.provider.DatasetVersionFileProvider.class, method = "newShipVersionNameChange") void newShipVersionNameChange(@Param("datasetId") Long datasetId, @Param("versionSource") String versionSource, @Param("versionTarget") String versionTarget); @@ -60,7 +61,6 @@ public interface DatasetVersionFileMapper extends BaseMapper * @param versionName 版本名称 * @param fileIds 文件ID */ - @UpdateProvider(type = org.dubhe.data.dao.provider.DatasetVersionFileProvider.class, method = "deleteShip") void deleteShip(@Param("datasetId") Long datasetId, @Param("versionName") String versionName, @Param("fileIds") List fileIds); @@ -75,11 +75,10 @@ public interface DatasetVersionFileMapper extends BaseMapper /** * 按数据集和版本查找文件状态列表 * - * @param datasetId 数据集id - * @param versionName 数据集版本名称 - * @return List 数据集和版本列表 + * @param datasetId 数据集id + * @param versionName 数据集版本名称 + * @return List 数据集和版本列表 */ - @SelectProvider(type = org.dubhe.data.dao.provider.DatasetVersionFileProvider.class, method = "findFileStatusListByDatasetAndVersion") List findFileStatusListByDatasetAndVersion(@Param("datasetId") Long datasetId, @Param("versionName") String versionName); /** @@ -89,7 +88,6 @@ public interface DatasetVersionFileMapper extends BaseMapper * @param versionName 数据集名称 * @param changed 是否改变 */ - @UpdateProvider(type = org.dubhe.data.dao.provider.DatasetVersionFileProvider.class, method = "rollbackFileAndAnnotationStatus") void rollbackFileAndAnnotationStatus(@Param("datasetId") Long datasetId, @Param("versionName") String versionName, @Param("changed") int changed); /** @@ -102,8 +100,8 @@ public interface DatasetVersionFileMapper extends BaseMapper /** * 获取数据集增强文件 * - * @param datasetId 数据集ID - * @param versionName 数据集版本名称 + * @param datasetId 数据集ID + * @param versionName 数据集版本名称 * @return List 数据集增强文件 */ List getNeedEnhanceFilesByDatasetIdAndVersionName(@Param("datasetId") Long datasetId, @Param("versionName") String versionName); @@ -118,11 +116,79 @@ public interface DatasetVersionFileMapper extends BaseMapper */ List getEnhanceFileList(@Param("datasetId") Long datasetId, @Param("versionName") String versionName, @Param("fileId") Long fileId); + /** + * 获取当前版本对应增强文件数量 + * + * @param datasetId 数据集ID + * @param versionName 数据集版本名称 + * @return Integer 当前版本对应增强文件数量 + */ + Integer getEnhanceFileCount(@Param("datasetId")Long datasetId,@Param("versionName") String versionName); + /** * 查询当前数据集版本的原始文件数量 * * @param dataset 当前数据集 - * @return: Integer 原始文件数量 + * @return Integer 原始文件数量 */ Integer getSourceFileCount(@Param("dataset") Dataset dataset); + + /** + * 更新数据集版本文件状态 + * + * @param datasetVersionFile datasetVersionFile + * @param status 数据集文件状态 + */ + @Update("update data_dataset_version_file set annotation_status = #{status} where dataset_id = #{datasetVersionFile.datasetId} " + + " and (version_name = #{datasetVersionFile.versionName} or version_name is NULL) and file_id = #{datasetVersionFile.fileId} and annotation_status = #{datasetVersionFile.status}") + void updateStatus(@Param("datasetVersionFile") DatasetVersionFile datasetVersionFile, @Param("status") Integer status); + + /** + * 获取数据集的版本文件数据 + * + * @param datasetIds 数据集IDS + * @return 数据集的版本文件数据 + */ + List listDatasetVersionFileByDatasetIds(List datasetIds); + + /** + * 获取数据集文件状态统计数据 + * + * @param datasetId 数据集ID + * @param versionName 数据集版本名称 + * @return 数据集文件状态统计数据 + */ + @MapKey("status") + Map getDatasetVersionFileCount(@Param("datasetId") Long datasetId, @Param("versionName") String versionName); + + /** + * 分页查询数据集文件中间表 + * + * @param datasetId 数据集ID + * @param versionName 版本名称 + * @param status 文件状态 + * @param offset 偏移量 + * @param limit 页容量 + * @return 数据集版本文件列表 + */ + List getListByDatasetIdAndAnnotationStatus(@Param("datasetId") Long datasetId, @Param("versionName") String versionName, @Param("status") Set status, @Param("offset") Long offset, @Param("limit") Integer limit); + + /** + * 根据数据集id,版本查询状态为删除的数据版本文件中间表 + * + * @param id 数据集Id + * @param currentVersionName 数据集版本 + * @return DatasetVersionFile Dataset版本文件关系表 + */ + List findStatusByDatasetIdAndVersionName(@Param("datasetId") Long id, @Param("versionName") String currentVersionName); + + /** + * 根据数据集ID删除数据版本文件数据 + * + * @param datasetId 数据集ID + * @param limitNumber 删除数量 + * @return 成功删除条数 + */ + @Delete("delete from data_dataset_version_file where dataset_id = #{datasetId} limit #{limitNumber} ") + int deleteBydatasetId(@Param("datasetId") Long datasetId, @Param("limitNumber") int limitNumber); } diff --git a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/dao/DatasetVersionMapper.java b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/dao/DatasetVersionMapper.java index d2ff608..9da18c3 100644 --- a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/dao/DatasetVersionMapper.java +++ b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/dao/DatasetVersionMapper.java @@ -18,9 +18,7 @@ package org.dubhe.data.dao; import com.baomidou.mybatisplus.core.mapper.BaseMapper; -import org.apache.ibatis.annotations.Delete; -import org.apache.ibatis.annotations.Param; -import org.apache.ibatis.annotations.Select; +import org.apache.ibatis.annotations.*; import org.dubhe.annotation.DataPermission; import org.dubhe.constant.PermissionConstant; import org.dubhe.data.domain.entity.DatasetVersion; @@ -31,17 +29,16 @@ import java.util.List; * @description 数据集 * @date 2020-05-14 */ -@DataPermission(ignores = {"insert", "getMaxVersionName", "selectPage", "update"}) +@DataPermission(ignoresMethod = {"insert"}) public interface DatasetVersionMapper extends BaseMapper { /** * 查询某个数据集的某个版本是否存在 * - * @param datasetId 数据集ID - * @param versionName 数据集版本 + * @param datasetId 数据集ID + * @param versionName 数据集版本 * @return List 数据集的版本信息 */ - @DataPermission(permission = PermissionConstant.SELECT) @Select("select * from data_dataset_version where dataset_id = #{datasetId} and version_name = #{versionName}") List findDatasetVersion(@Param("datasetId") Long datasetId, @Param("versionName") String versionName); @@ -49,10 +46,9 @@ public interface DatasetVersionMapper extends BaseMapper { /** * 获取指定数据集当前使用最大版本号 * - * @param datasetId 数据集ID - * @return String 指定数据集当前使用最大版本号 + * @param datasetId 数据集ID + * @return String 指定数据集当前使用最大版本号 */ - @DataPermission(permission = PermissionConstant.SELECT) @Select("select max(version_name) from data_dataset_version where dataset_id = #{datasetId} and version_name like 'V%'") String getMaxVersionName(@Param("datasetId") Long datasetId); @@ -67,11 +63,10 @@ public interface DatasetVersionMapper extends BaseMapper { /** * 获取当前数据集版本的url * - * @param datasetId 数据集ID - * @param versionName 数据集版本 - * @return: List 数据集版本的url + * @param datasetId 数据集ID + * @param versionName 数据集版本 + * @return List 数据集版本的url */ - @DataPermission(permission = PermissionConstant.SELECT) @Select("SELECT version_url FROM data_dataset_version WHERE dataset_id = #{datasetId} and version_name = #{versionName}") List selectVersionUrl(@Param("datasetId") Long datasetId, @Param("versionName") String versionName); diff --git a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/dao/FileMapper.java b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/dao/FileMapper.java index 5170d4d..8c077f5 100644 --- a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/dao/FileMapper.java +++ b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/dao/FileMapper.java @@ -19,6 +19,7 @@ package org.dubhe.data.dao; import com.baomidou.mybatisplus.core.conditions.Wrapper; import com.baomidou.mybatisplus.core.mapper.BaseMapper; +import org.apache.ibatis.annotations.Delete; import org.apache.ibatis.annotations.Param; import org.apache.ibatis.annotations.Select; import org.apache.ibatis.annotations.Update; @@ -33,9 +34,7 @@ import java.util.List; * @description 文件信息 Mapper 接口 * @date 2020-04-10 */ -@DataPermission(ignores = {"insert", "selectListByLimit", "updateSampleStatus", "getOneById", "selectPage", "selectListByLimit", "selectOne", "selectList", "selectBatchIds", "selectById", "selectCount", "update" - -}) +@DataPermission(ignoresMethod = {"insert","getOneById","selectFile"}) public interface FileMapper extends BaseMapper { /** @@ -44,17 +43,16 @@ public interface FileMapper extends BaseMapper { * @param offset 偏移量 * @param limit 页容量 * @param queryWrapper 查询条件 - * @return List File列表 + * @return List File列表 */ - @DataPermission(permission = PermissionConstant.SELECT) @Select("select * from data_file ${ew.customSqlSegment} limit #{offset}, #{limit}") List selectListByLimit(@Param("offset") long offset, @Param("limit") int limit, @Param("ew") Wrapper queryWrapper); /** * 将文件状态改为采样中 * - * @param id 文件ID - * @param status 文件状态 + * @param id 文件ID + * @param status 文件状态 * @return updateSampleStatus 执行次数 */ @Update("update data_file set status=1 where id = #{id} and status = #{status}") @@ -63,25 +61,78 @@ public interface FileMapper extends BaseMapper { /** * 根据文件ID获取文件 * - * @param fileId 文件ID - * @return File 文件对象 + * @param fileId 文件ID + * @return File 文件对象 */ - @Select("select * from data_file where id = #{fileId}") - File getOneById(@Param("fileId") Long fileId); + @Select("select * from data_file where id = #{fileId} and dataset_id = #{datasetId}") + File getOneById(@Param("fileId") Long fileId,@Param("datasetId") long datasetId); /** * 批量保存 * - * @param files 上传文件列表 - * @param userId 用户Id + * @param files 上传文件列表 + * @param userId 用户Id * @param datasetUserId 数据集用户id */ void saveList(@Param("files") List files, @Param("userId") Long userId, @Param("datasetUserId") Long datasetUserId); /** * 查询图片宽高 + * + * @param name 数据集的版本文件名称 + * @param datasetId 数据集ID + * @return FileCreateDTO 文件详情 */ @Select("select width,height from data_file where name = #{name} and dataset_id = #{datasetId}") FileCreateDTO selectWidthAndHeight(@Param("name") String name, @Param("datasetId") Long datasetId); + /** + * 获取文件详情 + * + * @param fileId 文件ID + * @param datasetId 数据集ID + * @return File 文件详情 + */ + @Select("select * from data_file where id = #{fileId} and dataset_id=#{datasetId} and deleted=0") + File selectFile(@Param("fileId") Long fileId,@Param("datasetId") Long datasetId); + + /** + * 分页获取数据集文件 + * + * @param datasetId 数据集ID + * @param currentVersionName 数据集版本名称 + * @param offset 偏移量 + * @param batchSize 批长度 + * @return List 文件列表 + */ + @Select("") + List selectListOne(@Param("datasetId") Long datasetId,@Param("currentVersionName") String currentVersionName,@Param("offset") int offset,@Param("batchSize") int batchSize); + + /** + * 更新文件状态 + * + * @param datasetId 数据集ID + * @param id 文件ID + * @param status 文件状态 + */ + @Update("update data_file set status = #{status} where dataset_id = #{datasetId} and id = #{id}") + void updateFileStatus(@Param("datasetId") Long datasetId, @Param("id") Long id, @Param("status") Integer status); + + /** + * 根据数据集ID删除文件数据 + * + * @param datasetId 数据集ID + * @param limitNumber 删除数量 + * @return 成功删除条数 + */ + @Delete("delete from data_file where dataset_id = #{datasetId} limit #{limitNumber} ") + int deleteBydatasetId(@Param("datasetId") Long datasetId, @Param("limitNumber") int limitNumber); } diff --git a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/dao/LabelGroupMapper.java b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/dao/LabelGroupMapper.java new file mode 100644 index 0000000..e96fe86 --- /dev/null +++ b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/dao/LabelGroupMapper.java @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Zhejiang Lab. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.dubhe.data.dao; + +import com.baomidou.mybatisplus.core.mapper.BaseMapper; +import org.dubhe.annotation.DataPermission; +import org.dubhe.data.domain.entity.LabelGroup; + +/** + * @description 标签组管理 Mapper 接口 + * @date 2020-09-22 + */ +@DataPermission(ignoresMethod = {"insert"}) +public interface LabelGroupMapper extends BaseMapper { +} diff --git a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/dao/LabelMapper.java b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/dao/LabelMapper.java index e666f34..9dc490a 100644 --- a/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/dao/LabelMapper.java +++ b/dubhe-server/dubhe-data/src/main/java/org/dubhe/data/dao/LabelMapper.java @@ -17,20 +17,19 @@ package org.dubhe.data.dao; +import com.baomidou.mybatisplus.core.mapper.BaseMapper; import org.apache.ibatis.annotations.Param; import org.apache.ibatis.annotations.Select; +import org.dubhe.data.domain.dto.LabelDTO; import org.dubhe.data.domain.entity.Label; -import com.baomidou.mybatisplus.core.mapper.BaseMapper; -import org.dubhe.annotation.DataPermission; +import org.springframework.security.core.parameters.P; import java.util.List; -import java.util.Set; /** * @description 数据集标签管理 Mapper 接口 * @date 2020-04-10 */ -@DataPermission(ignores = {"insert", "listLabelByDatasetId", "getDatasetLabelTypes", "selectListByType", "batchListByIds"}) public interface LabelMapper extends BaseMapper