@@ -83,17 +83,24 @@ class TrainTaskManager(BaseProcessor):
plugins=plugins
)
def query_train_jobs(self, offset=0, limit=10):
def query_train_jobs(self, offset=0, limit=10, request_train_id=None ):
"""
Query train jobs.
Args:
offset (int): Specify page number. Default is 0.
limit (int): Specify page size. Default is 10.
request_train_id (str): Specify train id. Default is None.
Returns:
tuple, return quantity of total train jobs and list of train jobs specified by offset and limit.
"""
if request_train_id is not None:
train_job_item = self._get_train_job_item(request_train_id)
if train_job_item is None:
return 0, []
return 1, [train_job_item]
brief_cache = self._data_manager.get_brief_cache()
brief_train_jobs = list(brief_cache.get_train_jobs().values())
brief_train_jobs.sort(key=lambda x: x.basic_info.update_time, reverse=True)
@@ -106,37 +113,52 @@ class TrainTaskManager(BaseProcessor):
train_ids = [train_job.basic_info.train_id for train_job in brief_train_jobs[start:end]]
for train_id in train_ids:
try:
train_job = self._data_manager.get_train_job(train_id)
except exceptions.TrainJobNotExistError:
logger.warning('Train job %s not existed', train_id)
train_job_item = self._get_train_job_item(train_id)
if train_job_item is None:
continue
basic_info = train_job.get_basic_info()
train_job_item = dict(
train_id=basic_info.train_id,
relative_path=basic_info.train_id,
create_time=basic_info.create_time.strftime('%Y-%m-%d %H:%M:%S'),
update_time=basic_info.update_time.strftime('%Y-%m-%d %H:%M:%S'),
profiler_dir=basic_info.profiler_dir,
cache_status=train_job.cache_status.value,
)
if train_job.cache_status == CacheStatus.CACHED:
plugins = self.get_plugins(train_id)
else:
plugins = dict(plugins={
'graph': [],
'scalar': [],
'image': [],
'histogram': [],
})
train_job_item.update(plugins)
train_jobs.append(train_job_item)
return total, train_jobs
def _get_train_job_item(self, train_id):
"""
Get train job item.
Args:
train_id (str): Specify train id.
Returns:
dict, a dict of train job item.
"""
try:
train_job = self._data_manager.get_train_job(train_id)
except exceptions.TrainJobNotExistError:
logger.warning('Train job %s not existed', train_id)
return None
basic_info = train_job.get_basic_info()
train_job_item = dict(
train_id=basic_info.train_id,
relative_path=basic_info.train_id,
create_time=basic_info.create_time.strftime('%Y-%m-%d %H:%M:%S'),
update_time=basic_info.update_time.strftime('%Y-%m-%d %H:%M:%S'),
profiler_dir=basic_info.profiler_dir,
cache_status=train_job.cache_status.value,
)
if train_job.cache_status == CacheStatus.CACHED:
plugins = self.get_plugins(train_id)
else:
plugins = dict(plugins={
'graph': [],
'scalar': [],
'image': [],
'histogram': [],
})
train_job_item.update(plugins)
return train_job_item
def cache_train_jobs(self, train_ids):
"""
Cache train jobs.