|
|
|
@@ -203,7 +203,10 @@ class TSLittleModel(BaseModel): |
|
|
|
input_feed = self.create_input_feed(self.session, new_image, |
|
|
|
img_data_np) |
|
|
|
output_fetch = self.create_output_fetch(self.session) |
|
|
|
return self.session.run(output_fetch, input_feed) |
|
|
|
output = self.session.run(output_fetch, input_feed) |
|
|
|
if self.postprocess: |
|
|
|
output = self.postprocess(output) |
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
class LCReporter(threading.Thread): |
|
|
|
@@ -294,13 +297,10 @@ class JointInference: |
|
|
|
:param little_model: the little model entity for edge inference |
|
|
|
:param hard_example_mining_algorithm: the algorithm for judging hard |
|
|
|
example |
|
|
|
:param pre_hook: the pre function of edge inference |
|
|
|
:param post_hook: the post function of edge inference |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, little_model: BaseModel, |
|
|
|
hard_example_mining_algorithm=None, |
|
|
|
pre_hook=None, post_hook=None): |
|
|
|
hard_example_mining_algorithm=None): |
|
|
|
self.little_model = little_model |
|
|
|
self.big_model = BigModelClient() |
|
|
|
# TODO how to deal process use-defined cloud_offload_algorithm, |
|
|
|
@@ -329,8 +329,6 @@ class JointInference: |
|
|
|
hard_example_mining_algorithm = ThresholdFilter() |
|
|
|
|
|
|
|
self.hard_example_mining_algorithm = hard_example_mining_algorithm |
|
|
|
self.pre_hook = pre_hook |
|
|
|
self.post_hook = post_hook |
|
|
|
|
|
|
|
self.lc_reporter = LCReporter() |
|
|
|
self.lc_reporter.setDaemon(True) |
|
|
|
@@ -339,13 +337,10 @@ class JointInference: |
|
|
|
def inference(self, img_data) -> InferenceResult: |
|
|
|
"""Image inference function.""" |
|
|
|
img_data_pre = img_data |
|
|
|
if self.pre_hook: |
|
|
|
img_data_pre = self.pre_hook(img_data_pre) |
|
|
|
edge_result = self.little_model.inference(img_data_pre) |
|
|
|
if self.post_hook: |
|
|
|
edge_result = self.post_hook(edge_result) |
|
|
|
is_hard_example = self.hard_example_mining_algorithm.hard_judge( |
|
|
|
edge_result) |
|
|
|
edge_result |
|
|
|
) |
|
|
|
if not is_hard_example: |
|
|
|
LOG.debug("not hard example, use edge result directly") |
|
|
|
self.lc_reporter.update_for_edge_inference() |
|
|
|
|