Browse Source

!816 Fixed Job List issue and added UT for explainer JSON API

Merge pull request !816 from TonyNG/master
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
d7e5a7103c
13 changed files with 641 additions and 74 deletions
  1. +1
    -3
      mindinsight/backend/explainer/__init__.py
  2. +26
    -15
      mindinsight/backend/explainer/explainer_api.py
  3. +26
    -37
      mindinsight/explainer/encapsulator/explain_job_encap.py
  4. +19
    -19
      mindinsight/explainer/encapsulator/saliency_encap.py
  5. +22
    -0
      mindinsight/explainer/manager/explain_manager.py
  6. +15
    -0
      tests/ut/backend/explainer/__init__.py
  7. +41
    -0
      tests/ut/backend/explainer/conftest.py
  8. +194
    -0
      tests/ut/backend/explainer/test_explainer_api.py
  9. +15
    -0
      tests/ut/explainer/encapsulator/__init__.py
  10. +115
    -0
      tests/ut/explainer/encapsulator/mock_explain_manager.py
  11. +52
    -0
      tests/ut/explainer/encapsulator/test_evaluation_encap.py
  12. +53
    -0
      tests/ut/explainer/encapsulator/test_explain_job_encap.py
  13. +62
    -0
      tests/ut/explainer/encapsulator/test_saliency_encap.py

+ 1
- 3
mindinsight/backend/explainer/__init__.py View File

@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
module init file.
"""
"""Module init file."""
from mindinsight.backend.explainer.explainer_api import init_module as init_query_module




+ 26
- 15
mindinsight/backend/explainer/explainer_api.py View File

@@ -15,6 +15,7 @@
"""Explainer restful api."""

import os
import json
import urllib.parse

from flask import Blueprint
@@ -36,8 +37,6 @@ from mindinsight.explainer.encapsulator.evaluation_encap import EvaluationEncap
URL_PREFIX = settings.URL_PATH_PREFIX+settings.API_PREFIX
BLUEPRINT = Blueprint("explainer", __name__, url_prefix=URL_PREFIX)

STATIC_EXPLAIN_MGR = True


class ExplainManagerHolder:
"""ExplainManger instance holder."""
@@ -46,38 +45,50 @@ class ExplainManagerHolder:

@classmethod
def get_instance(cls):
if cls.static_instance:
return cls.static_instance
instance = ExplainManager(settings.SUMMARY_BASE_DIR)
instance.start_load_data()
return instance
return cls.static_instance

@classmethod
def initialize(cls):
if STATIC_EXPLAIN_MGR:
cls.static_instance = ExplainManager(settings.SUMMARY_BASE_DIR)
cls.static_instance.start_load_data()
cls.static_instance = ExplainManager(settings.SUMMARY_BASE_DIR)
cls.static_instance.start_load_data()


def _image_url_formatter(train_id, image_id, image_type):
"""returns image url."""
"""Returns image url."""
train_id = urllib.parse.quote(str(train_id))
image_id = urllib.parse.quote(str(image_id))
image_type = urllib.parse.quote(str(image_type))
return f"{URL_PREFIX}/explainer/image?train_id={train_id}&image_id={image_id}&type={image_type}"


def _read_post_request(post_request):
"""
Extract the body of post request.

Args:
post_request (object): The post request.

