Browse Source

Mindspore Demo can run locally

Signed-off-by: Lj1ang <2872509481@qq.com>
tags/v0.6.0
Lj1ang 3 years ago
parent
commit
9b7071b967
13 changed files with 833 additions and 2 deletions
  1. +18
    -0
      examples/incremental-learning-dog-croissants-classification.Dockerfile
  2. +204
    -0
      examples/incremental_learning/dog_croissants_classification/README.md
  3. +99
    -0
      examples/incremental_learning/dog_croissants_classification/dog_croissants_classification.yaml
  4. +58
    -0
      examples/incremental_learning/dog_croissants_classification/training/dataset.py
  5. +48
    -0
      examples/incremental_learning/dog_croissants_classification/training/eval.py
  6. +69
    -0
      examples/incremental_learning/dog_croissants_classification/training/inference.py
  7. +114
    -0
      examples/incremental_learning/dog_croissants_classification/training/interface.py
  8. +34
    -0
      examples/incremental_learning/dog_croissants_classification/training/mobilenet_v2.py
  9. +67
    -0
      examples/incremental_learning/dog_croissants_classification/training/train.py
  10. +18
    -0
      incremental-learning-dog-croissants-classification.Dockerfile
  11. +27
    -1
      lib/sedna/algorithms/hard_example_mining/hard_example_mining.py
  12. +4
    -1
      lib/sedna/backend/base.py
  13. +73
    -0
      lib/sedna/backend/mindspore/__init__.py

+ 18
- 0
examples/incremental-learning-dog-croissants-classification.Dockerfile View File

@@ -0,0 +1,18 @@
FROM mindspore/mindspore-cpu:1.8.1

COPY ../lib/requirements.txt /home
# install requirements of sedna lib
RUN pip install -r /home/requirements.txt
RUN pip install Pillow
RUN pip install numpy
RUN pip install mindvision

ENV PYTHONPATH "/home/lib"

WORKDIR /home/work
COPY ../lib /home/lib

COPY incremental_learning/dog_croissants_classification/training /home/work/


ENTRYPOINT ["python"]

+ 204
- 0
examples/incremental_learning/dog_croissants_classification/README.md View File

@@ -0,0 +1,204 @@


## Prepare Model
auto-download

## Prepare for inference worker
```shell
mkdir -p /incremental_learning/infer/
mkdir -p /incremental_learning/he/
mkdir -p /data/dog_croissants/
mkdir /output
```

TODO:download dataset
```shell



```

download checkpoint
```shell
# need ckpt file under both two dir
mkdir -p /models/base_model
mkdir -p /models/deploy_model
cd /models/base_model
#wget https://download.mindspore.cn/vision/classification/mobilenet_v2_1.0_224.ckpt
```
## build docker file
```shell
$ docker build -f incremental-learning-dog-croissants-classification.Dockerfile -t test/dog:v0.1 .

```

