@@ -1,22 +1,22 @@
import os
os.environ["MODEL_URLS"] = "s3://kubeedge/sedna-robo/kb/index.pkl"
os.environ["S3_ENDPOINT_URL"] = "https://obs.cn-north-1.myhuaweicloud.com"
os.environ["SECRET_ACCESS_KEY"] = "OYPxi4uD9k5E90z0Od3Ug99symbJZ0AfyB4oveQc"
os.environ["ACCESS_KEY_ID"] = "EMPTKHQUGPO2CDUFD2YR"
os.environ["KB_SERVER"] = "http://0.0.0.0:9020"
import time
os.environ["MODEL_URLS"] = "s3://kubeedge/sedna-robo/kb_next/index.pkl"
# set in yaml
os.environ["unseen_save_url"] = "s3://kubeedge/sedna-robo/unseen_samples/"
os.environ["OOD_model"] = "s3://kubeedge/sedna-robo/models/lr_model.model"
os.environ["OOD_thresh"] = "0.0001"
import cv2
import time
import numpy as np
from PIL import Image
import warnings
os.environ["robo_skill"] = "ramp_detection"
os.environ["ramp_detection"] = "s3://kubeedge/sedna-robo/models/ramp_train1_200.pth"
os.environ["curb_detection"] = "s3://kubeedge/sedna-robo/models/2048x1024_80.pth"
from sedna.datasources import BaseDataSource, TxtDataParse
from basemodel import Model, preprocess_frames
from sedna.datasources import BaseDataSource
from sedna.core.lifelong_learning import LifelongLearning
from sedna.common.config import Context
from basemodel import Model, preprocess_frames
def preprocess(samples):
data = BaseDataSource(data_type="test")
@@ -34,20 +34,10 @@ def postprocess(samples):
return image_names, imgs
def _load_txt_dataset(dataset_url):
# use original dataset url
original_dataset_url = Context.get_parameters('original_dataset_url', "")
dataset_urls = dataset_url.split()
dataset_urls = [
os.path.join(
os.path.dirname(original_dataset_url),
dataset_url) for dataset_url in dataset_urls]
return dataset_urls[:-1], dataset_urls[-1]
def init_ll_job(**kwargs):
def init_ll_job():
robo_skill = Context.get_parameters("robo_skill", "ramp_detection")
estimator = Model(num_class=31,
weight_path=kwargs.get('weight_path' ),
weight_path=Context.get_parameters(robo_skill),
save_predicted_image=True,
merge=True)
@@ -57,26 +47,17 @@ def init_ll_job(**kwargs):
unseen_task_allocation = {
"method": "UnseenTaskAllocationDefault"
}
unseen_sample_recognition = {
"method": "OodIdentification",
"param": {
"OOD_thresh": float(kwargs.get ("OOD_thresh")),
"backup_model": kwargs.get("OOD_backup _model"),
"OOD_model_path": kwargs.get("OOD_model" ),
"OOD_thresh": float(Context.get_parameters ("OOD_thresh")),
"OOD_model": Context.get_parameters("OOD _model"),
"OOD_backup_model": Context.get_parameters(robo_skill ),
"preprocess_func": preprocess_frames,
"base_model": Model
}
}
# unseen_sample_recognition = {
# "method": "SampleRegonitionRobotic"
# }
inference_integrate = {
"method": "InferenceIntegrateByType"
}
ll_job = LifelongLearning(
estimator,
unseen_estimator=unseen_task_processing,
@@ -84,62 +65,13 @@ def init_ll_job(**kwargs):
task_relationship_discovery=None,
task_allocation=task_allocation,
task_remodeling=None,
inference_integrate=inference_integrat e,
inference_integrate=Non e,
task_update_decision=None,
unseen_task_allocation=unseen_task_allocation,
unseen_sample_recognition=unseen_sample_recognition,
unseen_sample_re_recognition=None)
return ll_job
def unseen_task_processing():
return "Warning: unseen sample detected."
def predict():
ll_job = init_ll_job()
camera_address = Context.get_parameters('video_url')
# use video streams for testing
camera = cv2.VideoCapture(camera_address)
fps = 10
nframe = 0
while 1:
ret, input_yuv = camera.read()
if not ret:
time.sleep(5)
camera = cv2.VideoCapture(camera_address)
continue
if nframe % fps:
nframe += 1
continue
img_rgb = cv2.cvtColor(input_yuv, cv2.COLOR_BGR2RGB)
nframe += 1
if nframe % 1000 == 1: # logs every 1000 frames
warnings.warn(f"camera is open, current frame index is {nframe}")
img_rgb = cv2.resize(np.array(img_rgb), (2048, 1024),
interpolation=cv2.INTER_CUBIC)
img_rgb = Image.fromarray(img_rgb)
data = {'image': img_rgb, "depth": img_rgb, "label": img_rgb}
data = preprocess(data)
print(postprocess)
print("Inference results:", ll_job.inference(
data=data, post_process=postprocess))
def predict_batch():
ll_job = init_ll_job()
test_dataset_url = Context.get_parameters("test_dataset_url")
test_data = TxtDataParse(data_type="test", func=_load_txt_dataset)
test_data.parse(test_dataset_url, use_raw=False)
return ll_job.inference(data=test_data)
if __name__ == '__main__':
print(predict())