Returns:
dict, the deserialized body of request.
"""
body = post_request.stream.read()
try:
body = json.loads(body if body else "{}")
except json.decoder.JSONDecodeError:
raise ParamValueError("Json data parse failed.")
return body


@BLUEPRINT.route("/explainer/explain-jobs", methods=["GET"])
def query_explain_jobs():
"""Query explain jobs."""
offset = request.args.get("offset", default=0)
limit = request.args.get("limit", default=10)
train_id = get_train_id(request)
offset = Validation.check_offset(offset=offset)
limit = Validation.check_limit(limit, min_value=1, max_value=SummaryWatcher.MAX_SUMMARY_DIR_COUNT)

encapsulator = ExplainJobEncap(ExplainManagerHolder.get_instance())
total, jobs = encapsulator.query_explain_jobs(offset, limit, train_id)
total, jobs = encapsulator.query_explain_jobs(offset, limit)

return jsonify({
'name': os.path.basename(os.path.realpath(settings.SUMMARY_BASE_DIR)),
@@ -102,7 +113,7 @@ def query_explain_job():
def query_saliency():
"""Query saliency map related results."""

data = request.get_json(silent=True)
data = _read_post_request(request)

train_id = data.get("train_id")
if train_id is None:
@@ -149,7 +160,7 @@ def query_evaluation():

@BLUEPRINT.route("/explainer/image", methods=["GET"])
def query_image():
"""Query image"""
"""Query image."""
train_id = get_train_id(request)
if train_id is None:
raise ParamMissError("train_id")


+ 26
- 37
mindinsight/explainer/encapsulator/explain_job_encap.py View File

@@ -18,7 +18,6 @@ import copy
from datetime import datetime

from mindinsight.utils.exceptions import ParamValueError
from mindinsight.datavisual.data_transform.summary_watcher import SummaryWatcher
from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap


@@ -26,48 +25,29 @@ class ExplainJobEncap(ExplainDataEncap):
"""Explain job list encapsulator."""

DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S"
DEFAULT_MIN_CONFIDENCE = 0.5

def query_explain_jobs(self, offset, limit, train_id):
def query_explain_jobs(self, offset, limit):
"""
Query explain job list.
Args:
offset (int): offset
limit (int): max. no. of items to be returned
train_id (str): job id
offset (int): Page offset.
limit (int): Max. no. of items to be returned.
Returns:
Tuple[int, List[Dict]], total no. of jobs and job list
Tuple[int, List[Dict]], total no. of jobs and job list.
"""
watcher = SummaryWatcher()
total, dir_infos = \
watcher.list_explain_directories(self.job_manager.summary_base_dir,
offset=offset, limit=limit)
obj_offset = offset * limit
job_infos = []

if train_id is None:
end = total
if obj_offset + limit < end:
end = obj_offset + limit

for i in range(obj_offset, end):
job_id = dir_infos[i]["relative_path"]
job = self.job_manager.get_job(job_id)
if job is not None:
job_infos.append(self._job_2_info(job))
else:
job = self.job_manager.get_job(train_id)
if job is not None:
job_infos.append(self._job_2_info(job))
total, dir_infos = self.job_manager.get_job_list(offset=offset, limit=limit)
job_infos = [self._dir_2_info(dir_info) for dir_info in dir_infos]

return total, job_infos

def query_meta(self, train_id):
"""
Query explain job meta-data
Query explain job meta-data.
Args:
train_id (str): job id
train_id (str): Job ID.
Returns:
Dict, the metadata
Dict, the metadata.
"""
job = self.job_manager.get_job(train_id)
if job is None:
@@ -78,11 +58,11 @@ class ExplainJobEncap(ExplainDataEncap):
"""
Query image binary content.
Args:
train_id (str): job id
image_id (str): image id
image_type (str) 'original' or 'overlay'
train_id (str): Job ID.
image_id (str): Image ID.
image_type (str): Image type, 'original' or 'overlay'.
Returns:
bytes, image binary
bytes, image binary.
"""
job = self.job_manager.get_job(train_id)

@@ -97,9 +77,18 @@ class ExplainJobEncap(ExplainDataEncap):

return binary

@classmethod
def _dir_2_info(cls, dir_info):
"""Convert ExplainJob object to jsonable info object."""
info = dict()
info["train_id"] = dir_info["relative_path"]
info["create_time"] = dir_info["create_time"].strftime(cls.DATETIME_FORMAT)
info["update_time"] = dir_info["update_time"].strftime(cls.DATETIME_FORMAT)
return info

@classmethod
def _job_2_info(cls, job):
"""Convert ExplainJob object to jsonable info object"""
"""Convert ExplainJob object to jsonable info object."""
info = dict()
info["train_id"] = job.train_id
info["create_time"] = datetime.fromtimestamp(job.create_time)\
@@ -110,13 +99,13 @@ class ExplainJobEncap(ExplainDataEncap):

@classmethod
def _job_2_meta(cls, job):
"""Convert ExplainJob's meta-data to jsonable info object"""
"""Convert ExplainJob's meta-data to jsonable info object."""
info = cls._job_2_info(job)
info["sample_count"] = job.sample_count
info["classes"] = copy.deepcopy(job.all_classes)
saliency_info = dict()
if job.min_confidence is None:
saliency_info["min_confidence"] = 0.5
saliency_info["min_confidence"] = cls.DEFAULT_MIN_CONFIDENCE
else:
saliency_info["min_confidence"] = job.min_confidence
saliency_info["explainers"] = list(job.explainers)


