Browse Source

!800 Fix bug in ExplainJob and EventParser, change the proto for Explain.

Merge pull request !800 from YuhanShi/ExplainJob_v1.0
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
9b5166cd0d
5 changed files with 252 additions and 363 deletions
  1. +6
    -14
      mindinsight/datavisual/proto_files/mindinsight_summary.proto
  2. +135
    -224
      mindinsight/datavisual/proto_files/mindinsight_summary_pb2.py
  3. +57
    -79
      mindinsight/explainer/manager/event_parse.py
  4. +41
    -39
      mindinsight/explainer/manager/explain_job.py
  5. +13
    -7
      mindinsight/explainer/manager/explain_manager.py

+ 6
- 14
mindinsight/datavisual/proto_files/mindinsight_summary.proto View File

@@ -116,22 +116,14 @@ message Explain {
message Explanation{
optional string explain_method = 1;
optional int32 label = 2;
optional bytes hitmap = 3;
optional bytes heatmap = 3;
}

message Benchmark{
message TotalScore{
optional string benchmark_method = 1;
optional float score = 2;
}
message LabelScore{
repeated float score = 1;
optional string benchmark_method = 2;
}

optional string explain_method = 1;
repeated TotalScore total_score = 2;
repeated LabelScore label_score = 3;
optional string benchmark_method = 1;
optional string explain_method = 2;
optional float total_score = 3;
repeated float label_score = 4;
}

message Metadata{
@@ -151,4 +143,4 @@ message Explain {

optional Metadata metadata = 7;
optional string status = 8; // enum value: run, end
}
}

+ 135
- 224
mindinsight/datavisual/proto_files/mindinsight_summary_pb2.py View File

@@ -1,26 +1,25 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: mindinsight_summary.proto

import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports)

_sym_db = _symbol_database.Default()


from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as mindinsight__anf__ir__pb2
from . import mindinsight_anf_ir_pb2 as mindinsight__anf__ir__pb2


