You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

predict.py 1.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import time
  2. from sedna.datasources import BaseDataSource
  3. from sedna.core.lifelong_learning import LifelongLearning
  4. from basemodel import Model
  5. def preprocess(samples):
  6. data = BaseDataSource(data_type="test")
  7. data.x = [samples]
  8. return data
  9. def postprocess(samples):
  10. image_names, imgs = [], []
  11. for sample in samples:
  12. img = sample.get("image")
  13. image_names.append("{}.png".format(str(time.time())))
  14. imgs.append(img)
  15. return image_names, imgs
  16. def init_ll_job():
  17. estimator = Model(num_class=31,
  18. save_predicted_image=True,
  19. merge=True)
  20. task_allocation = {
  21. "method": "TaskAllocationDefault"
  22. }
  23. unseen_task_allocation = {
  24. "method": "UnseenTaskAllocationDefault"
  25. }
  26. ll_job = LifelongLearning(
  27. estimator,
  28. unseen_estimator=unseen_task_processing,
  29. task_definition=None,
  30. task_relationship_discovery=None,
  31. task_allocation=task_allocation,
  32. task_remodeling=None,
  33. inference_integrate=None,
  34. task_update_decision=None,
  35. unseen_task_allocation=unseen_task_allocation,
  36. unseen_sample_recognition=None,
  37. unseen_sample_re_recognition=None)
  38. return ll_job
  39. def unseen_task_processing():
  40. return "Warning: unseen sample detected."