+ 19
- 19
mindinsight/explainer/encapsulator/saliency_encap.py View File

@@ -19,8 +19,8 @@ import copy
from mindinsight.explainer.encapsulator.explain_data_encap import ExplainDataEncap


def _sort_key_confid(sample):
"""Samples sort key by the max. confidence"""
def _sort_key_confidence(sample):
"""Samples sort key by the max. confidence."""
max_confid = None
for inference in sample["inferences"]:
if max_confid is None or inference["confidence"] > max_confid:
@@ -44,23 +44,23 @@ class SaliencyEncap(ExplainDataEncap):
sorted_name,
sorted_type):
"""
Query saliency maps
Query saliency maps.
Args:
train_id (str): job id
labels (List[str]): labels filter
explainers (List[str]): explainers of saliency maps to be shown
limit (int): max. no. of items to be returned
offset (int): item offset
sorted_name (str): field to be sorted
sorted_type (str): 'ascending' or 'descending' order
train_id (str): Job ID.
labels (List[str]): Label filter.
explainers (List[str]): Explainers of saliency maps to be shown.
limit (int): Max. no. of items to be returned.
offset (int): Page offset.
sorted_name (str): Field to be sorted.
sorted_type (str): Sorting order, 'ascending' or 'descending'.

Returns:
Tuple[int, List[dict]], total no. of samples after filtering and
list of sample result
list of sample result.
"""
job = self.job_manager.get_job(train_id)
if job is None:
return None
return 0, None

samples = copy.deepcopy(job.get_all_samples())
if labels:
@@ -77,7 +77,7 @@ class SaliencyEncap(ExplainDataEncap):

reverse = sorted_type == "descending"
if sorted_name == "confidence":
samples.sort(key=_sort_key_confid, reverse=reverse)
samples.sort(key=_sort_key_confidence, reverse=reverse)

sample_infos = []
obj_offset = offset*limit
@@ -93,13 +93,13 @@ class SaliencyEncap(ExplainDataEncap):

def _touch_sample(self, sample, job, explainers):
"""
Final editing the sample info
Final editing the sample info.
Args:
sample (dict): sample info
job (ExplainJob): job
explainers (List[str]): explainer names
sample (dict): Sample info.
job (ExplainJob): Explain job.
explainers (List[str]): Explainer names.
Returns:
Dict, edited sample info
Dict, the edited sample info.
"""
sample["image"] = self._get_image_url(job.train_id, sample["id"], "original")
for inference in sample["inferences"]:
@@ -116,7 +116,7 @@ class SaliencyEncap(ExplainDataEncap):
return sample

def _get_image_url(self, train_id, image_id, image_type):
"""Returns image's url"""
"""Returns image's url."""
if self._image_url_formatter is None:
return image_id
return self._image_url_formatter(train_id, image_id, image_type)

+ 22
- 0
mindinsight/explainer/manager/explain_manager.py View File

@@ -282,6 +282,28 @@ class ExplainManager:
"""Return the base directory for summary records."""
return self._summary_base_dir

def get_job_list(self, offset=0, limit=None):
"""
Return List of explain jobs. includes job ID, create and update time.

Args:
offset (int): An offset for page. Ex, offset is 0, mean current page is 1. Default value is 0.
limit (int): The max data items for per page. Default value is 10.

Returns:
tuple[total, directories], total indicates the overall number of explain directories and directories
indicate list of summary directory info including the following attributes.
- relative_path (str): Relative path of summary directory, referring to settings.SUMMARY_BASE_DIR,
starting with "./".
- create_time (datetime): Creation time of summary file.
- update_time (datetime): Modification time of summary file.
"""
watcher = SummaryWatcher()
total, dir_infos = \
watcher.list_explain_directories(self._summary_base_dir,
offset=offset, limit=limit)
return total, dir_infos

