You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluate.py 1.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import os
  2. from sedna.core.lifelong_learning import LifelongLearning
  3. from sedna.datasources import TxtDataParse
  4. from sedna.common.config import Context
  5. from accuracy import accuracy
  6. from basemodel import Model
  7. def _load_txt_dataset(dataset_url):
  8. # use original dataset url
  9. original_dataset_url = Context.get_parameters('original_dataset_url', "")
  10. dataset_urls = dataset_url.split()
  11. dataset_urls = [
  12. os.path.join(
  13. os.path.dirname(original_dataset_url),
  14. dataset_url) for dataset_url in dataset_urls]
  15. return dataset_urls[:-1], dataset_urls[-1]
  16. def eval():
  17. estimator = Model(num_class=31)
  18. eval_dataset_url = Context.get_parameters("test_dataset_url")
  19. eval_data = TxtDataParse(data_type="eval", func=_load_txt_dataset)
  20. eval_data.parse(eval_dataset_url, use_raw=False)
  21. task_allocation = {
  22. "method": "TaskAllocationSimple"
  23. }
  24. ll_job = LifelongLearning(estimator,
  25. task_definition=None,
  26. task_relationship_discovery=None,
  27. task_allocation=task_allocation,
  28. task_remodeling=None,
  29. inference_integrate=None,
  30. task_update_decision=None,
  31. unseen_task_allocation=None,
  32. unseen_sample_recognition=None,
  33. unseen_sample_re_recognition=None
  34. )
  35. ll_job.evaluate(eval_data, metrics=accuracy)
  36. if __name__ == '__main__':
  37. print(eval())