You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

inference.py 2.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Copyright 2021 The KubeEdge Authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import csv
  16. import json
  17. import time
  18. from sedna.common.config import Context
  19. from sedna.datasources import CSVDataParse
  20. from sedna.core.lifelong_learning import LifelongLearning
  21. from interface import DATACONF, Estimator, feature_process
  22. def main():
  23. utd = Context.get_parameters("UTD_NAME", "TaskAttrFilter")
  24. attribute = json.dumps({"attribute": DATACONF["ATTRIBUTES"]})
  25. utd_parameters = Context.get_parameters("UTD_PARAMETERS", {})
  26. ut_saved_url = Context.get_parameters("UTD_SAVED_URL", "/tmp")
  27. task_mining = {
  28. "method": "TaskMiningByDataAttr",
  29. "param": attribute
  30. }
  31. unseen_task_detect = {
  32. "method": utd,
  33. "param": utd_parameters
  34. }
  35. ll_service = LifelongLearning(
  36. estimator=Estimator,
  37. task_mining=task_mining,
  38. task_definition=None,
  39. task_relationship_discovery=None,
  40. task_remodeling=None,
  41. inference_integrate=None,
  42. unseen_task_detect=unseen_task_detect)
  43. infer_dataset_url = Context.get_parameters('infer_dataset_url')
  44. file_handle = open(infer_dataset_url, "r", encoding="utf-8")
  45. header = list(csv.reader([file_handle.readline().strip()]))[0]
  46. infer_data = CSVDataParse(data_type="test", func=feature_process)
  47. unseen_sample = open(os.path.join(ut_saved_url, "unseen_sample.csv"),
  48. "w", encoding="utf-8")
  49. unseen_sample.write("\t".join(header + ['pred']) + "\n")
  50. output_sample = open(f"{infer_dataset_url}_out.csv", "w", encoding="utf-8")
  51. output_sample.write("\t".join(header + ['pred']) + "\n")
  52. while 1:
  53. where = file_handle.tell()
  54. line = file_handle.readline()
  55. if not line:
  56. time.sleep(1)
  57. file_handle.seek(where)
  58. continue
  59. reader = list(csv.reader([line.strip()]))
  60. rows = reader[0]
  61. data = dict(zip(header, rows))
  62. infer_data.parse(data, label=DATACONF["LABEL"])
  63. rsl, is_unseen, target_task = ll_service.inference(infer_data)
  64. rows.append(list(rsl)[0])
  65. output = "\t".join(map(str, rows)) + "\n"
  66. if is_unseen:
  67. unseen_sample.write(output)
  68. output_sample.write(output)
  69. unseen_sample.close()
  70. output_sample.close()
  71. if __name__ == '__main__':
  72. print(main())