def get_job(self, train_id):
"""
Return ExplainJob given train_id.


+ 15
- 0
tests/ut/backend/explainer/__init__.py View File

@@ -0,0 +1,15 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""UT for backend.explainer."""

+ 41
- 0
tests/ut/backend/explainer/conftest.py View File

@@ -0,0 +1,41 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Description: This file is used for constants and fixtures.
"""
import pytest
from flask import Response

from mindinsight.backend.application import APP


@pytest.fixture
def client():
"""This fixture is flask client."""
APP.response_class = Response
app_client = APP.test_client()

yield app_client


EXPLAINER_URL_BASE = '/v1/mindinsight/explainer'

EXPLAINER_ROUTES = dict(
explain_jobs=f'{EXPLAINER_URL_BASE}/explain-jobs',
job_metadata=f'{EXPLAINER_URL_BASE}/explain-job',
saliency=f'{EXPLAINER_URL_BASE}/saliency',
evaluation=f'{EXPLAINER_URL_BASE}/evaluation',
image=f'{EXPLAINER_URL_BASE}/image'
)

+ 194
- 0
tests/ut/backend/explainer/test_explainer_api.py View File

@@ -0,0 +1,194 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test the module of backend/explainer/explainer_api."""
import json

from unittest.mock import patch

from mindinsight.explainer.encapsulator.explain_job_encap import ExplainJobEncap
from mindinsight.explainer.encapsulator.saliency_encap import SaliencyEncap
from mindinsight.explainer.encapsulator.evaluation_encap import EvaluationEncap
from .conftest import EXPLAINER_ROUTES


class TestExplainerApi:
"""Test the restful api of search_model."""

@patch("mindinsight.backend.explainer.explainer_api.settings")
@patch.object(ExplainJobEncap, "query_explain_jobs")
def test_query_explain_jobs(self, mock_query_explain_jobs, mock_settings, client):
"""Test query all explain jobs information in the SUMMARY_BASE_DIR."""
mock_settings.SUMMARY_BASE_DIR = "mock_base_dir"

job_list = [
{
"train_id": "./mock_job_1",
"create_time": "2020-10-01 20:21:23",
"update_time": "2020-10-01 20:21:23",
},
{
"train_id": "./mock_job_2",
"create_time": "2020-10-02 20:21:23",
"update_time": "2020-10-02 20:21:23",
}
]

mock_query_explain_jobs.return_value = (2, job_list)

response = client.get(f"{EXPLAINER_ROUTES['explain_jobs']}?limit=10&offset=0")
assert response.status_code == 200

expect_result = {
"name": mock_settings.SUMMARY_BASE_DIR,
"total": 2,
"explain_jobs": job_list
}

assert response.get_json() == expect_result

@patch.object(ExplainJobEncap, "query_meta")
def test_query_explain_job(self, mock_query_meta, client):
"""Test query a explain jobs' meta-data."""

job_meta = {
"train_id": "./mock_job_1",
"create_time": "2020-10-01 20:21:23",
"update_time": "2020-10-01 20:21:23",
"sample_count": 1999,
"classes": [
{
"id": 0,
"label": "car",
"sample_count": 1000
},
{
"id": 0,
"label": "person",
"sample_count": 999
}
],
"saliency": {
"min_confidence": 0.5,
"explainers": ["Gradient", "GradCAM"],
"metrics": ["Localization", "ClassSensitivity"]
},
"uncertainty": {
"enabled": False
}
}

mock_query_meta.return_value = job_meta

response = client.get(f"{EXPLAINER_ROUTES['job_metadata']}?train_id=.%2Fmock_job_1")
assert response.status_code == 200

expect_result = job_meta

assert response.get_json() == expect_result

