You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

incremental_learning.py 5.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import logging
  2. import os
  3. import tensorflow as tf
  4. import neptune
  5. from neptune.common.config import BaseConfig
  6. from neptune.common.constant import K8sResourceKindStatus, K8sResourceKind
  7. from neptune.common.utils import clean_folder, remove_path_prefix
  8. from neptune.hard_example_mining import CrossEntropyFilter, IBTFilter, \
  9. ThresholdFilter
  10. from neptune.joint_inference import TSLittleModel
  11. from neptune.lc_client import LCClient
  12. LOG = logging.getLogger(__name__)
  13. class IncrementalConfig(BaseConfig):
  14. def __init__(self):
  15. BaseConfig.__init__(self)
  16. self.model_urls = os.getenv("MODEL_URLS")
  17. self.base_model_url = os.getenv("BASE_MODEL_URL")
  18. def train(model, train_data, epochs, batch_size, class_names, input_shape,
  19. obj_threshold, nms_threshold):
  20. """The train endpoint of incremental learning.
  21. :param model: the train model
  22. :param train_data: the data use for train
  23. :param epochs: the number of epochs for training the model
  24. :param batch_size: the number of samples in a training
  25. :param class_names:
  26. :param input_shape:
  27. :param obj_threshold:
  28. :param nms_threshold:
  29. """
  30. il_config = IncrementalConfig()
  31. clean_folder(il_config.model_url)
  32. model.train(train_data, []) # validation data is empty.
  33. tf.reset_default_graph()
  34. model.save_model_pb()
  35. ckpt_model_url = remove_path_prefix(il_config.model_url,
  36. il_config.data_path_prefix)
  37. pb_model_url = remove_path_prefix(
  38. os.path.join(il_config.model_url, 'model.pb'),
  39. il_config.data_path_prefix)
  40. # TODO delete metrics whether affect lc
  41. ckpt_result = {
  42. "format": "ckpt",
  43. "url": ckpt_model_url,
  44. }
  45. pb_result = {
  46. "format": "pb",
  47. "url": pb_model_url,
  48. }
  49. results = [ckpt_result, pb_result]
  50. message = {
  51. "name": il_config.worker_name,
  52. "namespace": il_config.namespace,
  53. "ownerName": il_config.job_name,
  54. "ownerKind": K8sResourceKind.INCREMENTAL_JOB.value,
  55. "kind": "train",
  56. "status": K8sResourceKindStatus.COMPLETED.value,
  57. "results": results
  58. }
  59. LCClient.send(il_config.worker_name, message)
  60. def evaluate(model, test_data, class_names, input_shape):
  61. """The evaluation endpoint of incremental job.
  62. :param model: the model used for evaluation
  63. :param test_data:
  64. :param class_names:
  65. :param input_shape: the input shape of model
  66. """
  67. il_config = IncrementalConfig()
  68. results = []
  69. for model_url in il_config.model_urls.split(';'):
  70. precision, recall, all_precisions, all_recalls = model(
  71. model_path=model_url,
  72. test_dataset=test_data,
  73. class_names=class_names,
  74. input_shape=input_shape)
  75. result = {
  76. "format": "pb",
  77. "url": remove_path_prefix(model_url, il_config.data_path_prefix),
  78. "metrics": {
  79. "recall": recall,
  80. "precision": precision
  81. }
  82. }
  83. results.append(result)
  84. message = {
  85. "name": il_config.worker_name,
  86. "namespace": il_config.namespace,
  87. "ownerName": il_config.job_name,
  88. "ownerKind": K8sResourceKind.INCREMENTAL_JOB.value,
  89. "kind": "eval",
  90. "status": K8sResourceKindStatus.COMPLETED.value,
  91. "results": results
  92. }
  93. LCClient.send(il_config.worker_name, message)
  94. class TSModel(TSLittleModel):
  95. def __init__(self, preprocess=None, postprocess=None, input_shape=(0, 0),
  96. create_input_feed=None, create_output_fetch=None):
  97. TSLittleModel.__init__(self, preprocess, postprocess, input_shape,
  98. create_input_feed, create_output_fetch)
  99. class InferenceResult:
  100. def __init__(self, is_hard_example, infer_result):
  101. self.is_hard_example = is_hard_example
  102. self.infer_result = infer_result
  103. class Inference:
  104. def __init__(self, model: TSModel, hard_example_mining_algorithm=None):
  105. if hard_example_mining_algorithm is None:
  106. hem_name = BaseConfig.hem_name
  107. if hem_name == "IBT":
  108. threshold_box = float(neptune.context.get_hem_parameters(
  109. "threshold_box", 0.8
  110. ))
  111. threshold_img = float(neptune.context.get_hem_parameters(
  112. "threshold_img", 0.8
  113. ))
  114. hard_example_mining_algorithm = IBTFilter(threshold_img,
  115. threshold_box)
  116. elif hem_name == "CrossEntropy":
  117. threshold_cross_entropy = float(
  118. neptune.context.get_hem_parameters(
  119. "threshold_cross_entropy", 0.5
  120. )
  121. )
  122. hard_example_mining_algorithm = CrossEntropyFilter(
  123. threshold_cross_entropy)
  124. else:
  125. hard_example_mining_algorithm = ThresholdFilter()
  126. self.hard_example_mining_algorithm = hard_example_mining_algorithm
  127. self.model = model
  128. def inference(self, img_data) -> InferenceResult:
  129. result = self.model.inference(img_data)
  130. bboxes = deal_infer_rsl(result)
  131. is_hard_example = self.hard_example_mining_algorithm.hard_judge(bboxes)
  132. if is_hard_example:
  133. return InferenceResult(True, result)
  134. else:
  135. return InferenceResult(False, result)
  136. def deal_infer_rsl(model_output):
  137. all_classes, all_scores, all_bboxes = model_output
  138. bboxes = []
  139. for c, s, bbox in zip(all_classes, all_scores, all_bboxes):
  140. bbox[0], bbox[1], bbox[2], bbox[3] = bbox[1], bbox[0], bbox[3], bbox[2]
  141. bboxes.append(bbox.tolist() + [s, c])
  142. return bboxes