## Create Incremental Job
```shell
WORKER_NODE="edge-node"
```
Create Dataset
```shell
kubectl create -f - <<EOF
apiVersion: sedna.io/v1alpha1
kind: Dataset
metadata:
name: incremental-dataset
spec:
url: "/data/dog_croissants/train_data.txt"
format: "txt"
nodeName: $WORKER_NODE
EOF
```
Create initial Model to simulate the inital model in incremental learning scenoario
```shell
kubectl create -f - <<EOF
apiVersion: sedna.io/v1alpha1
kind: Model
metadata:
name: initial-model
spec:
url : "/models/base_model/base_model.ckpt"
format: "ckpt"
EOF
```
Create Deploy Model
```shell
kubectl create -f - <<EOF
apiVersion: sedna.io/v1alpha1
kind: Model
metadata:
name: deploy-model
spec:
url : "/models/deploy_model/deploy_model.ckpt"
format: "ckpt"
EOF
```
create the job
```shell
IMAGE=lj1ang/dog:v0.40
kubectl create -f - <<EOF
apiVersion: sedna.io/v1alpha1
kind: IncrementalLearningJob
metadata:
name: dog-croissants-classification-demo
spec:
initialModel:
name: "initial-model"
dataset:
name: "incremental-dataset"
trainProb: 0.8
trainSpec:
template:
spec:
nodeName: $WORKER_NODE
containers:
- image: $IMAGE
name: train-worker
imagePullPolicy: IfNotPresent
args: [ "train.py" ]
env:
- name: "batch_size"
value: "2"
- name: "epochs"
value: "2"
- name: "input_shape"
value: "224"
- name: "class_names"
value: "Croissants, Dog"
- name: "num_parallel_workers"
value: "2"
trigger:
checkPeriodSeconds: 60
timer:
start: 02:00
end: 20:00
condition:
operator: ">"
threshold: 50
metric: num_of_samples
evalSpec:
template:
spec:
nodeName: $WORKER_NODE
containers:
- image: $IMAGE
name: eval-worker
imagePullPolicy: IfNotPresent
args: [ "eval.py" ]
env:
- name: "input_shape"
value: "224"
- name: "batch_size"
value: "2"
- name: "num_parallel_workers"
value: "2"
- name: "class_names"
value: "Croissants, Dog"
deploySpec:
model:
name: "deploy-model"
hotUpdateEnabled: true
pollPeriodSeconds: 60
trigger:
condition:
operator: ">"
threshold: 0.1
metric: precision_delta
hardExampleMining:
name: "Random"
parameters:
- key: "random_ratio"
value: "0.3"
template:
spec:
nodeName: $WORKER_NODE
containers:
- image: $IMAGE
name: infer-worker
imagePullPolicy: IfNotPresent
args: [ "inference.py" ]
env:
- name: "input_shape"
value: "224"
- name: "infer_url"
value: "/infer"
- name: "HE_SAVED_URL"
value: "/he_saved_url"
volumeMounts:
- name: localinferdir
mountPath: /infer
- name: hedir
mountPath: /he_saved_url
resources: # user defined resources
limits:
memory: 3Gi
volumes: # user defined volumes
- name: localinferdir
hostPath:
path: /incremental_learning/infer/
type: DirectoryOrCreate
- name: hedir
hostPath:
path: /incremental_learning/he/
type: DirectoryOrCreate
outputDir: "/output"
EOF
```
## trigger
```shell
cd /data/helmet_detection
wget https://kubeedge.obs.cn-north-1.myhuaweicloud.com/examples/helmet-detection/dataset.tar.gz
tar -zxvf dataset.tar.gz
```
## delete
```shell
kubectl delete dataset incremental-dataset
kubectl delete model initial-model
kubectl delete model deploy-model
kubectl delete IncrementalLearningJob dog-croissants-classification-demo
```
```shell
ctr -n k8s.io image pull registry.aliyuncs.com/google_containers/pause:3.5
ctr -n k8s.io image tag registry.aliyuncs.com/google_containers/pause:3.5 k8s.gcr.io/pause:3.5

```

+ 99
- 0
examples/incremental_learning/dog_croissants_classification/dog_croissants_classification.yaml View File

@@ -0,0 +1,99 @@
apiVersion: sedna.io/v1alpha1
kind: IncrementalLearningJob
metadata:
name: dog-croissants-classification-demo
spec:
initialModel:
name: "initial-model"
dataset:
name: "incremental-dataset"
trainProb: 0.8
trainSpec:
template:
spec:
nodeName: $WORKER_NODE
containers:
- image: $IMAGE
name: train-worker
imagePullPolicy: IfNotPresent
args: [ "train.py" ]
env:
- name: "batch_size"
value: "8"
- name: "epochs"
value: "10"
- name: "input_shape"
value: "224"
- name: "class_names"
value: "Croissants, Dog"
trigger:
checkPeriodSeconds: 60
timer:
start: 02:00
end: 20:00
condition:
operator: ">"
threshold: 30
metric: num_of_samples
evalSpec:
template:
spec:
nodeName: $WORKER_NODE
containers:
- image: $IMAGE
name: eval-worker
imagePullPolicy: IfNotPresent
args: [ "eval.py" ]
env:
- name: "input_shape"
value: "224"
- name: "class_names"
value: "Croissants, Dog"
deploySpec:
model:
name: "deploy-model"
hotUpdateEnabled: true
pollPeriodSeconds: 60
trigger:
condition:
operator: ">"
threshold: 0.1
metric: precision_delta
hardExampleMining:
name: "Random"
parameters:
- key: "random_ratio"
value: "0.3"
template:
spec:
nodeName: $WORKER_NODE
containers:
- image: $IMAGE
name: infer-worker
imagePullPolicy: IfNotPresent
args: [ "inference.py" ]
env:
- name: "input_shape"
value: "224"
- name: "infer_url"
value: ""
- name: "HE_SAVED_URL"
value: "/he_saved_url"
volumeMounts:
- name: localinferdir
mountPath: /data/DogCroissants/infer
- name: hedir
mountPath: /he_saved_url
resources: # user defined resources
limits:
memory: 2Gi
volumes: # user defined volumes
- name: localinferdir
hostPath:
path: /incremental_learning/data/DogCroissants/infer
type: DirectoryOrCreate
- name: hedir
hostPath:
path: /incremental_learning/he/
type: DirectoryOrCreate
outputDir: "/output"