@patch.object(SaliencyEncap, "query_saliency_maps")
def test_query_saliency_maps(self, mock_query_saliency_maps, client):
"""Test query saliency map results."""

samples = [
{
"name": "sample_1",
"labels": "car",
"image": "/image",
"inferences": [
{
"label": "car",
"confidence": 0.85,
"saliency_maps": [
{
"explainer": "Gradient",
"overlay": "/overlay"
},
{
"explainer": "GradCAM",
"overlay": "/overlay"
},
]
}
]
}
]

mock_query_saliency_maps.return_value = (1999, samples)

body_data = {
"train_id": "./mock_job_1",
"explainers": ["Gradient", "GradCAM"],
"offset": 0,
"limit": 1,
"sorted_name": "confidence",
"sorted_type": "descending"
}

response = client.post(EXPLAINER_ROUTES["saliency"], data=json.dumps(body_data))
assert response.status_code == 200

expect_result = {
"count": 1999,
"samples": samples
}

assert response.get_json() == expect_result

@patch.object(EvaluationEncap, "query_explainer_scores")
def test_query_query_evaluation(self, mock_query_explainer_scores, client):
"""Test query explainers' evaluation results."""

explainer_scores = [
{
"explainer": "Gradient",
"evaluations": [
{
"metric": "Localization",
"score": 0.5
}
],
"class_scores": [
{
"label": "car",
"evaluations": [
{
"metric": "Localization",
"score": 0.5
}
]
}
]
},
]

mock_query_explainer_scores.return_value = explainer_scores

response = client.get(f"{EXPLAINER_ROUTES['evaluation']}?train_id=.%2Fmock_job_1")
assert response.status_code == 200

expect_result = {"explainer_scores": explainer_scores}
assert response.get_json() == expect_result

@patch.object(ExplainJobEncap, "query_image_binary")
def test_query_image(self, mock_query_image_binary, client):
"""Test query a image's binary content."""

mock_query_image_binary.return_value = b'123'

response = client.get(f"{EXPLAINER_ROUTES['image']}?train_id=.%2Fmock_job_1&image_id=1&type=original")

assert response.status_code == 200
assert response.data == b'123'

+ 15
- 0
tests/ut/explainer/encapsulator/__init__.py View File

@@ -0,0 +1,15 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""UT for explainer.encapsulator."""

+ 115
- 0
tests/ut/explainer/encapsulator/mock_explain_manager.py View File

@@ -0,0 +1,115 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Mock ExplainManager and ExplainJob classes for UT."""
from datetime import datetime

from mindinsight.explainer.encapsulator.explain_job_encap import ExplainJobEncap


class MockExplainJob:
"""Mock ExplainJob."""
def __init__(self, train_id):
self.train_id = train_id
self.create_time = datetime.timestamp(
datetime.strptime("2020-10-01 20:21:23",
ExplainJobEncap.DATETIME_FORMAT))
self.latest_update_time = self.create_time
self.sample_count = 1999
self.min_confidence = 0.5
self.explainers = ["Gradient"]
self.metrics = ["Localization"]
self.all_classes = [
{
"id": 0,
"label": "car",
"sample_count": 1999
}
]
self.explainer_scores = [
{
"explainer": "Gradient",
"evaluations": [
{
"metric": "Localization",
"Score": 0.5
}
],
"class_scores": [
{
"label": "car",
"evaluations": [
{
"metric": "Localization",
"score": 0.5
}
]
}
]
}
]

def retrieve_image(self, image_id):
"""Get original image binary."""
if image_id == "1":
return b'123'
return None

def retrieve_overlay(self, image_id):
"""Get overlay image binary."""
if image_id == "4":
return b'456'
return None

def get_all_samples(self):
"""Get all mock samples."""
sample = {
"id": "123",
"name": "123",
"labels": ["car"],
"inferences": [
{
"label": "car",
"confidence": 0.75,
"saliency_maps": [
{
"explainer": "Gradient",
"overlay": "4"
}
]
}
]
}
return [sample]


