|
- # Copyright 2023 The KubeEdge Authors.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- import os
- import time
-
- from PIL import Image
- from sedna.datasources import BaseDataSource
- from sedna.common.config import Context
- from sedna.common.log import LOGGER
- from sedna.common.file_ops import FileOps
- from sedna.core.lifelong_learning import LifelongLearning
-
- from interface import Estimator
-
-
- def unseen_sample_postprocess(sample, save_url):
- if isinstance(sample, dict):
- img = sample.get("image")
- image_name = "{}.png".format(str(time.time()))
- image_url = FileOps.join_path(save_url, image_name)
- img.save(image_url)
- else:
- image_name = os.path.basename(sample[0])
- image_url = FileOps.join_path(save_url, image_name)
- FileOps.upload(sample[0], image_url, clean=False)
-
-
- def preprocess(samples):
- data = BaseDataSource(data_type="test")
- data.x = [samples]
- return data
-
-
- def init_ll_job():
- estimator = Estimator(num_class=Context.get_parameters("num_class", 24),
- save_predicted_image=True,
- merge=True)
-
- task_allocation = {
- "method": "TaskAllocationStream"
- }
- unseen_task_allocation = {
- "method": "UnseenTaskAllocationDefault"
- }
-
- ll_job = LifelongLearning(
- estimator,
- unseen_estimator=unseen_task_processing,
- task_definition=None,
- task_relationship_discovery=None,
- task_allocation=task_allocation,
- task_remodeling=None,
- inference_integrate=None,
- task_update_decision=None,
- unseen_task_allocation=unseen_task_allocation,
- unseen_sample_recognition=None,
- unseen_sample_re_recognition=None)
- return ll_job
-
-
- def unseen_task_processing():
- return "Warning: unseen sample detected."
-
-
- def predict():
- ll_job = init_ll_job()
- test_data_dir = Context.get_parameters("test_data")
- test_data = os.listdir(test_data_dir)
- test_data_num = len(test_data)
- count = 0
-
- # Simulate a permenant inference service
- while True:
- for i, data in enumerate(test_data):
- LOGGER.info(f"Start to inference image {i + count + 1}")
-
- test_data_url = os.path.join(test_data_dir, data)
- img_rgb = Image.open(test_data_url).convert("RGB")
- sample = {'image': img_rgb, "depth": img_rgb, "label": img_rgb}
- predict_data = preprocess(sample)
- prediction, is_unseen, _ = ll_job.inference(
- predict_data,
- unseen_sample_postprocess=unseen_sample_postprocess)
- LOGGER.info(f"Image {i + count + 1} is unseen task: {is_unseen}")
- LOGGER.info(
- f"Image {i + count + 1} prediction result: {prediction}")
- time.sleep(1.0)
-
- count += test_data_num
-
-
- if __name__ == '__main__':
- predict()
|