# Copyright 2021 The KubeEdge Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import csv import json import time from interface import DATACONF, Estimator, feature_process from sedna.common.config import Context from sedna.datasources import CSVDataParse from sedna.core.lifelong_learning import LifelongLearning def main(): utd = Context.get_parameters("UTD_NAME", "TaskAttr") attribute = json.dumps({"attribute": DATACONF["ATTRIBUTES"]}) utd_parameters = Context.get_parameters("UTD_PARAMETERS", {}) ut_saved_url = Context.get_parameters("UTD_SAVED_URL", "/tmp") ll_job = LifelongLearning( estimator=Estimator, task_mining="TaskMiningByDataAttr", task_mining_param=attribute, unseen_task_detect=utd, unseen_task_detect_param=utd_parameters) infer_dataset_url = Context.get_parameters('infer_dataset_url') file_handle = open(infer_dataset_url, "r", encoding="utf-8") header = list(csv.reader([file_handle.readline().strip()]))[0] infer_data = CSVDataParse(data_type="test", func=feature_process) unseen_sample = open(os.path.join(ut_saved_url, "unseen_sample.csv"), "w", encoding="utf-8") unseen_sample.write("\t".join(header + ['pred']) + "\n") output_sample = open(f"{infer_dataset_url}_out.csv", "w", encoding="utf-8") output_sample.write("\t".join(header + ['pred']) + "\n") while 1: where = file_handle.tell() line = file_handle.readline() if not line: time.sleep(1) file_handle.seek(where) continue reader = list(csv.reader([line.strip()])) rows = reader[0] data = dict(zip(header, rows)) infer_data.parse(data, label=DATACONF["LABEL"]) rsl, is_unseen, target_task = ll_job.inference(infer_data) rows.append(list(rsl)[0]) if is_unseen: unseen_sample.write("\t".join(map(str, rows)) + "\n") output_sample.write("\t".join(map(str, rows)) + "\n") unseen_sample.close() output_sample.close() if __name__ == '__main__': print(main())