DESCRIPTOR = _descriptor.FileDescriptor(
name='mindinsight_summary.proto',
package='mindinsight',
syntax='proto2',
serialized_pb=_b('\n\x19mindinsight_summary.proto\x12\x0bmindinsight\x1a\x18mindinsight_anf_ir.proto\"\xc3\x01\n\x05\x45vent\x12\x11\n\twall_time\x18\x01 \x02(\x01\x12\x0c\n\x04step\x18\x02 \x01(\x03\x12\x11\n\x07version\x18\x03 \x01(\tH\x00\x12,\n\tgraph_def\x18\x04 \x01(\x0b\x32\x17.mindinsight.GraphProtoH\x00\x12\'\n\x07summary\x18\x05 \x01(\x0b\x32\x14.mindinsight.SummaryH\x00\x12\'\n\x07\x65xplain\x18\x06 \x01(\x0b\x32\x14.mindinsight.ExplainH\x00\x42\x06\n\x04what\"\xc0\x04\n\x07Summary\x12)\n\x05value\x18\x01 \x03(\x0b\x32\x1a.mindinsight.Summary.Value\x1aQ\n\x05Image\x12\x0e\n\x06height\x18\x01 \x02(\x05\x12\r\n\x05width\x18\x02 \x02(\x05\x12\x12\n\ncolorspace\x18\x03 \x02(\x05\x12\x15\n\rencoded_image\x18\x04 \x02(\x0c\x1a\xf0\x01\n\tHistogram\x12\x36\n\x07\x62uckets\x18\x01 \x03(\x0b\x32%.mindinsight.Summary.Histogram.bucket\x12\x11\n\tnan_count\x18\x02 \x01(\x03\x12\x15\n\rpos_inf_count\x18\x03 \x01(\x03\x12\x15\n\rneg_inf_count\x18\x04 \x01(\x03\x12\x0b\n\x03max\x18\x05 \x01(\x01\x12\x0b\n\x03min\x18\x06 \x01(\x01\x12\x0b\n\x03sum\x18\x07 \x01(\x01\x12\r\n\x05\x63ount\x18\x08 \x01(\x03\x1a\x34\n\x06\x62ucket\x12\x0c\n\x04left\x18\x01 \x02(\x01\x12\r\n\x05width\x18\x02 \x02(\x01\x12\r\n\x05\x63ount\x18\x03 \x02(\x03\x1a\xc3\x01\n\x05Value\x12\x0b\n\x03tag\x18\x01 \x02(\t\x12\x16\n\x0cscalar_value\x18\x03 \x01(\x02H\x00\x12+\n\x05image\x18\x04 \x01(\x0b\x32\x1a.mindinsight.Summary.ImageH\x00\x12*\n\x06tensor\x18\x08 \x01(\x0b\x32\x18.mindinsight.TensorProtoH\x00\x12\x33\n\thistogram\x18\t \x01(\x0b\x32\x1e.mindinsight.Summary.HistogramH\x00\x42\x07\n\x05value\"\xa9\x06\n\x07\x45xplain\x12\x10\n\x08image_id\x18\x01 \x01(\t\x12\x12\n\nimage_data\x18\x02 \x01(\x0c\x12\x1a\n\x12ground_truth_label\x18\x03 \x03(\x05\x12\x31\n\tinference\x18\x04 \x01(\x0b\x32\x1e.mindinsight.Explain.Inference\x12\x35\n\x0b\x65xplanation\x18\x05 \x03(\x0b\x32 .mindinsight.Explain.Explanation\x12\x31\n\tbenchmark\x18\x06 \x03(\x0b\x32\x1e.mindinsight.Explain.Benchmark\x12/\n\x08metadata\x18\x07 \x01(\x0b\x32\x1d.mindinsight.Explain.Metadata\x12\x0e\n\x06status\x18\x08 \x01(\t\x1aW\n\tInference\x12\x19\n\x11ground_truth_prob\x18\x01 \x03(\x02\x12\x17\n\x0fpredicted_label\x18\x02 \x03(\x05\x12\x16\n\x0epredicted_prob\x18\x03 \x03(\x02\x1a\x44\n\x0b\x45xplanation\x12\x16\n\x0e\x65xplain_method\x18\x01 \x01(\t\x12\r\n\x05label\x18\x02 \x01(\x05\x12\x0e\n\x06hitmap\x18\x03 \x01(\x0c\x1a\x91\x02\n\tBenchmark\x12\x16\n\x0e\x65xplain_method\x18\x01 \x01(\t\x12>\n\x0btotal_score\x18\x02 \x03(\x0b\x32).mindinsight.Explain.Benchmark.TotalScore\x12>\n\x0blabel_score\x18\x03 \x03(\x0b\x32).mindinsight.Explain.Benchmark.LabelScore\x1a\x35\n\nTotalScore\x12\x18\n\x10\x62\x65nchmark_method\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x1a\x35\n\nLabelScore\x12\r\n\x05score\x18\x01 \x03(\x02\x12\x18\n\x10\x62\x65nchmark_method\x18\x02 \x01(\t\x1aK\n\x08Metadata\x12\r\n\x05label\x18\x01 \x03(\t\x12\x16\n\x0e\x65xplain_method\x18\x02 \x03(\t\x12\x18\n\x10\x62\x65nchmark_method\x18\x03 \x03(\tB\x03\xf8\x01\x01')
serialized_options=b'\370\001\001',
serialized_pb=b'\n\x19mindinsight_summary.proto\x12\x0bmindinsight\x1a\x18mindinsight_anf_ir.proto\"\xc3\x01\n\x05\x45vent\x12\x11\n\twall_time\x18\x01 \x02(\x01\x12\x0c\n\x04step\x18\x02 \x01(\x03\x12\x11\n\x07version\x18\x03 \x01(\tH\x00\x12,\n\tgraph_def\x18\x04 \x01(\x0b\x32\x17.mindinsight.GraphProtoH\x00\x12\'\n\x07summary\x18\x05 \x01(\x0b\x32\x14.mindinsight.SummaryH\x00\x12\'\n\x07\x65xplain\x18\x06 \x01(\x0b\x32\x14.mindinsight.ExplainH\x00\x42\x06\n\x04what\"\xc0\x04\n\x07Summary\x12)\n\x05value\x18\x01 \x03(\x0b\x32\x1a.mindinsight.Summary.Value\x1aQ\n\x05Image\x12\x0e\n\x06height\x18\x01 \x02(\x05\x12\r\n\x05width\x18\x02 \x02(\x05\x12\x12\n\ncolorspace\x18\x03 \x02(\x05\x12\x15\n\rencoded_image\x18\x04 \x02(\x0c\x1a\xf0\x01\n\tHistogram\x12\x36\n\x07\x62uckets\x18\x01 \x03(\x0b\x32%.mindinsight.Summary.Histogram.bucket\x12\x11\n\tnan_count\x18\x02 \x01(\x03\x12\x15\n\rpos_inf_count\x18\x03 \x01(\x03\x12\x15\n\rneg_inf_count\x18\x04 \x01(\x03\x12\x0b\n\x03max\x18\x05 \x01(\x01\x12\x0b\n\x03min\x18\x06 \x01(\x01\x12\x0b\n\x03sum\x18\x07 \x01(\x01\x12\r\n\x05\x63ount\x18\x08 \x01(\x03\x1a\x34\n\x06\x62ucket\x12\x0c\n\x04left\x18\x01 \x02(\x01\x12\r\n\x05width\x18\x02 \x02(\x01\x12\r\n\x05\x63ount\x18\x03 \x02(\x03\x1a\xc3\x01\n\x05Value\x12\x0b\n\x03tag\x18\x01 \x02(\t\x12\x16\n\x0cscalar_value\x18\x03 \x01(\x02H\x00\x12+\n\x05image\x18\x04 \x01(\x0b\x32\x1a.mindinsight.Summary.ImageH\x00\x12*\n\x06tensor\x18\x08 \x01(\x0b\x32\x18.mindinsight.TensorProtoH\x00\x12\x33\n\thistogram\x18\t \x01(\x0b\x32\x1e.mindinsight.Summary.HistogramH\x00\x42\x07\n\x05value\"\xff\x04\n\x07\x45xplain\x12\x10\n\x08image_id\x18\x01 \x01(\t\x12\x12\n\nimage_data\x18\x02 \x01(\x0c\x12\x1a\n\x12ground_truth_label\x18\x03 \x03(\x05\x12\x31\n\tinference\x18\x04 \x01(\x0b\x32\x1e.mindinsight.Explain.Inference\x12\x35\n\x0b\x65xplanation\x18\x05 \x03(\x0b\x32 .mindinsight.Explain.Explanation\x12\x31\n\tbenchmark\x18\x06 \x03(\x0b\x32\x1e.mindinsight.Explain.Benchmark\x12/\n\x08metadata\x18\x07 \x01(\x0b\x32\x1d.mindinsight.Explain.Metadata\x12\x0e\n\x06status\x18\x08 \x01(\t\x1aW\n\tInference\x12\x19\n\x11ground_truth_prob\x18\x01 \x03(\x02\x12\x17\n\x0fpredicted_label\x18\x02 \x03(\x05\x12\x16\n\x0epredicted_prob\x18\x03 \x03(\x02\x1a\x45\n\x0b\x45xplanation\x12\x16\n\x0e\x65xplain_method\x18\x01 \x01(\t\x12\r\n\x05label\x18\x02 \x01(\x05\x12\x0f\n\x07heatmap\x18\x03 \x01(\x0c\x1ag\n\tBenchmark\x12\x18\n\x10\x62\x65nchmark_method\x18\x01 \x01(\t\x12\x16\n\x0e\x65xplain_method\x18\x02 \x01(\t\x12\x13\n\x0btotal_score\x18\x03 \x01(\x02\x12\x13\n\x0blabel_score\x18\x04 \x03(\x02\x1aK\n\x08Metadata\x12\r\n\x05label\x18\x01 \x03(\t\x12\x16\n\x0e\x65xplain_method\x18\x02 \x03(\t\x12\x18\n\x10\x62\x65nchmark_method\x18\x03 \x03(\tB\x03\xf8\x01\x01'
,
dependencies=[mindinsight__anf__ir__pb2.DESCRIPTOR,])