+ 58
- 0
examples/incremental_learning/dog_croissants_classification/training/dataset.py View File

@@ -0,0 +1,58 @@
# Copyright 2021 The KubeEdge Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import mindspore.dataset as ds
import mindspore.dataset.vision as vision
from sedna.datasources import BaseDataSource


class ImgDataset(BaseDataSource):
def __init__(self, data_type="train", func=None):
super(ImgDataset, self).__init__(data_type=data_type, func=func)

def parse(self, *args, path=None, train=True, image_shape=224, batch_size=2,num_parallel_workers=1, **kwargs):
dataset = ds.ImageFolderDataset(
path, num_parallel_workers=num_parallel_workers,
class_indexing={"croissants": 0, "dog": 1})
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
if train:
trans = [
vision.RandomCropDecodeResize(image_shape, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
vision.RandomHorizontalFlip(prob=0.5),
vision.Normalize(mean=mean, std=std),
vision.HWC2CHW()
]
else:
trans = [
vision.Decode(),
vision.Resize(256),
vision.CenterCrop(image_shape),
vision.Normalize(mean=mean, std=std),
vision.HWC2CHW()
]
dataset = dataset.map(operations=trans,
input_columns="image",
num_parallel_workers=num_parallel_workers)
dataset = dataset.batch(batch_size, drop_remainder=True)
return dataset

'''
def download_dataset(self):
dataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/beginner/DogCroissants.zip"
path = "./datasets"
dl = DownLoad()
dl.download_and_extract_archive(dataset_url, path)
'''


+ 48
- 0
examples/incremental_learning/dog_croissants_classification/training/eval.py View File

@@ -0,0 +1,48 @@
# Copyright 2021 The KubeEdge Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os.path

from sedna.common.config import Context
from sedna.core.incremental_learning import IncrementalLearning

from interface import Estimator
from dataset import ImgDataset

def main():

class_names=Context.get_parameters("class_name")
print(Context.get_parameters("model_path"))
#read parameters from deployment config
input_shape=int(Context.get_parameters("input_shape"))
batch_size=int(Context.get_parameters("batch_size"))
original_dataset_url=Context.get_parameters("ORIGINAL_DATASET_URL")
num_parallel_workers=int(Context.get_parameters("num_parallel_workers"))
if original_dataset_url:
print("ORIGINAL_DATASET_URL"+ original_dataset_url)
else:
print("ORIGINAL_DATASET_URL: NULL" )
eval_dataset_path=os.path.dirname(original_dataset_url)+r"/eval"
test_data=ImgDataset(data_type="eval").parse(path=eval_dataset_path,
train=False,
image_shape=input_shape,
batch_size=batch_size,
num_parallel_workers=num_parallel_workers)
incremental_instance = IncrementalLearning(estimator=Estimator)
return incremental_instance.evaluate(test_data,
class_names=class_names,
input_shape=input_shape)

if __name__ == "__main__":
main()


+ 69
- 0
examples/incremental_learning/dog_croissants_classification/training/inference.py View File

@@ -0,0 +1,69 @@
# Copyright 2021 The KubeEdge Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import os


from PIL import Image
from sedna.common.config import Context
from sedna.core.incremental_learning import IncrementalLearning
from interface import Estimator
import shutil
import mindspore as ms
from mobilenet_v2 import mobilenet_v2_fine_tune


he_saved_url = Context.get_parameters("HE_SAVED_URL", './tmp')

def output_deal(is_hard_example, infer_image_path):
img_name=infer_image_path.split(r"/")[-1]
img_category = infer_image_path.split(r"/")[-2]
if is_hard_example:
shutil.copy(infer_image_path,f"{he_saved_url}/{img_category}_{img_name}")

def main():

hard_example_mining = IncrementalLearning.get_hem_algorithm_from_config(
random_ratio=0.3
)
incremental_instance = IncrementalLearning(estimator=Estimator, hard_example_mining=hard_example_mining)
class_names=Context.get_parameters("class_name")
#read parameters from deployment config
input_shape=int(Context.get_parameters("input_shape"))
# load ckpt
model_url=Context.get_parameters("model_url")
print("model_url=" + model_url)
# load model ckpt here
network = mobilenet_v2_fine_tune(base_model_url=model_url).get_eval_network()
#ms.load_checkpoint(model_url, network)
model = ms.Model(network)
# load dataset
#train_dataset_url = BaseConfig.train_dataset_url
infer_dataset_url=Context.get_parameters("infer_url")
print(infer_dataset_url)
# get each image unber infer_dataset_url with wildcard
while True:
for each_img in glob.glob(infer_dataset_url+"/*/*"):
infer_data=Image.open(each_img)
results, _, is_hard_example = incremental_instance.inference(data=infer_data,
model=model,
class_names=class_names,
input_shape=input_shape)
hard_example="is hard example" if is_hard_example else "is not hard example"
print(f"{each_img}--->{results}-->{hard_example}")
output_deal(is_hard_example, each_img)

if __name__ == "__main__":
main()


+ 114
- 0
examples/incremental_learning/dog_croissants_classification/training/interface.py View File

@@ -0,0 +1,114 @@
# Copyright 2021 The KubeEdge Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import division


import os

import PIL
import numpy as np
from PIL import Image
import mindspore as ms
import mindspore.nn as nn
from mindvision.engine.loss import CrossEntropySmooth
from mindvision.engine.callback import ValAccMonitor
from mobilenet_v2 import mobilenet_v2_fine_tune


os.environ['BACKEND_TYPE'] = 'MINDSPORE'

def preprocess(img:PIL.Image.Image):
#image=Image.open(img_path).convert("RGB").resize((224 ,224))
image=img.convert("RGB").resize((224,224))
mean = np.array([0.485 * 255, 0.456 * 255, 0.406 * 255])
std = np.array([0.229 * 255, 0.224 * 255, 0.225 * 255])
image = np.array(image)
image = (image - mean) / std
image = image.astype(np.float32)

image = np.transpose(image, (2, 0, 1))
image = np.expand_dims(image, axis=0)
return image

class Estimator:


def __init__(self,**kwargs):
self.trained_ckpt_url=None


# TODO:save url
# example : https://www.mindspore.cn/doc/programming_guide/zh-CN/r1.0/train.html#id3
def train(self, train_data,base_model_url, trained_ckpt_url, valid_data=None,epochs=10, **kwargs):
network=mobilenet_v2_fine_tune(base_model_url).get_train_network()
network_opt=nn.Momentum(params=network.trainable_params(),learning_rate=0.01,momentum=0.9)
network_loss=CrossEntropySmooth(sparse=True, reduction="mean", smooth_factor=0.1, classes_num=2)
metrics = {"Accuracy" : nn.Accuracy()}
model=ms.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=metrics)
num_epochs = epochs
#best_ckpt_name=deploy_model_url.split(r"/")[-1]
#ckpt_dir=deploy_model_url.replace(best_ckpt_name, "")
model.train(num_epochs, train_data, callbacks=[ValAccMonitor(model, valid_data, num_epochs, save_best_ckpt=True, ckpt_directory=trained_ckpt_url), ms.TimeMonitor()])
self.trained_ckpt_url=trained_ckpt_url+"/best.ckpt"
# sedna will save model checkpoint in the path which is the value of MODEL_URL or MODEL_PATH
#ms.save_checkpoint(network, deploy_model_url)


