You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

incremental_learning.py 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. import logging
  2. import os
  3. import tensorflow as tf
  4. import neptune
  5. from neptune.common.config import BaseConfig
  6. from neptune.common.constant import K8sResourceKindStatus, K8sResourceKind
  7. from neptune.common.utils import clean_folder, remove_path_prefix
  8. from neptune.hard_example_mining import CrossEntropyFilter, IBTFilter, \
  9. ThresholdFilter
  10. from neptune.joint_inference import TSLittleModel
  11. from neptune.lc_client import LCClient
  12. LOG = logging.getLogger(__name__)
  13. class IncrementalConfig(BaseConfig):
  14. def __init__(self):
  15. BaseConfig.__init__(self)
  16. self.model_urls = os.getenv("MODEL_URLS")
  17. self.base_model_url = os.getenv("BASE_MODEL_URL")
  18. self.saved_model_name = "model.pb"
  19. def train(model, train_data, epochs, batch_size, class_names, input_shape,
  20. obj_threshold, nms_threshold):
  21. """The train endpoint of incremental learning.
  22. :param model: the train model
  23. :param train_data: the data use for train
  24. :param epochs: the number of epochs for training the model
  25. :param batch_size: the number of samples in a training
  26. :param class_names:
  27. :param input_shape:
  28. :param obj_threshold:
  29. :param nms_threshold:
  30. """
  31. il_config = IncrementalConfig()
  32. clean_folder(il_config.model_url)
  33. model.train(train_data, []) # validation data is empty.
  34. tf.reset_default_graph()
  35. model.save_model_pb(il_config.saved_model_name)
  36. ckpt_model_url = remove_path_prefix(il_config.model_url,
  37. il_config.data_path_prefix)
  38. pb_model_url = remove_path_prefix(
  39. os.path.join(il_config.model_url, il_config.saved_model_name),
  40. il_config.data_path_prefix)
  41. # TODO delete metrics whether affect lc
  42. ckpt_result = {
  43. "format": "ckpt",
  44. "url": ckpt_model_url,
  45. }
  46. pb_result = {
  47. "format": "pb",
  48. "url": pb_model_url,
  49. }
  50. results = [ckpt_result, pb_result]
  51. message = {
  52. "name": il_config.worker_name,
  53. "namespace": il_config.namespace,
  54. "ownerName": il_config.job_name,
  55. "ownerKind": K8sResourceKind.INCREMENTAL_JOB.value,
  56. "kind": "train",
  57. "status": K8sResourceKindStatus.COMPLETED.value,
  58. "results": results
  59. }
  60. LCClient.send(il_config.worker_name, message)
  61. def evaluate(model, test_data, class_names, input_shape):
  62. """The evaluation endpoint of incremental job.
  63. :param model: the model used for evaluation
  64. :param test_data:
  65. :param class_names:
  66. :param input_shape: the input shape of model
  67. """
  68. il_config = IncrementalConfig()
  69. results = []
  70. for model_url in il_config.model_urls.split(';'):
  71. precision, recall, all_precisions, all_recalls = model(
  72. model_path=model_url,
  73. test_dataset=test_data,
  74. class_names=class_names,
  75. input_shape=input_shape)
  76. result = {
  77. "format": "pb",
  78. "url": remove_path_prefix(model_url, il_config.data_path_prefix),
  79. "metrics": {
  80. "recall": recall,
  81. "precision": precision
  82. }
  83. }
  84. results.append(result)
  85. message = {
  86. "name": il_config.worker_name,
  87. "namespace": il_config.namespace,
  88. "ownerName": il_config.job_name,
  89. "ownerKind": K8sResourceKind.INCREMENTAL_JOB.value,
  90. "kind": "eval",
  91. "status": K8sResourceKindStatus.COMPLETED.value,
  92. "results": results
  93. }
  94. LCClient.send(il_config.worker_name, message)
  95. class TSModel(TSLittleModel):
  96. def __init__(self, preprocess=None, postprocess=None, input_shape=(0, 0),
  97. create_input_feed=None, create_output_fetch=None):
  98. TSLittleModel.__init__(self, preprocess, postprocess, input_shape,
  99. create_input_feed, create_output_fetch)
  100. class InferenceResult:
  101. def __init__(self, is_hard_example, infer_result):
  102. self.is_hard_example = is_hard_example
  103. self.infer_result = infer_result
  104. class Inference:
  105. def __init__(self, model: TSModel, hard_example_mining_algorithm=None):
  106. if hard_example_mining_algorithm is None:
  107. hem_name = BaseConfig.hem_name
  108. if hem_name == "IBT":
  109. threshold_box = float(neptune.context.get_hem_parameters(
  110. "threshold_box", 0.8
  111. ))
  112. threshold_img = float(neptune.context.get_hem_parameters(
  113. "threshold_img", 0.8
  114. ))
  115. hard_example_mining_algorithm = IBTFilter(threshold_img,
  116. threshold_box)
  117. elif hem_name == "CrossEntropy":
  118. threshold_cross_entropy = float(
  119. neptune.context.get_hem_parameters(
  120. "threshold_cross_entropy", 0.5
  121. )
  122. )
  123. hard_example_mining_algorithm = CrossEntropyFilter(
  124. threshold_cross_entropy)
  125. else:
  126. hard_example_mining_algorithm = ThresholdFilter()
  127. self.hard_example_mining_algorithm = hard_example_mining_algorithm
  128. self.model = model
  129. def inference(self, img_data) -> InferenceResult:
  130. result = self.model.inference(img_data)
  131. rsl = deal_infer_rsl(result)
  132. is_hard_example = self.hard_example_mining_algorithm.hard_judge(rsl)
  133. if is_hard_example:
  134. return InferenceResult(True, result)
  135. else:
  136. return InferenceResult(False, result)
  137. def deal_infer_rsl(model_output):
  138. all_classes, all_scores, all_bboxes = model_output
  139. rsl = []
  140. for c, s, bbox in zip(all_classes, all_scores, all_bboxes):
  141. bbox[0], bbox[1], bbox[2], bbox[3] = bbox[1], bbox[0], bbox[3], bbox[2]
  142. rsl.append(bbox.tolist() + [s, c])
  143. return rsl