You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

eval.py 1.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. # Copyright 2021 The KubeEdge Authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from sedna.common.config import Context
  16. from sedna.core.incremental_learning import IncrementalLearning
  17. from sedna.datasources import TxtDataParse
  18. from interface import Estimator
  19. def _load_txt_dataset(dataset_url):
  20. # use original dataset url,
  21. # see https://github.com/kubeedge/sedna/issues/35
  22. original_dataset_url = Context.get_parameters('original_dataset_url')
  23. return os.path.join(os.path.dirname(original_dataset_url), dataset_url)
  24. def main():
  25. # load dataset.
  26. test_dataset_url = Context.get_parameters('test_dataset_url')
  27. valid_data = TxtDataParse(data_type="test", func=_load_txt_dataset)
  28. valid_data.parse(test_dataset_url, use_raw=True)
  29. # read parameters from deployment config.
  30. class_names = Context.get_parameters("class_names")
  31. class_names = [label.strip() for label in class_names.split(',')]
  32. input_shape = Context.get_parameters("input_shape")
  33. input_shape = tuple(int(shape) for shape in input_shape.split(','))
  34. incremental_instance = IncrementalLearning(estimator=Estimator)
  35. return incremental_instance.evaluate(valid_data, class_names=class_names,
  36. input_shape=input_shape)
  37. if __name__ == '__main__':
  38. main()