def evaluate(self,data,model_path="",class_name="",input_shape=(224,224),**kwargs):
# load
network = mobilenet_v2_fine_tune(model_path).get_eval_network()
# eval
network_loss = CrossEntropySmooth(sparse=True,
reduction="mean",
smooth_factor=0.1,
classes_num=2)
model = ms.Model(network, loss_fn=network_loss, optimizer=None, metrics={'acc'})
acc=model.eval(data, dataset_sink_mode=False)
print(acc)
return acc


def predict(self, data,model, input_shape=None, **kwargs):
# load

# preprocess
preprocessed_data=preprocess(data)
# predict
pre=model.predict(ms.Tensor(preprocessed_data))
result=np.argmax(pre)
class_name={0:"Croissants", 1:"Dog"}
#print(class_name[result])
#return class_name[result]
return pre

def load(self, model_url):
pass

def save(self, model_path=None):
if not model_path:
return
#model_dir, model_name = os.path.split(model_path)
network = mobilenet_v2_fine_tune(self.trained_ckpt_url).get_eval_network()
ms.save_checkpoint(network, model_path)











+ 34
- 0
examples/incremental_learning/dog_croissants_classification/training/mobilenet_v2.py View File

@@ -0,0 +1,34 @@
import mindspore as ms
from mindvision.classification.models import mobilenet_v2
from mindvision.dataset import DownLoad