@@ -40,49 +39,49 @@ _EVENT = _descriptor.Descriptor(
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='step', full_name='mindinsight.Event.step', index=1,
number=2, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='version', full_name='mindinsight.Event.version', index=2,
number=3, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='graph_def', full_name='mindinsight.Event.graph_def', index=3,
number=4, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='summary', full_name='mindinsight.Event.summary', index=4,
number=5, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='explain', full_name='mindinsight.Event.explain', index=5,
number=6, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
serialized_options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
@@ -109,35 +108,35 @@ _SUMMARY_IMAGE = _descriptor.Descriptor(
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='width', full_name='mindinsight.Summary.Image.width', index=1,
number=2, type=5, cpp_type=1, label=2,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='colorspace', full_name='mindinsight.Summary.Image.colorspace', index=2,
number=3, type=5, cpp_type=1, label=2,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='encoded_image', full_name='mindinsight.Summary.Image.encoded_image', index=3,
number=4, type=12, cpp_type=9, label=2,
has_default_value=False, default_value=_b(""),
has_default_value=False, default_value=b"",
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
serialized_options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
@@ -160,28 +159,28 @@ _SUMMARY_HISTOGRAM_BUCKET = _descriptor.Descriptor(
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='width', full_name='mindinsight.Summary.Histogram.bucket.width', index=1,
number=2, type=1, cpp_type=5, label=2,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='count', full_name='mindinsight.Summary.Histogram.bucket.count', index=2,
number=3, type=3, cpp_type=2, label=2,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
serialized_options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
@@ -204,63 +203,63 @@ _SUMMARY_HISTOGRAM = _descriptor.Descriptor(
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='nan_count', full_name='mindinsight.Summary.Histogram.nan_count', index=1,
number=2, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='pos_inf_count', full_name='mindinsight.Summary.Histogram.pos_inf_count', index=2,
number=3, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='neg_inf_count', full_name='mindinsight.Summary.Histogram.neg_inf_count', index=3,
number=4, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='max', full_name='mindinsight.Summary.Histogram.max', index=4,
number=5, type=1, cpp_type=5, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='min', full_name='mindinsight.Summary.Histogram.min', index=5,
number=6, type=1, cpp_type=5, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='sum', full_name='mindinsight.Summary.Histogram.sum', index=6,
number=7, type=1, cpp_type=5, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='count', full_name='mindinsight.Summary.Histogram.count', index=7,
number=8, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[_SUMMARY_HISTOGRAM_BUCKET, ],
enum_types=[
],
options=None,
serialized_options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
@@ -280,45 +279,45 @@ _SUMMARY_VALUE = _descriptor.Descriptor(
_descriptor.FieldDescriptor(
name='tag', full_name='mindinsight.Summary.Value.tag', index=0,
number=1, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='scalar_value', full_name='mindinsight.Summary.Value.scalar_value', index=1,
number=3, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='image', full_name='mindinsight.Summary.Value.image', index=2,
number=4, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='tensor', full_name='mindinsight.Summary.Value.tensor', index=3,
number=8, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='histogram', full_name='mindinsight.Summary.Value.histogram', index=4,
number=9, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
serialized_options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
@@ -344,14 +343,14 @@ _SUMMARY = _descriptor.Descriptor(
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[_SUMMARY_IMAGE, _SUMMARY_HISTOGRAM, _SUMMARY_VALUE, ],
enum_types=[
],
options=None,
serialized_options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
@@ -375,28 +374,28 @@ _EXPLAIN_INFERENCE = _descriptor.Descriptor(
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='predicted_label', full_name='mindinsight.Explain.Inference.predicted_label', index=1,
number=2, type=5, cpp_type=1, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='predicted_prob', full_name='mindinsight.Explain.Inference.predicted_prob', index=2,
number=3, type=2, cpp_type=6, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
serialized_options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
@@ -416,156 +415,89 @@ _EXPLAIN_EXPLANATION = _descriptor.Descriptor(
_descriptor.FieldDescriptor(
name='explain_method', full_name='mindinsight.Explain.Explanation.explain_method', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='label', full_name='mindinsight.Explain.Explanation.label', index=1,
number=2, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='hitmap', full_name='mindinsight.Explain.Explanation.hitmap', index=2,
name='heatmap', full_name='mindinsight.Explain.Explanation.heatmap', index=2,
number=3, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""),
has_default_value=False, default_value=b"",
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
serialized_options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1234,
serialized_end=1302,
serialized_end=1303,
)

_EXPLAIN_BENCHMARK_TOTALSCORE = _descriptor.Descriptor(
name='TotalScore',
full_name='mindinsight.Explain.Benchmark.TotalScore',
_EXPLAIN_BENCHMARK = _descriptor.Descriptor(
name='Benchmark',
full_name='mindinsight.Explain.Benchmark',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='benchmark_method', full_name='mindinsight.Explain.Benchmark.TotalScore.benchmark_method', index=0,
name='benchmark_method', full_name='mindinsight.Explain.Benchmark.benchmark_method', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='score', full_name='mindinsight.Explain.Benchmark.TotalScore.score', index=1,
number=2, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1470,
serialized_end=1523,
)

_EXPLAIN_BENCHMARK_LABELSCORE = _descriptor.Descriptor(
name='LabelScore',
full_name='mindinsight.Explain.Benchmark.LabelScore',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='score', full_name='mindinsight.Explain.Benchmark.LabelScore.score', index=0,
number=1, type=2, cpp_type=6, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='benchmark_method', full_name='mindinsight.Explain.Benchmark.LabelScore.benchmark_method', index=1,
name='explain_method', full_name='mindinsight.Explain.Benchmark.explain_method', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1525,
serialized_end=1578,
)

_EXPLAIN_BENCHMARK = _descriptor.Descriptor(
name='Benchmark',
full_name='mindinsight.Explain.Benchmark',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='explain_method', full_name='mindinsight.Explain.Benchmark.explain_method', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='total_score', full_name='mindinsight.Explain.Benchmark.total_score', index=1,
number=2, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
name='total_score', full_name='mindinsight.Explain.Benchmark.total_score', index=2,
number=3, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='label_score', full_name='mindinsight.Explain.Benchmark.label_score', index=2,
number=3, type=11, cpp_type=10, label=3,
name='label_score', full_name='mindinsight.Explain.Benchmark.label_score', index=3,
number=4, type=2, cpp_type=6, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[_EXPLAIN_BENCHMARK_TOTALSCORE, _EXPLAIN_BENCHMARK_LABELSCORE, ],
nested_types=[],
enum_types=[
],
options=None,
serialized_options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1305,
serialized_end=1578,
serialized_end=1408,
)

_EXPLAIN_METADATA = _descriptor.Descriptor(
@@ -581,35 +513,35 @@ _EXPLAIN_METADATA = _descriptor.Descriptor(
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='explain_method', full_name='mindinsight.Explain.Metadata.explain_method', index=1,
number=2, type=9, cpp_type=9, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='benchmark_method', full_name='mindinsight.Explain.Metadata.benchmark_method', index=2,
number=3, type=9, cpp_type=9, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
serialized_options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1580,
serialized_end=1655,
serialized_start=1410,
serialized_end=1485,
)

_EXPLAIN = _descriptor.Descriptor(
@@ -622,73 +554,73 @@ _EXPLAIN = _descriptor.Descriptor(
_descriptor.FieldDescriptor(
name='image_id', full_name='mindinsight.Explain.image_id', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='image_data', full_name='mindinsight.Explain.image_data', index=1,
number=2, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""),
has_default_value=False, default_value=b"",
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='ground_truth_label', full_name='mindinsight.Explain.ground_truth_label', index=2,
number=3, type=5, cpp_type=1, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='inference', full_name='mindinsight.Explain.inference', index=3,
number=4, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='explanation', full_name='mindinsight.Explain.explanation', index=4,
number=5, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='benchmark', full_name='mindinsight.Explain.benchmark', index=5,
number=6, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='metadata', full_name='mindinsight.Explain.metadata', index=6,
number=7, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='status', full_name='mindinsight.Explain.status', index=7,
number=8, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[_EXPLAIN_INFERENCE, _EXPLAIN_EXPLANATION, _EXPLAIN_BENCHMARK, _EXPLAIN_METADATA, ],
enum_types=[
],
options=None,
serialized_options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=846,
serialized_end=1655,
serialized_end=1485,
)

_EVENT.fields_by_name['graph_def'].message_type = mindinsight__anf__ir__pb2._GRAPHPROTO
@@ -729,10 +661,6 @@ _SUMMARY_VALUE.fields_by_name['histogram'].containing_oneof = _SUMMARY_VALUE.one
_SUMMARY.fields_by_name['value'].message_type = _SUMMARY_VALUE
_EXPLAIN_INFERENCE.containing_type = _EXPLAIN
_EXPLAIN_EXPLANATION.containing_type = _EXPLAIN
_EXPLAIN_BENCHMARK_TOTALSCORE.containing_type = _EXPLAIN_BENCHMARK
_EXPLAIN_BENCHMARK_LABELSCORE.containing_type = _EXPLAIN_BENCHMARK
_EXPLAIN_BENCHMARK.fields_by_name['total_score'].message_type = _EXPLAIN_BENCHMARK_TOTALSCORE
_EXPLAIN_BENCHMARK.fields_by_name['label_score'].message_type = _EXPLAIN_BENCHMARK_LABELSCORE
_EXPLAIN_BENCHMARK.containing_type = _EXPLAIN
_EXPLAIN_METADATA.containing_type = _EXPLAIN
_EXPLAIN.fields_by_name['inference'].message_type = _EXPLAIN_INFERENCE
@@ -744,108 +672,91 @@ DESCRIPTOR.message_types_by_name['Summary'] = _SUMMARY
DESCRIPTOR.message_types_by_name['Explain'] = _EXPLAIN
_sym_db.RegisterFileDescriptor(DESCRIPTOR)

Event = _reflection.GeneratedProtocolMessageType('Event', (_message.Message,), dict(
DESCRIPTOR = _EVENT,
__module__ = 'mindinsight_summary_pb2'
Event = _reflection.GeneratedProtocolMessageType('Event', (_message.Message,), {
'DESCRIPTOR' : _EVENT,
'__module__' : 'mindinsight_summary_pb2'
# @@protoc_insertion_point(class_scope:mindinsight.Event)
))
})
_sym_db.RegisterMessage(Event)

Summary = _reflection.GeneratedProtocolMessageType('Summary', (_message.Message,), dict(
Summary = _reflection.GeneratedProtocolMessageType('Summary', (_message.Message,), {

Image = _reflection.GeneratedProtocolMessageType('Image', (_message.Message,), dict(
DESCRIPTOR = _SUMMARY_IMAGE,
__module__ = 'mindinsight_summary_pb2'
'Image' : _reflection.GeneratedProtocolMessageType('Image', (_message.Message,), {
'DESCRIPTOR' : _SUMMARY_IMAGE,
'__module__' : 'mindinsight_summary_pb2'
# @@protoc_insertion_point(class_scope:mindinsight.Summary.Image)
))
})
,

Histogram = _reflection.GeneratedProtocolMessageType('Histogram', (_message.Message,), dict(
'Histogram' : _reflection.GeneratedProtocolMessageType('Histogram', (_message.Message,), {

bucket = _reflection.GeneratedProtocolMessageType('bucket', (_message.Message,), dict(
DESCRIPTOR = _SUMMARY_HISTOGRAM_BUCKET,
__module__ = 'mindinsight_summary_pb2'
'bucket' : _reflection.GeneratedProtocolMessageType('bucket', (_message.Message,), {
'DESCRIPTOR' : _SUMMARY_HISTOGRAM_BUCKET,
'__module__' : 'mindinsight_summary_pb2'
# @@protoc_insertion_point(class_scope:mindinsight.Summary.Histogram.bucket)
))
})
,
DESCRIPTOR = _SUMMARY_HISTOGRAM,
__module__ = 'mindinsight_summary_pb2'
'DESCRIPTOR' : _SUMMARY_HISTOGRAM,
'__module__' : 'mindinsight_summary_pb2'
# @@protoc_insertion_point(class_scope:mindinsight.Summary.Histogram)
))
})
,

Value = _reflection.GeneratedProtocolMessageType('Value', (_message.Message,), dict(
DESCRIPTOR = _SUMMARY_VALUE,
__module__ = 'mindinsight_summary_pb2'
'Value' : _reflection.GeneratedProtocolMessageType('Value', (_message.Message,), {
'DESCRIPTOR' : _SUMMARY_VALUE,
'__module__' : 'mindinsight_summary_pb2'
# @@protoc_insertion_point(class_scope:mindinsight.Summary.Value)
))
})
,
DESCRIPTOR = _SUMMARY,
__module__ = 'mindinsight_summary_pb2'
'DESCRIPTOR' : _SUMMARY,
'__module__' : 'mindinsight_summary_pb2'
# @@protoc_insertion_point(class_scope:mindinsight.Summary)
))
})
_sym_db.RegisterMessage(Summary)
_sym_db.RegisterMessage(Summary.Image)
_sym_db.RegisterMessage(Summary.Histogram)
_sym_db.RegisterMessage(Summary.Histogram.bucket)
_sym_db.RegisterMessage(Summary.Value)

Explain = _reflection.GeneratedProtocolMessageType('Explain', (_message.Message,), dict(
Explain = _reflection.GeneratedProtocolMessageType('Explain', (_message.Message,), {

Inference = _reflection.GeneratedProtocolMessageType('Inference', (_message.Message,), dict(
DESCRIPTOR = _EXPLAIN_INFERENCE,
__module__ = 'mindinsight_summary_pb2'
'Inference' : _reflection.GeneratedProtocolMessageType('Inference', (_message.Message,), {
'DESCRIPTOR' : _EXPLAIN_INFERENCE,
'__module__' : 'mindinsight_summary_pb2'
# @@protoc_insertion_point(class_scope:mindinsight.Explain.Inference)
))
})
,

Explanation = _reflection.GeneratedProtocolMessageType('Explanation', (_message.Message,), dict(
DESCRIPTOR = _EXPLAIN_EXPLANATION,
__module__ = 'mindinsight_summary_pb2'
'Explanation' : _reflection.GeneratedProtocolMessageType('Explanation', (_message.Message,), {
'DESCRIPTOR' : _EXPLAIN_EXPLANATION,
'__module__' : 'mindinsight_summary_pb2'
# @@protoc_insertion_point(class_scope:mindinsight.Explain.Explanation)
))
})
,

Benchmark = _reflection.GeneratedProtocolMessageType('Benchmark', (_message.Message,), dict(

TotalScore = _reflection.GeneratedProtocolMessageType('TotalScore', (_message.Message,), dict(
DESCRIPTOR = _EXPLAIN_BENCHMARK_TOTALSCORE,
__module__ = 'mindinsight_summary_pb2'
# @@protoc_insertion_point(class_scope:mindinsight.Explain.Benchmark.TotalScore)
))
,

LabelScore = _reflection.GeneratedProtocolMessageType('LabelScore', (_message.Message,), dict(
DESCRIPTOR = _EXPLAIN_BENCHMARK_LABELSCORE,
__module__ = 'mindinsight_summary_pb2'
# @@protoc_insertion_point(class_scope:mindinsight.Explain.Benchmark.LabelScore)
))
,
DESCRIPTOR = _EXPLAIN_BENCHMARK,
__module__ = 'mindinsight_summary_pb2'
'Benchmark' : _reflection.GeneratedProtocolMessageType('Benchmark', (_message.Message,), {
'DESCRIPTOR' : _EXPLAIN_BENCHMARK,
'__module__' : 'mindinsight_summary_pb2'
# @@protoc_insertion_point(class_scope:mindinsight.Explain.Benchmark)
))
})
,

Metadata = _reflection.GeneratedProtocolMessageType('Metadata', (_message.Message,), dict(
DESCRIPTOR = _EXPLAIN_METADATA,
__module__ = 'mindinsight_summary_pb2'
'Metadata' : _reflection.GeneratedProtocolMessageType('Metadata', (_message.Message,), {
'DESCRIPTOR' : _EXPLAIN_METADATA,
'__module__' : 'mindinsight_summary_pb2'
# @@protoc_insertion_point(class_scope:mindinsight.Explain.Metadata)
))
})
,
DESCRIPTOR = _EXPLAIN,
__module__ = 'mindinsight_summary_pb2'
'DESCRIPTOR' : _EXPLAIN,
'__module__' : 'mindinsight_summary_pb2'
# @@protoc_insertion_point(class_scope:mindinsight.Explain)
))
})
_sym_db.RegisterMessage(Explain)
_sym_db.RegisterMessage(Explain.Inference)
_sym_db.RegisterMessage(Explain.Explanation)
_sym_db.RegisterMessage(Explain.Benchmark)
_sym_db.RegisterMessage(Explain.Benchmark.TotalScore)
_sym_db.RegisterMessage(Explain.Benchmark.LabelScore)
_sym_db.RegisterMessage(Explain.Metadata)


DESCRIPTOR.has_options = True
DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('\370\001\001'))
DESCRIPTOR._options = None
# @@protoc_insertion_point(module_scope)

+ 57
- 79
mindinsight/explainer/manager/event_parse.py View File

@@ -13,8 +13,8 @@
# limitations under the License.
# ============================================================================
"""EventParser for summary event."""
from collections import namedtuple
from typing import Dict, Iterable, List, Optional, Tuple
from collections import namedtuple, defaultdict
from typing import Dict, List, Optional, Tuple

from mindinsight.explainer.common.enums import PluginNameEnum
from mindinsight.explainer.common.log import logger
@@ -35,32 +35,34 @@ class EventParser:
self._job = job
self._sample_pool = {}

def clear(self):
"""Clear the loaded data."""
self._sample_pool.clear()

def parse_metadata(self, metadata) -> Tuple[List, List, List]:
@staticmethod
def parse_metadata(metadata) -> Tuple[List, List, List]:
"""Parse the metadata event."""
explainers = list(metadata.explain_method)
metrics = list(metadata.benchmark_method)
labels = list(metadata.label)
return explainers, metrics, labels

def parse_benchmark(self, benchmark) -> Dict:
@staticmethod
def parse_benchmark(benchmarks) -> Tuple[Dict, Dict]:
"""Parse the benchmark event."""
imported_benchmark = {}
for explainer_result in benchmark:
explainer = explainer_result.explain_method
total_score = explainer_result.total_score
label_score = explainer_result.label_score

explainer_benchmark = {
'explainer': explainer,
'evaluations': EventParser._total_score_to_dict(total_score),
'class_scores': EventParser._label_score_to_dict(label_score, self._job.labels)
}
imported_benchmark[explainer] = explainer_benchmark
return imported_benchmark
explainer_score_dict = defaultdict(list)
label_score_dict = defaultdict(dict)

for benchmark in benchmarks:
explainer = benchmark.explain_method
metric = benchmark.benchmark_method
metric_score = benchmark.total_score
label_score_event = benchmark.label_score

explainer_score_dict[explainer].append({
'metric': metric,
'score': metric_score})
new_label_score_dict = EventParser._score_event_to_dict(label_score_event, metric)
for label, label_scores in new_label_score_dict.items():
label_score_dict[explainer][label] = label_score_dict[explainer].get(label, []) + label_scores

return explainer_score_dict, label_score_dict

def parse_sample(self, sample: namedtuple) -> Optional[namedtuple]:
"""Parse the sample event."""
@@ -83,68 +85,13 @@ class EventParser:
" detail: %r", tag, str(ex))
continue

if EventParser._is_sample_data_complete(self._sample_pool[sample_id]):
return self._sample_pool.pop(sample_id)
if EventParser._is_ready_for_display(self._sample_pool[sample_id]):
return self._sample_pool[sample_id]
return None

def _parse_inference(self, event, sample_id):
"""Parse the inference event."""
self._sample_pool[sample_id].inference.ground_truth_prob.extend(
event.inference.ground_truth_prob)
self._sample_pool[sample_id].inference.predicted_label.extend(
event.inference.predicted_label)
self._sample_pool[sample_id].inference.predicted_prob.extend(
event.inference.predicted_prob)

def _parse_explanation(self, event, sample_id):
"""Parse the explanation event."""
if event.explanation:
for explanation_item in event.explanation:
new_explanation = self._sample_pool[sample_id].explanation.add()
new_explanation.explain_method = explanation_item.explain_method
new_explanation.label = explanation_item.label
new_explanation.heatmap = explanation_item.heatmap

def _parse_sample_info(self, event, sample_id, tag):
"""Parse the event containing image info."""
if not getattr(self._sample_pool[sample_id], tag):
setattr(self._sample_pool[sample_id], tag, getattr(event, tag))

@staticmethod
def _total_score_to_dict(total_scores: Iterable):
"""Transfer a list of benchmark score to a list of dict."""
evaluation_info = []
for total_score in total_scores:
metric_result = {
'metric': total_score.benchmark_method,
'score': total_score.score}
evaluation_info.append(metric_result)
return evaluation_info

@staticmethod
def _label_score_to_dict(label_scores: Iterable, labels: List[str]):
"""Transfer a list of benchmark score."""
evaluation_info = [{'label': label, 'evaluations': []}
for label in labels]
for label_score in label_scores:
metric = label_score.benchmark_method
for i, score in enumerate(label_score.score):
label_metric_score = {
'metric': metric,
'score': score}
evaluation_info[i]['evaluations'].append(label_metric_score)
return evaluation_info

@staticmethod
def _is_sample_data_complete(image_container: namedtuple) -> bool:
"""Check whether sample data completely loaded."""
required_attrs = ['image_id', 'image_data', 'ground_truth_label', 'inference', 'explanation']
for attr in required_attrs:
if not EventParser.is_attr_ready(image_container, attr):
return False
return True
def clear(self):
"""Clear the loaded data."""
self._sample_pool.clear()

@staticmethod
def _is_ready_for_display(image_container: namedtuple) -> bool:
@@ -178,3 +125,34 @@ class EventParser:
if getattr(image_container, attr, False):
return True
return False

@staticmethod
def _score_event_to_dict(label_score_event, metric):
"""Transfer metric scores per label to pre-defined structure."""
new_label_score_dict = defaultdict(list)
for label_id, label_score in enumerate(label_score_event):
new_label_score_dict[label_id].append({
'metric': metric,
'score': label_score,
})
return new_label_score_dict

def _parse_inference(self, event, sample_id):
"""Parse the inference event."""
self._sample_pool[sample_id].inference.ground_truth_prob.extend(event.inference.ground_truth_prob)
self._sample_pool[sample_id].inference.predicted_label.extend(event.inference.predicted_label)
self._sample_pool[sample_id].inference.predicted_prob.extend(event.inference.predicted_prob)

def _parse_explanation(self, event, sample_id):
"""Parse the explanation event."""
if event.explanation:
for explanation_item in event.explanation:
new_explanation = self._sample_pool[sample_id].explanation.add()
new_explanation.explain_method = explanation_item.explain_method
new_explanation.label = explanation_item.label
new_explanation.heatmap = explanation_item.heatmap

def _parse_sample_info(self, event, sample_id, tag):
"""Parse the event containing image info."""
if not getattr(self._sample_pool[sample_id], tag):
setattr(self._sample_pool[sample_id], tag, getattr(event, tag))

+ 41
- 39
mindinsight/explainer/manager/explain_job.py View File

@@ -15,8 +15,9 @@
"""ExplainJob."""

import os
from collections import defaultdict
from datetime import datetime
from typing import List, Iterable, Union
from typing import Union

from mindinsight.explainer.common.enums import PluginNameEnum
from mindinsight.explainer.common.log import logger
@@ -47,7 +48,8 @@ class ExplainJob:
self._explainers = []
self._samples_info = {}
self._labels_info = {}
self._benchmark = {}
self._explainer_score_dict = defaultdict(list)
self._label_score_dict = defaultdict(dict)
self._overlay_dict = {}
self._image_dict = {}

@@ -66,9 +68,10 @@ class ExplainJob:
"""
all_classes_return = []
for label_id, label_info in self._labels_info.items():
single_info = {'id': label_id,
'label': label_info['label'],
'sample_count': len(label_info['sample_ids'])}
single_info = {
'id': label_id,
'label': label_info['label'],
'sample_count': len(label_info['sample_ids'])}
all_classes_return.append(single_info)
return all_classes_return

@@ -85,7 +88,21 @@ class ExplainJob:
@property
def explainer_scores(self):
"""Return evaluation results for every explainer."""
return [score for score in self._benchmark.values()]
merged_scores = []
for explainer, explainer_score_on_metric in self._explainer_score_dict.items():
label_scores = []
for label, label_score_on_metric in self._label_score_dict[explainer].items():
score_single_label = {
'label': self._labels[label],
'evaluations': label_score_on_metric,
}
label_scores.append(score_single_label)
merged_scores.append({
'explainer': explainer,
'evaluations': explainer_score_on_metric,
'class_scores': label_scores,
})
return merged_scores

@property
def sample_count(self):
@@ -164,10 +181,10 @@ class ExplainJob:
"""
if isinstance(new_time, datetime):
self._latest_update_time = new_time.timestamp()
elif isinstance(new_time, str):
elif isinstance(new_time, float):
self._latest_update_time = new_time
else:
raise TypeError('new_time should have type of str or datetime')
raise TypeError('new_time should have type of float or datetime')

@property
def loader_id(self):
@@ -191,30 +208,6 @@ class ExplainJob:
update_time = os.stat(file_path).st_mtime
return update_time

@staticmethod
def _total_score_to_dict(total_scores: Iterable):
"""Transfer a list of benchmark score to a list of dict."""
evaluation_info = []
for total_score in total_scores:
metric_result = {'metric': total_score.benchmark_method,
'score': total_score.score}
evaluation_info.append(metric_result)
return evaluation_info

@staticmethod
def _label_score_to_dict(label_scores: Iterable, labels: List[str]):
"""Transfer a list of benchmark score."""
evaluation_info = [{'label': label, 'evaluations': []}
for label in labels]
for label_score in label_scores:
metric = label_score.benchmark_method
for i, score in enumerate(label_score.score):
label_metric_score = dict()
label_metric_score['metric'] = metric
label_metric_score['score'] = score
evaluation_info[i]['evaluations'].append(label_metric_score)
return evaluation_info

def _initialize_labels_info(self):
"""Initialize a dict for labels in the job."""
if self._labels is None:
@@ -288,7 +281,6 @@ class ExplainJob:

Return:
string, image data in base64 byte

"""
return self._image_dict.get(image_id, None)

@@ -332,8 +324,7 @@ class ExplainJob:
}

if 'metadata' not in event and self._is_metadata_empty():
raise ValueError('metadata is empty, should write metadata first'
'in the summary.')
raise ValueError('metadata is empty, should write metadata first in the summary.')
for tag in tags:
if tag not in event:
continue
@@ -347,12 +338,12 @@ class ExplainJob:

if tag == PluginNameEnum.BENCHMARK.value:
benchmark_event = event[tag].benchmark
benchmark = self._event_parser.parse_benchmark(benchmark_event)
self._benchmark = benchmark
explain_score_dict, label_score_dict = EventParser.parse_benchmark(benchmark_event)
self._update_benchmark(explain_score_dict, label_score_dict)

elif tag == PluginNameEnum.METADATA.value:
metadata_event = event[tag].metadata
metadata = self._event_parser.parse_metadata(metadata_event)
metadata = EventParser.parse_metadata(metadata_event)
self._explainers, self._metrics, self._labels = metadata
self._initialize_labels_info()

@@ -389,7 +380,18 @@ class ExplainJob:
self._explainers.clear()
self._samples_info.clear()
self._labels_info.clear()
self._benchmark.clear()
self._explainer_score_dict.clear()
self._label_score_dict.clear()
self._overlay_dict.clear()
self._image_dict.clear()
self._event_parser.clear()

def _update_benchmark(self, explainer_score_dict, labels_score_dict):
"""Update the benchmark info."""
for explainer, score in explainer_score_dict.items():
self._explainer_score_dict[explainer].extend(score)

for explainer, score in labels_score_dict.items():
for label, score_of_label in score.items():
self._label_score_dict[explainer][label] = (self._label_score_dict[explainer].get(label, [])
+ score_of_label)

+ 13
- 7
mindinsight/explainer/manager/explain_manager.py View File

@@ -24,7 +24,7 @@ from mindinsight.explainer.common.log import logger
from mindinsight.explainer.manager.explain_job import ExplainJob
from mindinsight.datavisual.data_access.file_handler import FileHandler
from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher
from mindinsight.utils.exceptions import MindInsightException, ParamValueError
from mindinsight.utils.exceptions import MindInsightException, ParamValueError, UnknownError

_MAX_LOADER_NUM = 3
_MAX_INTERVAL = 3
@@ -54,11 +54,14 @@ class ExplainManager:
def _reload_data(self):
"""periodically load summary from file."""
while True:
self._load_data()
try:
self._load_data()

if not self._reload_interval:
break
time.sleep(self._reload_interval)
if not self._reload_interval:
break
time.sleep(self._reload_interval)
except UnknownError:
self._status = _ExplainManagerStatus.INVALID.value

def _load_data(self):
"""Loading the summary in the given base directory."""
@@ -73,8 +76,11 @@ class ExplainManager:

self._status = _ExplainManagerStatus.LOADING.value

self._generate_loaders()
self._execute_load_data()
try:
self._generate_loaders()
self._execute_load_data()
except Exception as ex:
raise UnknownError(ex)

if not self._loader_pool:
self._status = _ExplainManagerStatus.INVALID.value


Loading…
Cancel
Save