You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

predict.py 1.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import os
  2. import time
  3. from sedna.datasources import BaseDataSource
  4. from sedna.core.lifelong_learning import LifelongLearning
  5. from basemodel import Model
  6. def preprocess(samples):
  7. data = BaseDataSource(data_type="test")
  8. data.x = [samples]
  9. return data
  10. def postprocess(samples):
  11. image_names, imgs = [], []
  12. for sample in samples:
  13. img = sample.get("image")
  14. image_names.append("{}.png".format(str(time.time())))
  15. imgs.append(img)
  16. return image_names, imgs
  17. def init_ll_job():
  18. estimator = Model(num_class=31,
  19. save_predicted_image=True,
  20. merge=True)
  21. task_allocation = {
  22. "method": "TaskAllocationDefault"
  23. }
  24. unseen_task_allocation = {
  25. "method": "UnseenTaskAllocationDefault"
  26. }
  27. ll_job = LifelongLearning(
  28. estimator,
  29. unseen_estimator=unseen_task_processing,
  30. task_definition=None,
  31. task_relationship_discovery=None,
  32. task_allocation=task_allocation,
  33. task_remodeling=None,
  34. inference_integrate=None,
  35. task_update_decision=None,
  36. unseen_task_allocation=unseen_task_allocation,
  37. unseen_sample_recognition=None,
  38. unseen_sample_re_recognition=None)
  39. return ll_job
  40. def unseen_task_processing():
  41. return "Warning: unseen sample detected."