class mobilenet_v2_fine_tune:
# TODO: save model
def __init__(self, base_model_url=None):
models_download_url = "https://download.mindspore.cn/vision/classification/mobilenet_v2_1.0_224.ckpt"
dl=DownLoad()
if base_model_url==None:
dl.download_url(models_download_url)
else:
dl.download_url(models_download_url,filename=base_model_url )
self.network = mobilenet_v2(num_classes=2, resize=224)
#print("base_model_url == "+base_model_url)
self.param_dict = ms.load_checkpoint(base_model_url)


def get_train_network(self):
self.filter_list = [x.name for x in self.network.head.classifier.get_parameters()]
for key in list(self.param_dict.keys()):
for name in self.filter_list:
if name in key:
print("Delete parameter from checkpoint: ", key)
del self.param_dict[key]
break
ms.load_param_into_net(self.network, self.param_dict)
return self.network

def get_eval_network(self):
ms.load_param_into_net(self.network, self.param_dict)
return self.network

+ 67
- 0
examples/incremental_learning/dog_croissants_classification/training/train.py View File

@@ -0,0 +1,67 @@
# Copyright 2021 The KubeEdge Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

from sedna.common.config import Context, BaseConfig
from sedna.core.incremental_learning import IncrementalLearning

from interface import Estimator
from dataset import ImgDataset


def main():
# base_model_url means the low accuracy model
base_model_url=Context.get_parameters("base_model_url")
# model_url means the checkpoint file that has been trained
# model_url is used for estimator.save, not in train.
trained_ckpt_url = Context.get_parameters("model_url")
#read parameters from deployment config
input_shape=int(Context.get_parameters("input_shape"))
epochs=int(Context.get_parameters('epochs'))
batch_size=int(Context.get_parameters("batch_size"))
num_parallel_workers=int(Context.get_parameters("num_parallel_workers"))
print("num_parallel_workers="+str(num_parallel_workers))
# load dataset
train_dataset_url=os.path.dirname(Context.get_parameters("ORIGINAL_DATASET_URL"))+"/train"
valid_dataset_url=os.path.dirname(Context.get_parameters("ORIGINAL_DATASET_URL"))+"/val"
if train_dataset_url:
print("train_dataset_url " + train_dataset_url)
else:
print("train_dataset_url: NULL ")
if valid_dataset_url:
print("valid_dataset_urlL : " + valid_dataset_url)
else:
print("valid_dataset_url : NULL")


train_data = ImgDataset(data_type="train").parse(path=train_dataset_url,
train=True,
image_shape=input_shape,
batch_size=batch_size,
num_parallel_workers=num_parallel_workers)
valid_data=ImgDataset(data_type="eval").parse(path=valid_dataset_url,
train=False,
image_shape=input_shape,
batch_size=batch_size,
num_parallel_workers=num_parallel_workers)
incremental_instance = IncrementalLearning(estimator=Estimator)
return incremental_instance.train(train_data=train_data,
base_model_url=base_model_url,
trained_ckpt_url=trained_ckpt_url,
valid_data=valid_data,
epochs=1)

if __name__ == "__main__":
main()
print("train_phase_done")

+ 18
- 0
incremental-learning-dog-croissants-classification.Dockerfile View File

@@ -0,0 +1,18 @@
FROM mindspore/mindspore-cpu:1.7.1

