Browse Source

fix the impl of get_flatten_lineage

tags/v1.0.0
Li Hongzhang 5 years ago
parent
commit
726c0b9fda
1 changed files with 37 additions and 73 deletions
  1. +37
    -73
      mindinsight/lineagemgr/model.py

+ 37
- 73
mindinsight/lineagemgr/model.py View File

@@ -206,80 +206,44 @@ def get_flattened_lineage(data_manager, search_condition=None):
Dict[str, list]: A dict contains keys and values from lineages.

"""
summary_base_dir = data_manager.summary_base_dir
lineages = filter_summary_lineage(data_manager=data_manager, search_condition=search_condition)
lineage_objects = lineages.get("object", [])

# Step 1, get column names
column_names = _get_columns_name(lineage_objects)

# Step 2, collect data
column_data = _organize_data_to_matrix(lineage_objects, column_names, summary_base_dir)

return column_data


def _get_columns_name(lineage_objects):
"""Get columns name."""
column_names = set()
user_defined_num = 0
for lineage in lineage_objects:
model_lineage = lineage.get("model_lineage", {})
metric = model_lineage.get("metric", {})
metric_names = tuple('{}{}'.format(_METRIC_PREFIX, key) for key in metric.keys())
user_defined = model_lineage.get("user_defined", {})
user_defined_names = tuple('{}{}'.format(_USER_DEFINED_PREFIX, key) for key in user_defined.keys())
model_lineage_temp = list(model_lineage.keys())
for key in model_lineage_temp:
if key in ["metric", "user_defined"]:
model_lineage_temp.remove(key)
column_names.update(model_lineage_temp)
column_names.update(metric_names)
if user_defined_num + len(user_defined_names) <= USER_DEFINED_INFO_LIMIT:
column_names.update(user_defined_names)
user_defined_num += len(user_defined_names)
elif user_defined_num < USER_DEFINED_INFO_LIMIT <= user_defined_num + len(user_defined_names):
names = []
for i in range(USER_DEFINED_INFO_LIMIT - user_defined_num):
names.append(user_defined_names[i])
column_names.update(names)
user_defined_num += len(names)
log.info("Partial user_defined_info is deleted. Currently saved length is: %s.", user_defined_num)
summary_base_dir, flatten_dict, user_count = data_manager.summary_base_dir, {'train_id': []}, 0
lineages = filter_summary_lineage(data_manager=data_manager, search_condition=search_condition).get("object", [])
for index, lineage in enumerate(lineages):
flatten_dict['train_id'].append(get_relative_path(lineage.get("summary_dir"), summary_base_dir))
for key, val in _flatten_lineage(lineage.get('model_lineage', {})):
if key.startswith(_USER_DEFINED_PREFIX) and key not in flatten_dict:
if user_count > USER_DEFINED_INFO_LIMIT:
log.warning("The user_defined_info has reached the limit %s. %r is ignored",
USER_DEFINED_INFO_LIMIT, key)
continue
user_count += 1
if key not in flatten_dict:
flatten_dict[key] = [None] * index
flatten_dict[key].append(_parse_value(val))
for vals in flatten_dict.values():
if len(vals) == index:
vals.append(None)
return flatten_dict


def _flatten_lineage(lineage):
"""Flatten the lineage."""
for key, val in lineage.items():
if key == 'metric':
for k, v in val.items():
yield f'{_METRIC_PREFIX}{k}', v
elif key == 'user_defined':
for k, v in val.items():
yield f'{_USER_DEFINED_PREFIX}{k}', v
else:
log.warning("The quantity of user_defined_info has reached the limit %s.", USER_DEFINED_INFO_LIMIT)
column_names.update(["train_id"])

return column_names


def _organize_data_to_matrix(lineage_objects, column_names, summary_base_dir):
"""Collect data and transform to matrix."""
cnt_lineages = len(lineage_objects)
column_data = {key: [None] * cnt_lineages for key in column_names}
for ind, lineage in enumerate(lineage_objects):

train_id = get_relative_path(lineage.get("summary_dir"), summary_base_dir)

model_lineage = lineage.get("model_lineage", {})
metric = model_lineage.pop("metric", {})
metric_content = {
'{}{}'.format(_METRIC_PREFIX, key): val for key, val in metric.items()
}
user_defined = model_lineage.pop("user_defined", {})
user_defined_content = {
'{}{}'.format(_USER_DEFINED_PREFIX, key): val for key, val in user_defined.items()
}
final_content = {}
final_content.update(model_lineage)
final_content.update(metric_content)
final_content.update(user_defined_content)
final_content.update({"train_id": train_id})
for key, val in final_content.items():
if isinstance(val, str) and val.lower() in ['nan', 'inf']:
val = np.nan
if key in column_data:
column_data[key][ind] = val
return column_data
yield key, val


def _parse_value(val):
"""Parse value."""
if isinstance(val, str) and val.lower() in ['nan', 'inf']:
return np.nan
return val


class LineageTable:


Loading…
Cancel
Save