Browse Source

fix:修复TADL模块存在的循环依赖的问题

tags/V3.0
jiang 2 years ago
parent
commit
b39dfb2b40
2 changed files with 51 additions and 47 deletions
  1. +48
    -46
      dubhe-server/dubhe-tadl/src/main/java/org/dubhe/tadl/service/impl/TadlRedisServiceImpl.java
  2. +3
    -1
      dubhe-server/dubhe-tadl/src/main/java/org/dubhe/tadl/task/TrialJobAsyncTask.java

+ 48
- 46
dubhe-server/dubhe-tadl/src/main/java/org/dubhe/tadl/service/impl/TadlRedisServiceImpl.java View File

@@ -61,6 +61,7 @@ import org.redisson.api.RLock;
import org.redisson.api.RedissonClient; import org.redisson.api.RedissonClient;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Lazy;
import org.springframework.data.domain.Range; import org.springframework.data.domain.Range;
import org.springframework.data.redis.connection.stream.MapRecord; import org.springframework.data.redis.connection.stream.MapRecord;
import org.springframework.data.redis.connection.stream.ObjectRecord; import org.springframework.data.redis.connection.stream.ObjectRecord;
@@ -90,7 +91,7 @@ import java.util.stream.Collectors;
@Service @Service
public class TadlRedisServiceImpl implements TadlRedisService { public class TadlRedisServiceImpl implements TadlRedisService {


@Resource
@Resource
private StringRedisTemplate stringRedisTemplate; private StringRedisTemplate stringRedisTemplate;


@Resource @Resource
@@ -109,6 +110,7 @@ public class TadlRedisServiceImpl implements TadlRedisService {
private K8sNameTool k8sNameTool; private K8sNameTool k8sNameTool;


@Resource @Resource
@Lazy
private TrialJobAsyncTask trialJobAsyncTask; private TrialJobAsyncTask trialJobAsyncTask;


@Resource @Resource
@@ -253,7 +255,7 @@ public class TadlRedisServiceImpl implements TadlRedisService {
//查询同个实验阶段下运行中,等待中的实验trial ,变更状态为待运行 //查询同个实验阶段下运行中,等待中的实验trial ,变更状态为待运行
List<Trial> trialList = trialService.getTrialList(new LambdaQueryWrapper<Trial>() List<Trial> trialList = trialService.getTrialList(new LambdaQueryWrapper<Trial>()
.eq(Trial::getStageId, stageId) .eq(Trial::getStageId, stageId)
.in(Trial::getStatus, TrialStatusEnum.RUNNING.getVal(), TrialStatusEnum.WAITING.getVal())
.in(Trial::getStatus, TrialStatusEnum.RUNNING.getVal(), TrialStatusEnum.WAITING.getVal())
); );
if (!CollectionUtils.isEmpty(trialList)){ if (!CollectionUtils.isEmpty(trialList)){
Map<Integer, List<Long>> statusTrialIdListMap = trialList.stream().collect(Collectors.groupingBy(Trial::getStatus, Map<Integer, List<Long>> statusTrialIdListMap = trialList.stream().collect(Collectors.groupingBy(Trial::getStatus,
@@ -475,8 +477,8 @@ public class TadlRedisServiceImpl implements TadlRedisService {
//查询运行中,待运行,等待中,运行失败的trial实验(因有些异常状态的处理是先变更状态为运行失败,再进行删除pod操作) //查询运行中,待运行,等待中,运行失败的trial实验(因有些异常状态的处理是先变更状态为运行失败,再进行删除pod操作)
List<Trial> trialList = trialService.getTrialList(new LambdaQueryWrapper<Trial>() List<Trial> trialList = trialService.getTrialList(new LambdaQueryWrapper<Trial>()
.eq(Trial::getStageId, stageId) .eq(Trial::getStageId, stageId)
.in(Trial::getStatus, TrialStatusEnum.RUNNING.getVal(), TrialStatusEnum.TO_RUN.getVal(),TrialStatusEnum.WAITING.getVal(),TrialStatusEnum.FAILED.getVal())
.isNotNull(Trial::getResourceName)
.in(Trial::getStatus, TrialStatusEnum.RUNNING.getVal(), TrialStatusEnum.TO_RUN.getVal(),TrialStatusEnum.WAITING.getVal(),TrialStatusEnum.FAILED.getVal())
.isNotNull(Trial::getResourceName)
); );
LogUtil.info(LogEnum.TADL, TadlConstant.PROCESS_STAGE_KEYWORD_LOG+"Delete running trial.The trial size:{}", experiment.getId(), stageId, trialList.size()); LogUtil.info(LogEnum.TADL, TadlConstant.PROCESS_STAGE_KEYWORD_LOG+"Delete running trial.The trial size:{}", experiment.getId(), stageId, trialList.size());
List<TrialDeleteDTO> trialDeleteDTOList = trialList.stream().map(trial -> { List<TrialDeleteDTO> trialDeleteDTOList = trialList.stream().map(trial -> {
@@ -493,7 +495,7 @@ public class TadlRedisServiceImpl implements TadlRedisService {
return trialDeleteDTO; return trialDeleteDTO;
}).collect(Collectors.toList()); }).collect(Collectors.toList());
//调用删除trial任务方法 //调用删除trial任务方法
trialJobAsyncTask.deleteTrialList(trialDeleteDTOList);
trialJobAsyncTask.deleteTrialList(trialDeleteDTOList);
String taskIdentify = (String) redisUtils.get(experimentIdPrefix + experimentStage.getExperimentId()); String taskIdentify = (String) redisUtils.get(experimentIdPrefix + experimentStage.getExperimentId());
if (StringUtils.isNotEmpty(taskIdentify)) { if (StringUtils.isNotEmpty(taskIdentify)) {
redisUtils.del(taskIdentify, experimentIdPrefix + experimentStage.getExperimentId()); redisUtils.del(taskIdentify, experimentIdPrefix + experimentStage.getExperimentId());
@@ -554,9 +556,9 @@ public class TadlRedisServiceImpl implements TadlRedisService {
} }
List<Long> trialIdList = trialService.getTrialList(new LambdaQueryWrapper<Trial>() List<Long> trialIdList = trialService.getTrialList(new LambdaQueryWrapper<Trial>()
.eq(Trial::getExperimentId, experimentAndTrailDTO.getExperimentId()) .eq(Trial::getExperimentId, experimentAndTrailDTO.getExperimentId())
.eq(Trial::getStageId, experimentAndTrailDTO.getStageId())
.in(Trial::getStatus, TrialStatusEnum.RUNNING.getVal(),TrialStatusEnum.WAITING.getVal())
).stream().map(Trial::getId).collect(Collectors.toList());
.eq(Trial::getStageId, experimentAndTrailDTO.getStageId())
.in(Trial::getStatus, TrialStatusEnum.RUNNING.getVal(),TrialStatusEnum.WAITING.getVal())
).stream().map(Trial::getId).collect(Collectors.toList());


//从推送的消息队列中获取recordId //从推送的消息队列中获取recordId
StreamOperations<String, String, TrialRunParamDTO> streamOperations = stringRedisTemplate.opsForStream(); StreamOperations<String, String, TrialRunParamDTO> streamOperations = stringRedisTemplate.opsForStream();
@@ -601,46 +603,46 @@ public class TadlRedisServiceImpl implements TadlRedisService {
private boolean checkAndPushMessages(ExperimentAndTrailDTO experimentAndTrailDTO) { private boolean checkAndPushMessages(ExperimentAndTrailDTO experimentAndTrailDTO) {
RLock lock = redissonClient.getLock(TadlConstant.LOCK + experimentAndTrailDTO.getStageId()); RLock lock = redissonClient.getLock(TadlConstant.LOCK + experimentAndTrailDTO.getStageId());
try { try {
lock.lock(30, TimeUnit.SECONDS);
Experiment experiment = experimentService.selectById(experimentAndTrailDTO.getExperimentId());
if (ExperimentStatusEnum.FAILED_EXPERIMENT_STATE.getValue().equals(experiment.getStatus())){
LogUtil.info(LogEnum.TADL,TadlConstant.PROCESS_STAGE_KEYWORD_LOG+" The experiment status is :{}. ", experimentAndTrailDTO.getExperimentId(), experimentAndTrailDTO.getStageId(),ExperimentStatusEnum.FAILED_EXPERIMENT_STATE.getMsg());
return true ;
}
LogUtil.info(LogEnum.TADL,TadlConstant.PROCESS_STAGE_KEYWORD_LOG +"Get stream operations. ", experimentAndTrailDTO.getExperimentId(), experimentAndTrailDTO.getStageId());
StreamOperations<String, String, TrialRunParamDTO> streamOperations = stringRedisTemplate.opsForStream();
List<MapRecord<String, String, TrialRunParamDTO>> redisDataList = streamOperations.range(RedisKeyConstant.buildStreamStageKey(experimentAndTrailDTO.getExperimentId(), experimentAndTrailDTO.getStageId()), Range.closed("-", "+"));
//判断 若消息队列中消息数量 >= 并发数量 则不进行消息推送
if (redisDataList.size() >= experimentAndTrailDTO.getTrialConcurrentNum()){
LogUtil.info(LogEnum.TADL,TadlConstant.PROCESS_STAGE_KEYWORD_LOG +"Steam size are grater than concurrent number.");
return true;
}
//获取stage key消息队列中的trial id集合
List<Long> trialIdList = redisDataList.stream().map(mapRecord -> {
TrialRunParamDTO trialRunParamDTO = new TrialRunParamDTO();
try {
BeanUtils.populate(trialRunParamDTO, mapRecord.getValue());
} catch (Exception e) {
LogUtil.error(LogEnum.TADL, TadlConstant.PROCESS_STAGE_KEYWORD_LOG+"Redis Stream 消息转化实体异常!异常信息:{}.",experimentAndTrailDTO.getExperimentId(),experimentAndTrailDTO.getStageId(), e.getMessage());
throw new BusinessException("Redis Stream 消息转化实体异常!");
lock.lock(30, TimeUnit.SECONDS);
Experiment experiment = experimentService.selectById(experimentAndTrailDTO.getExperimentId());
if (ExperimentStatusEnum.FAILED_EXPERIMENT_STATE.getValue().equals(experiment.getStatus())){
LogUtil.info(LogEnum.TADL,TadlConstant.PROCESS_STAGE_KEYWORD_LOG+" The experiment status is :{}. ", experimentAndTrailDTO.getExperimentId(), experimentAndTrailDTO.getStageId(),ExperimentStatusEnum.FAILED_EXPERIMENT_STATE.getMsg());
return true ;
}
LogUtil.info(LogEnum.TADL,TadlConstant.PROCESS_STAGE_KEYWORD_LOG +"Get stream operations. ", experimentAndTrailDTO.getExperimentId(), experimentAndTrailDTO.getStageId());
StreamOperations<String, String, TrialRunParamDTO> streamOperations = stringRedisTemplate.opsForStream();
List<MapRecord<String, String, TrialRunParamDTO>> redisDataList = streamOperations.range(RedisKeyConstant.buildStreamStageKey(experimentAndTrailDTO.getExperimentId(), experimentAndTrailDTO.getStageId()), Range.closed("-", "+"));
//判断 若消息队列中消息数量 >= 并发数量 则不进行消息推送
if (redisDataList.size() >= experimentAndTrailDTO.getTrialConcurrentNum()){
LogUtil.info(LogEnum.TADL,TadlConstant.PROCESS_STAGE_KEYWORD_LOG +"Steam size are grater than concurrent number.");
return true;
}
//获取stage key消息队列中的trial id集合
List<Long> trialIdList = redisDataList.stream().map(mapRecord -> {
TrialRunParamDTO trialRunParamDTO = new TrialRunParamDTO();
try {
BeanUtils.populate(trialRunParamDTO, mapRecord.getValue());
} catch (Exception e) {
LogUtil.error(LogEnum.TADL, TadlConstant.PROCESS_STAGE_KEYWORD_LOG+"Redis Stream 消息转化实体异常!异常信息:{}.",experimentAndTrailDTO.getExperimentId(),experimentAndTrailDTO.getStageId(), e.getMessage());
throw new BusinessException("Redis Stream 消息转化实体异常!");
}
return trialRunParamDTO.getTrialId();
}).collect(Collectors.toList());
//获取不存在于 消息队列中的 trial数据组装实体
List<TrialRunParamDTO> trialRunParamDTOList = experimentAndTrailDTO.getTrialRunParamDTOList().stream().filter(e -> !trialIdList.contains(e.getTrialId())).collect(Collectors.toList());
LogUtil.info(LogEnum.TADL,TadlConstant.PROCESS_STAGE_KEYWORD_LOG+" trialRunParamDTOList size:{}. ", experimentAndTrailDTO.getExperimentId(), experimentAndTrailDTO.getStageId(),trialRunParamDTOList.size());

if (CollectionUtils.isEmpty(trialRunParamDTOList)){
LogUtil.info(LogEnum.TADL,TadlConstant.PROCESS_STAGE_KEYWORD_LOG+" trialRunParamDTOList size:{}.The trial run param size is zero. ", experimentAndTrailDTO.getExperimentId(), experimentAndTrailDTO.getStageId(),trialRunParamDTOList.size());
return true;
} }
return trialRunParamDTO.getTrialId();
}).collect(Collectors.toList());
//获取不存在于 消息队列中的 trial数据组装实体
List<TrialRunParamDTO> trialRunParamDTOList = experimentAndTrailDTO.getTrialRunParamDTOList().stream().filter(e -> !trialIdList.contains(e.getTrialId())).collect(Collectors.toList());
LogUtil.info(LogEnum.TADL,TadlConstant.PROCESS_STAGE_KEYWORD_LOG+" trialRunParamDTOList size:{}. ", experimentAndTrailDTO.getExperimentId(), experimentAndTrailDTO.getStageId(),trialRunParamDTOList.size());

if (CollectionUtils.isEmpty(trialRunParamDTOList)){
LogUtil.info(LogEnum.TADL,TadlConstant.PROCESS_STAGE_KEYWORD_LOG+" trialRunParamDTOList size:{}.The trial run param size is zero. ", experimentAndTrailDTO.getExperimentId(), experimentAndTrailDTO.getStageId(),trialRunParamDTOList.size());
return true;
}


//对消息实体类进行推送
for (TrialRunParamDTO trialRunParamDTO : trialRunParamDTOList) {
ObjectRecord<String, TrialRunParamDTO> mapRecord = ObjectRecord.create(RedisKeyConstant.buildStreamStageKey(experimentAndTrailDTO.getExperimentId(), experimentAndTrailDTO.getStageId()), trialRunParamDTO);
//添加mapRecord 生成recordId
stringRedisTemplate.opsForStream().add(mapRecord);
}
//对消息实体类进行推送
for (TrialRunParamDTO trialRunParamDTO : trialRunParamDTOList) {
ObjectRecord<String, TrialRunParamDTO> mapRecord = ObjectRecord.create(RedisKeyConstant.buildStreamStageKey(experimentAndTrailDTO.getExperimentId(), experimentAndTrailDTO.getStageId()), trialRunParamDTO);
//添加mapRecord 生成recordId
stringRedisTemplate.opsForStream().add(mapRecord);
}


}catch (Exception e){ }catch (Exception e){
LogUtil.error(LogEnum.TADL, TadlConstant.PROCESS_STAGE_KEYWORD_LOG+"获取分布式锁失败,失败信息:{}", experimentAndTrailDTO.getExperimentId(),experimentAndTrailDTO.getStageId(),e.getMessage()); LogUtil.error(LogEnum.TADL, TadlConstant.PROCESS_STAGE_KEYWORD_LOG+"获取分布式锁失败,失败信息:{}", experimentAndTrailDTO.getExperimentId(),experimentAndTrailDTO.getStageId(),e.getMessage());


+ 3
- 1
dubhe-server/dubhe-tadl/src/main/java/org/dubhe/tadl/task/TrialJobAsyncTask.java View File

@@ -50,6 +50,7 @@ import org.dubhe.tadl.service.ExperimentService;
import org.dubhe.tadl.service.TadlRedisService; import org.dubhe.tadl.service.TadlRedisService;
import org.dubhe.tadl.service.TadlTrialService; import org.dubhe.tadl.service.TadlTrialService;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Lazy;
import org.springframework.scheduling.annotation.Async; import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
@@ -80,6 +81,7 @@ public class TrialJobAsyncTask {
private ResourceCache resourceCache; private ResourceCache resourceCache;


@Resource @Resource
@Lazy
private TadlRedisService tadlRedisService; private TadlRedisService tadlRedisService;


@Resource @Resource
@@ -148,7 +150,7 @@ public class TrialJobAsyncTask {
*/ */
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public void deleteTrialList(List<TrialDeleteDTO> trialDeleteDTOList) { public void deleteTrialList(List<TrialDeleteDTO> trialDeleteDTOList) {
//三次重试均反馈失败则给予删除失败结果
//三次重试均反馈失败则给予删除失败结果
int tryTime = 1; int tryTime = 1;
while (!trialDeleteDTOList.isEmpty()){ while (!trialDeleteDTOList.isEmpty()){
//重试三次 //重试三次


Loading…
Cancel
Save