|
- # Copyright 2021 The KubeEdge Authors.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- from __future__ import division
- import os
- import PIL
- import numpy as np
- from PIL import Image
- import mindspore as ms
- import mindspore.nn as nn
- from mindvision.engine.loss import CrossEntropySmooth
- from mindvision.engine.callback import ValAccMonitor
- from mobilenet_v2 import mobilenet_v2_fine_tune
-
- os.environ['BACKEND_TYPE'] = 'MINDSPORE'
-
-
- def preprocess(img:PIL.Image.Image):
- image=img.convert("RGB").resize((224,224))
- mean = np.array([0.485 * 255, 0.456 * 255, 0.406 * 255])
- std = np.array([0.229 * 255, 0.224 * 255, 0.225 * 255])
- image = np.array(image)
- image = (image - mean) / std
- image = image.astype(np.float32)
- image = np.transpose(image, (2, 0, 1))
- image = np.expand_dims(image, axis=0)
- return image
-
-
- class Estimator:
- def __init__(self,**kwargs):
- self.trained_ckpt_url=None
-
- def train(self, train_data,base_model_url, trained_ckpt_url, valid_data=None, epochs=10, **kwargs):
- network=mobilenet_v2_fine_tune(base_model_url).get_train_network()
- network_opt=nn.Momentum(params=network.trainable_params(), learning_rate=0.01, momentum=0.9)
- network_loss=CrossEntropySmooth(sparse=True, reduction="mean", smooth_factor=0.1, classes_num=2)
- metrics = {"Accuracy": nn.Accuracy()}
- model=ms.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=metrics)
- num_epochs = epochs
- model.train(num_epochs, train_data, callbacks=[ValAccMonitor(model, valid_data, num_epochs, save_best_ckpt=True, ckpt_directory=trained_ckpt_url), ms.TimeMonitor()])
- self.trained_ckpt_url = trained_ckpt_url+"/best.ckpt"
-
- def evaluate(self,data,model_path="", class_name="", input_shape=(224, 224), **kwargs):
- # load
- network = mobilenet_v2_fine_tune(model_path).get_eval_network()
- # eval
- network_loss = CrossEntropySmooth(sparse=True,
- reduction="mean",
- smooth_factor=0.1,
- classes_num=2)
- model = ms.Model(network, loss_fn=network_loss, optimizer=None, metrics={'acc'})
- acc=model.eval(data, dataset_sink_mode=False)
- print(acc)
- return acc
-
- def predict(self, data, model, input_shape=None, **kwargs):
- # preprocess
- preprocessed_data=preprocess(data)
- # predict
- pre=model.predict(ms.Tensor(preprocessed_data))
- return pre
-
- def save(self, model_path=None):
- if not model_path:
- return
- network = mobilenet_v2_fine_tune(self.trained_ckpt_url).get_eval_network()
- ms.save_checkpoint(network, model_path)
-
-
-
-
-
-
-
-
-
|