You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

inference.py 2.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. # Copyright 2021 The KubeEdge Authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import csv
  16. import json
  17. import time
  18. from interface import DATACONF, Estimator, feature_process
  19. from sedna.common.config import Context
  20. from sedna.datasources import CSVDataParse
  21. from sedna.core.lifelong_learning import LifelongLearning
  22. def main():
  23. utd = Context.get_parameters("UTD_NAME", "TaskAttr")
  24. attribute = json.dumps({"attribute": DATACONF["ATTRIBUTES"]})
  25. utd_parameters = Context.get_parameters("UTD_PARAMETERS", {})
  26. ut_saved_url = Context.get_parameters("UTD_SAVED_URL", "/tmp")
  27. ll_job = LifelongLearning(
  28. estimator=Estimator,
  29. task_mining="TaskMiningByDataAttr",
  30. task_mining_param=attribute,
  31. unseen_task_detect=utd,
  32. unseen_task_detect_param=utd_parameters)
  33. infer_dataset_url = Context.get_parameters('infer_dataset_url')
  34. file_handle = open(infer_dataset_url, "r", encoding="utf-8")
  35. header = list(csv.reader([file_handle.readline().strip()]))[0]
  36. infer_data = CSVDataParse(data_type="test", func=feature_process)
  37. unseen_sample = open(os.path.join(ut_saved_url, "unseen_sample.csv"),
  38. "w", encoding="utf-8")
  39. unseen_sample.write("\t".join(header + ['pred']) + "\n")
  40. output_sample = open(f"{infer_dataset_url}_out.csv", "w", encoding="utf-8")
  41. output_sample.write("\t".join(header + ['pred']) + "\n")
  42. while 1:
  43. where = file_handle.tell()
  44. line = file_handle.readline()
  45. if not line:
  46. time.sleep(1)
  47. file_handle.seek(where)
  48. continue
  49. reader = list(csv.reader([line.strip()]))
  50. rows = reader[0]
  51. data = dict(zip(header, rows))
  52. infer_data.parse(data, label=DATACONF["LABEL"])
  53. rsl, is_unseen, target_task = ll_job.inference(infer_data)
  54. rows.append(list(rsl)[0])
  55. if is_unseen:
  56. unseen_sample.write("\t".join(map(str, rows)) + "\n")
  57. output_sample.write("\t".join(map(str, rows)) + "\n")
  58. unseen_sample.close()
  59. output_sample.close()
  60. if __name__ == '__main__':
  61. print(main())