|
- # Copyright 2021 The KubeEdge Authors.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- import glob
- from PIL import Image
- from sedna.common.config import Context
- from sedna.core.incremental_learning import IncrementalLearning
- from interface import Estimator
- import shutil
- import mindspore as ms
- from mobilenet_v2 import mobilenet_v2_fine_tune
-
- he_saved_url = Context.get_parameters("HE_SAVED_URL", './tmp')
-
-
- def output_deal(is_hard_example, infer_image_path):
- img_name=infer_image_path.split(r"/")[-1]
- img_category = infer_image_path.split(r"/")[-2]
- if is_hard_example:
- shutil.copy(infer_image_path,f"{he_saved_url}/{img_category}_{img_name}")
-
-
- def main():
-
- hard_example_mining = IncrementalLearning.get_hem_algorithm_from_config(
- random_ratio=0.3
- )
- incremental_instance = IncrementalLearning(estimator=Estimator, hard_example_mining=hard_example_mining)
- class_names=Context.get_parameters("class_name")
- #read parameters from deployment config
- input_shape=int(Context.get_parameters("input_shape"))
- # load ckpt
- model_url=Context.get_parameters("model_url")
- print("model_url=" + model_url)
- # load model ckpt here
- network = mobilenet_v2_fine_tune(base_model_url=model_url).get_eval_network()
- #ms.load_checkpoint(model_url, network)
- model = ms.Model(network)
- # load dataset
- #train_dataset_url = BaseConfig.train_dataset_url
- infer_dataset_url=Context.get_parameters("infer_url")
- print(infer_dataset_url)
- # get each image unber infer_dataset_url with wildcard
- while True:
- for each_img in glob.glob(infer_dataset_url+"/*/*"):
- infer_data=Image.open(each_img)
- results, _, is_hard_example = incremental_instance.inference(data=infer_data,
- model=model,
- class_names=class_names,
- input_shape=input_shape)
- hard_example="is hard example" if is_hard_example else "is not hard example"
- print(f"{each_img}--->{results}-->{hard_example}")
- output_deal(is_hard_example, each_img)
-
-
- if __name__ == "__main__":
- main()
|