You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_explainer_api.py 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Test the module of backend/explainer/explainer_api."""
  16. import json
  17. from unittest.mock import patch
  18. from mindinsight.explainer.encapsulator.explain_job_encap import ExplainJobEncap
  19. from mindinsight.explainer.encapsulator.saliency_encap import SaliencyEncap
  20. from mindinsight.explainer.encapsulator.evaluation_encap import EvaluationEncap
  21. from mindinsight.explainer.encapsulator.datafile_encap import DatafileEncap
  22. from .conftest import EXPLAINER_ROUTES
  23. class TestExplainerApi:
  24. """Test the restful api of search_model."""
  25. @patch("mindinsight.backend.explainer.explainer_api.settings")
  26. @patch.object(ExplainJobEncap, "query_explain_jobs")
  27. def test_query_explain_jobs(self, mock_query_explain_jobs, mock_settings, client):
  28. """Test query all explain jobs information in the SUMMARY_BASE_DIR."""
  29. mock_settings.SUMMARY_BASE_DIR = "mock_base_dir"
  30. job_list = [
  31. {
  32. "train_id": "./mock_job_1",
  33. "create_time": "2020-10-01 20:21:23",
  34. "update_time": "2020-10-01 20:21:23",
  35. },
  36. {
  37. "train_id": "./mock_job_2",
  38. "create_time": "2020-10-02 20:21:23",
  39. "update_time": "2020-10-02 20:21:23",
  40. }
  41. ]
  42. mock_query_explain_jobs.return_value = (2, job_list)
  43. response = client.get(f"{EXPLAINER_ROUTES['explain_jobs']}?limit=10&offset=0")
  44. assert response.status_code == 200
  45. expect_result = {
  46. "name": mock_settings.SUMMARY_BASE_DIR,
  47. "total": 2,
  48. "explain_jobs": job_list
  49. }
  50. assert response.get_json() == expect_result
  51. @patch.object(ExplainJobEncap, "query_meta")
  52. def test_query_explain_job(self, mock_query_meta, client):
  53. """Test query a explain jobs' meta-data."""
  54. job_meta = {
  55. "train_id": "./mock_job_1",
  56. "create_time": "2020-10-01 20:21:23",
  57. "update_time": "2020-10-01 20:21:23",
  58. "sample_count": 1999,
  59. "classes": [
  60. {
  61. "id": 0,
  62. "label": "car",
  63. "sample_count": 1000
  64. },
  65. {
  66. "id": 0,
  67. "label": "person",
  68. "sample_count": 999
  69. }
  70. ],
  71. "saliency": {
  72. "min_confidence": 0.5,
  73. "explainers": ["Gradient", "GradCAM"],
  74. "metrics": ["Localization", "ClassSensitivity"]
  75. },
  76. "uncertainty": {
  77. "enabled": False
  78. }
  79. }
  80. mock_query_meta.return_value = job_meta
  81. response = client.get(f"{EXPLAINER_ROUTES['job_metadata']}?train_id=.%2Fmock_job_1")
  82. assert response.status_code == 200
  83. expect_result = job_meta
  84. assert response.get_json() == expect_result
  85. @patch.object(SaliencyEncap, "query_saliency_maps")
  86. def test_query_saliency_maps(self, mock_query_saliency_maps, client):
  87. """Test query saliency map results."""
  88. samples = [
  89. {
  90. "name": "sample_1",
  91. "labels": "car",
  92. "image": "/image",
  93. "inferences": [
  94. {
  95. "label": "car",
  96. "confidence": 0.85,
  97. "saliency_maps": [
  98. {
  99. "explainer": "Gradient",
  100. "overlay": "/overlay"
  101. },
  102. {
  103. "explainer": "GradCAM",
  104. "overlay": "/overlay"
  105. },
  106. ]
  107. }
  108. ]
  109. }
  110. ]
  111. mock_query_saliency_maps.return_value = (1999, samples)
  112. body_data = {
  113. "train_id": "./mock_job_1",
  114. "explainers": ["Gradient", "GradCAM"],
  115. "offset": 0,
  116. "limit": 1,
  117. "sorted_name": "confidence",
  118. "sorted_type": "descending"
  119. }
  120. response = client.post(EXPLAINER_ROUTES["saliency"], data=json.dumps(body_data))
  121. assert response.status_code == 200
  122. expect_result = {
  123. "count": 1999,
  124. "samples": samples
  125. }
  126. assert response.get_json() == expect_result
  127. @patch.object(EvaluationEncap, "query_explainer_scores")
  128. def test_query_query_evaluation(self, mock_query_explainer_scores, client):
  129. """Test query explainers' evaluation results."""
  130. explainer_scores = [
  131. {
  132. "explainer": "Gradient",
  133. "evaluations": [
  134. {
  135. "metric": "Localization",
  136. "score": 0.5
  137. }
  138. ],
  139. "class_scores": [
  140. {
  141. "label": "car",
  142. "evaluations": [
  143. {
  144. "metric": "Localization",
  145. "score": 0.5
  146. }
  147. ]
  148. }
  149. ]
  150. },
  151. ]
  152. mock_query_explainer_scores.return_value = explainer_scores
  153. response = client.get(f"{EXPLAINER_ROUTES['evaluation']}?train_id=.%2Fmock_job_1")
  154. assert response.status_code == 200
  155. expect_result = {"explainer_scores": explainer_scores}
  156. assert response.get_json() == expect_result
  157. @patch.object(DatafileEncap, "query_image_binary")
  158. def test_query_image(self, mock_query_image_binary, client):
  159. """Test query a image's binary content."""
  160. mock_query_image_binary.return_value = b'123'
  161. response = client.get(f"{EXPLAINER_ROUTES['image']}?train_id=.%2Fmock_job_1&path=1&type=original")
  162. assert response.status_code == 200
  163. assert response.data == b'123'