You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_explainer_api.py 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Test the module of backend/explainer/explainer_api."""
  16. import json
  17. from unittest.mock import patch
  18. from mindinsight.explainer.encapsulator.explain_job_encap import ExplainJobEncap
  19. from mindinsight.explainer.encapsulator.saliency_encap import SaliencyEncap
  20. from mindinsight.explainer.encapsulator.evaluation_encap import EvaluationEncap
  21. from .conftest import EXPLAINER_ROUTES
  22. class TestExplainerApi:
  23. """Test the restful api of search_model."""
  24. @patch("mindinsight.backend.explainer.explainer_api.settings")
  25. @patch.object(ExplainJobEncap, "query_explain_jobs")
  26. def test_query_explain_jobs(self, mock_query_explain_jobs, mock_settings, client):
  27. """Test query all explain jobs information in the SUMMARY_BASE_DIR."""
  28. mock_settings.SUMMARY_BASE_DIR = "mock_base_dir"
  29. job_list = [
  30. {
  31. "train_id": "./mock_job_1",
  32. "create_time": "2020-10-01 20:21:23",
  33. "update_time": "2020-10-01 20:21:23",
  34. },
  35. {
  36. "train_id": "./mock_job_2",
  37. "create_time": "2020-10-02 20:21:23",
  38. "update_time": "2020-10-02 20:21:23",
  39. }
  40. ]
  41. mock_query_explain_jobs.return_value = (2, job_list)
  42. response = client.get(f"{EXPLAINER_ROUTES['explain_jobs']}?limit=10&offset=0")
  43. assert response.status_code == 200
  44. expect_result = {
  45. "name": mock_settings.SUMMARY_BASE_DIR,
  46. "total": 2,
  47. "explain_jobs": job_list
  48. }
  49. assert response.get_json() == expect_result
  50. @patch.object(ExplainJobEncap, "query_meta")
  51. def test_query_explain_job(self, mock_query_meta, client):
  52. """Test query a explain jobs' meta-data."""
  53. job_meta = {
  54. "train_id": "./mock_job_1",
  55. "create_time": "2020-10-01 20:21:23",
  56. "update_time": "2020-10-01 20:21:23",
  57. "sample_count": 1999,
  58. "classes": [
  59. {
  60. "id": 0,
  61. "label": "car",
  62. "sample_count": 1000
  63. },
  64. {
  65. "id": 0,
  66. "label": "person",
  67. "sample_count": 999
  68. }
  69. ],
  70. "saliency": {
  71. "min_confidence": 0.5,
  72. "explainers": ["Gradient", "GradCAM"],
  73. "metrics": ["Localization", "ClassSensitivity"]
  74. },
  75. "uncertainty": {
  76. "enabled": False
  77. }
  78. }
  79. mock_query_meta.return_value = job_meta
  80. response = client.get(f"{EXPLAINER_ROUTES['job_metadata']}?train_id=.%2Fmock_job_1")
  81. assert response.status_code == 200
  82. expect_result = job_meta
  83. assert response.get_json() == expect_result
  84. @patch.object(SaliencyEncap, "query_saliency_maps")
  85. def test_query_saliency_maps(self, mock_query_saliency_maps, client):
  86. """Test query saliency map results."""
  87. samples = [
  88. {
  89. "name": "sample_1",
  90. "labels": "car",
  91. "image": "/image",
  92. "inferences": [
  93. {
  94. "label": "car",
  95. "confidence": 0.85,
  96. "saliency_maps": [
  97. {
  98. "explainer": "Gradient",
  99. "overlay": "/overlay"
  100. },
  101. {
  102. "explainer": "GradCAM",
  103. "overlay": "/overlay"
  104. },
  105. ]
  106. }
  107. ]
  108. }
  109. ]
  110. mock_query_saliency_maps.return_value = (1999, samples)
  111. body_data = {
  112. "train_id": "./mock_job_1",
  113. "explainers": ["Gradient", "GradCAM"],
  114. "offset": 0,
  115. "limit": 1,
  116. "sorted_name": "confidence",
  117. "sorted_type": "descending"
  118. }
  119. response = client.post(EXPLAINER_ROUTES["saliency"], data=json.dumps(body_data))
  120. assert response.status_code == 200
  121. expect_result = {
  122. "count": 1999,
  123. "samples": samples
  124. }
  125. assert response.get_json() == expect_result
  126. @patch.object(EvaluationEncap, "query_explainer_scores")
  127. def test_query_query_evaluation(self, mock_query_explainer_scores, client):
  128. """Test query explainers' evaluation results."""
  129. explainer_scores = [
  130. {
  131. "explainer": "Gradient",
  132. "evaluations": [
  133. {
  134. "metric": "Localization",
  135. "score": 0.5
  136. }
  137. ],
  138. "class_scores": [
  139. {
  140. "label": "car",
  141. "evaluations": [
  142. {
  143. "metric": "Localization",
  144. "score": 0.5
  145. }
  146. ]
  147. }
  148. ]
  149. },
  150. ]
  151. mock_query_explainer_scores.return_value = explainer_scores
  152. response = client.get(f"{EXPLAINER_ROUTES['evaluation']}?train_id=.%2Fmock_job_1")
  153. assert response.status_code == 200
  154. expect_result = {"explainer_scores": explainer_scores}
  155. assert response.get_json() == expect_result
  156. @patch.object(ExplainJobEncap, "query_image_binary")
  157. def test_query_image(self, mock_query_image_binary, client):
  158. """Test query a image's binary content."""
  159. mock_query_image_binary.return_value = b'123'
  160. response = client.get(f"{EXPLAINER_ROUTES['image']}?train_id=.%2Fmock_job_1&image_id=1&type=original")
  161. assert response.status_code == 200
  162. assert response.data == b'123'