COPY lib/requirements.txt /home
# install requirements of sedna lib
RUN pip install -r /home/requirements.txt
RUN pip install Pillow
RUN pip install numpy
RUN pip install mindvision

ENV PYTHONPATH "/home/lib"

WORKDIR /home/work
COPY lib /home/lib

COPY examples/incremental_learning/dog_croissants_classification/training /home/work/


ENTRYPOINT ["python"]

+ 27
- 1
lib/sedna/algorithms/hard_example_mining/hard_example_mining.py View File

@@ -16,7 +16,7 @@

import abc
import math
import random
from sedna.common.class_factory import ClassFactory, ClassType

__all__ = ('ThresholdFilter', 'CrossEntropyFilter', 'IBTFilter')
@@ -177,3 +177,29 @@ class IBTFilter(BaseFilter, abc.ABC):
if float(box_score) <= self.threshold_box]
return (len(confidence_score_list) / len(infer_result)
>= (1 - self.threshold_img))

@ClassFactory.register(ClassType.HEM, alias="Random")
class RandomFilter(BaseFilter):
"""judge a image is hard example or not randomly

Parameters
----------
random_ratio: int
value: between 0 and 1
with a model having very high accuracy like 98%, use this
function to define an input is hard example or not. just
a meaningless but needed function in sedna incremental learning
inference

Returns
-------
is hard sample: bool
`True` means hard sample, `False` means not.
"""
def __init__(self, random_ratio=0.3, **kwargs):
self.random_ratio=random_ratio

def __call__(self, *args, **kwargs):
if random.uniform(0,1) < self.random_ratio:
return True
return False

+ 4
- 1
lib/sedna/backend/base.py View File

@@ -25,6 +25,7 @@ class BackendBase:
self.framework = ""
self.estimator = estimator
self.use_cuda = True if kwargs.get("use_cuda") else False
self.use_npu = True if kwargs.get("use_npu") else False
self.fine_tune = fine_tune
self.model_save_path = kwargs.get("model_save_path") or "/tmp"
self.default_name = kwargs.get("model_name")
@@ -35,7 +36,9 @@ class BackendBase:
if self.default_name:
return self.default_name
model_postfix = {"pytorch": [".pth", ".pt"],
"keras": ".pb", "tensorflow": ".pb"}
"keras": ".pb",
"tensorflow": ".pb",
"mindspore": ".ckpt"}
continue_flag = "_finetune_" if self.fine_tune else ""
post_fix = model_postfix.get(self.framework, ".pkl")
return f"model{continue_flag}{self.framework}{post_fix}"


+ 73
- 0
lib/sedna/backend/mindspore/__init__.py View File

@@ -0,0 +1,73 @@
# Copyright 2021 The KubeEdge Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import mindspore.context as context
from sedna.backend.base import BackendBase
from sedna.common.file_ops import FileOps


class MSBackend(BackendBase):
def __init__(self, estimator, fine_tune=True, **kwargs):
super(MSBackend, self).__init__(estimator=estimator,
fine_tune=fine_tune,
**kwargs)
self.framework = "mindspore"
if self.use_npu:
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
elif self.use_cuda:
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
else:
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")

if callable(self.estimator):
self.estimator = self.estimator()

def train(self, train_data, valid_data=None, **kwargs):
if callable(self.estimator):
self.estimator = self.estimator()
if self.fine_tune and FileOps.exists(self.model_save_path):
self.finetune()
self.has_load = True
varkw = self.parse_kwargs(self.estimator.train, **kwargs)
return self.estimator.train(train_data=train_data,
valid_data=valid_data,
**varkw)

def predict(self, data, **kwargs):
if not self.has_load:
self.load()
varkw = self.parse_kwargs(self.estimator.predict, **kwargs)
return self.estimator.predict(data=data, **varkw)

def evaluate(self, data, **kwargs):
if not self.has_load:
self.load()
varkw = self.parse_kwargs(self.estimator.evaluate, **kwargs)
return self.estimator.evaluate(data, **varkw)

def finetune(self):
"""todo: no support yet"""

def load_weights(self):
model_path = FileOps.join_path(self.model_save_path, self.model_name)
if os.path.exists(model_path):
self.estimator.load_weights(model_path)

def get_weights(self):
"""todo: no support yet"""

def set_weights(self, weights):
"""todo: no support yet"""

Loading…
Cancel
Save