class MockExplainManager:
"""Mock ExplainManger."""
def get_job_list(self, offset, limit):
"""Get all mock jobs."""
del offset, limit
job_list = [
{
"relative_path": "./mock_job_1",
"create_time": datetime.strptime("2020-10-01 20:21:23", ExplainJobEncap.DATETIME_FORMAT),
"update_time": datetime.strptime("2020-10-01 20:21:23", ExplainJobEncap.DATETIME_FORMAT)
}
]
return 1, job_list

def get_job(self, train_id):
"""Get a mock job."""
if train_id == "./mock_job_1":
return MockExplainJob(train_id)
return None

+ 52
- 0
tests/ut/explainer/encapsulator/test_evaluation_encap.py View File

@@ -0,0 +1,52 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test the module of explainer.evaluation_encap."""

from mindinsight.explainer.encapsulator.evaluation_encap import EvaluationEncap
from .mock_explain_manager import MockExplainManager


class TestEvaluationEncap:
"""Test case for EvaluationEncap."""
def setup(self):
"""Setup the test case."""
self.encapsulator = EvaluationEncap(MockExplainManager())

def test_query_explainer_scores(self):
"""Test query the explainer evaluation scores."""
explainer_scores = self.encapsulator.query_explainer_scores("./mock_job_1")
expected_result = [
{
"explainer": "Gradient",
"evaluations": [
{
"metric": "Localization",
"Score": 0.5
}
],
"class_scores": [
{
"label": "car",
"evaluations": [
{
"metric": "Localization",
"score": 0.5
}
]
}
]
}
]
assert explainer_scores == expected_result

+ 53
- 0
tests/ut/explainer/encapsulator/test_explain_job_encap.py View File

@@ -0,0 +1,53 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test the module of explainer.explain_job_encap."""

from mindinsight.explainer.encapsulator.explain_job_encap import ExplainJobEncap
from .mock_explain_manager import MockExplainManager


class TestExplainJobEncap:
"""Test case of ExplainJobEncap."""
def setup(self):
"""Setup the test case."""
self.encapsulator = ExplainJobEncap(MockExplainManager())

def test_query_explain_jobs(self):
"""Test query the explain job list."""
job_list = self.encapsulator.query_explain_jobs(offset=0, limit=10)
expected_result = (1, [
{
"train_id": "./mock_job_1",
"create_time": "2020-10-01 20:21:23",
"update_time": "2020-10-01 20:21:23"
}
])
assert job_list == expected_result

def test_query_meta(self):
"""Test query a explain job's meta-data."""
job = self.encapsulator.query_meta("./mock_job_1")
assert job is not None
assert job["train_id"] == "./mock_job_1"

def test_query_image_binary(self):
"""Test query images' binary content."""
image = self.encapsulator.query_image_binary("./mock_job_1", "1", "original")
assert image is not None
assert image == b'123'

image = self.encapsulator.query_image_binary("./mock_job_1", "4", "overlay")
assert image is not None
assert image == b'456'

+ 62
- 0
tests/ut/explainer/encapsulator/test_saliency_encap.py View File

@@ -0,0 +1,62 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test the module of explainer.saliency_encap."""

from mindinsight.explainer.encapsulator.saliency_encap import SaliencyEncap
from .mock_explain_manager import MockExplainManager


def _image_url_formatter(_, image_id, image_type):
"""Return image url."""
return f"{image_type}-{image_id}"


class TestEvaluationEncap:
"""Test case for EvaluationEncap."""
def setup(self):
"""Setup the test case."""
self.encapsulator = SaliencyEncap(_image_url_formatter, MockExplainManager())

def test_saliency_maps(self):
"""Test query the saliency map results."""
saliency_maps = \
self.encapsulator.query_saliency_maps(train_id="./mock_job_1",
labels=["car"],
explainers=["Gradient"],
limit=10,
offset=0,
sorted_name="confidence",
sorted_type="descending")
expected_result = (1, [
{
"id": "123",
"name": "123",
"labels": ["car"],
"image": "original-123",
"inferences": [
{
"label": "car",
"confidence": 0.75,
"saliency_maps": [
{
"explainer": "Gradient",
"overlay": "overlay-4"
}
]
}
]
}
])
assert saliency_maps == expected_result

Loading…
Cancel
Save