Browse Source

Merge pull request #32 from jaypume/incremental_learning

modify the dataset format to relative path
tags/v0.1.0
KubeEdge Bot GitHub 4 years ago
parent
commit
f77563236f
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 42 deletions
  1. +14
    -2
      examples/federated_learning/surface_defect_detection/training_worker/train.py
  2. +8
    -9
      examples/incremental_learning/helmet_detection_incremental_train/training/train.py
  3. +13
    -31
      lib/sedna/dataset/dataset.py

+ 14
- 2
examples/federated_learning/surface_defect_detection/training_worker/train.py View File

@@ -16,13 +16,25 @@ import numpy as np
from tensorflow import keras

import sedna
from sedna.ml_model import save_model
from network import GlobalModelInspectionCNN
from sedna.ml_model import save_model


def image_process(line):
import keras.preprocessing.image as img_preprocessing
file_path, label = line.split(',')
img = img_preprocessing.load_img(file_path).resize((128, 128))
data = img_preprocessing.img_to_array(img) / 255.0
label = [0, 1] if int(label) == 0 else [1, 0]
data = np.array(data)
label = np.array(label)
return [data, label]


def main():
# load dataset.
train_data = sedna.load_train_dataset(data_format="txt", with_image=True)
train_data = sedna.load_train_dataset(data_format="txt",
preprocess_fun=image_process)

x = np.array([tup[0] for tup in train_data])
y = np.array([tup[1] for tup in train_data])


+ 8
- 9
examples/incremental_learning/helmet_detection_incremental_train/training/train.py View File

@@ -30,8 +30,7 @@ def main():
class_names = sedna.context.get_parameters("class_names")

# load dataset.
train_data = sedna.load_train_dataset(data_format='txt',
with_image=False)
train_data = sedna.load_train_dataset(data_format='txt')

# read parameters from deployment config.
obj_threshold = sedna.context.get_parameters("obj_threshold")
@@ -88,13 +87,13 @@ def main():
model = Interface()

sedna.incremental_learning.train(model=model,
train_data=train_data,
epochs=epochs,
batch_size=batch_size,
class_names=class_names,
input_shape=input_shape,
obj_threshold=obj_threshold,
nms_threshold=nms_threshold)
train_data=train_data,
epochs=epochs,
batch_size=batch_size,
class_names=class_names,
input_shape=input_shape,
obj_threshold=obj_threshold,
nms_threshold=nms_threshold)


if __name__ == '__main__':


+ 13
- 31
lib/sedna/dataset/dataset.py View File

@@ -18,71 +18,53 @@ choice 1: should be compatible with tensorflow.data.Dataset
choice 2: a high level Dataset object not compatible with tensorflow,
but it's unified in our framework.
"""

import fileinput
import logging
import os

import numpy as np

from sedna.common.config import BaseConfig

LOG = logging.getLogger(__name__)


def _load_dataset(dataset_url, format, **kwargs):
def _load_dataset(dataset_url, format, preprocess_fun=None, **kwargs):
if dataset_url is None:
LOG.warning(f'dataset_url is None, please check the url.')
return None
if format == 'txt':
LOG.info(
f"dataset format is txt, now loading txt from [{dataset_url}]")
if kwargs.get('with_image'):
return _load_txt_dataset_with_image(dataset_url)
samples = _load_txt_dataset(dataset_url)
if preprocess_fun:
new_samples = [preprocess_fun(s) for s in samples]
else:
return _load_txt_dataset(dataset_url)
new_samples = samples
return new_samples


def load_train_dataset(data_format, **kwargs):
def load_train_dataset(data_format, preprocess_fun=None, **kwargs):
"""
:param data_format: txt
:param kwargs:
:return: Dataset
"""
return _load_dataset(BaseConfig.train_dataset_url, data_format, **kwargs)
return _load_dataset(BaseConfig.train_dataset_url, data_format,
preprocess_fun, **kwargs)


def load_test_dataset(data_format, **kwargs):
def load_test_dataset(data_format, preprocess_fun=None, **kwargs):
"""
:param data_format: txt
:param kwargs:
:return: Dataset
"""
return _load_dataset(BaseConfig.test_dataset_url, data_format, **kwargs)
return _load_dataset(BaseConfig.test_dataset_url, data_format,
preprocess_fun, **kwargs)


def _load_txt_dataset(dataset_url):
LOG.info(f'dataset_url is {dataset_url}, now reading dataset_url')
root_path = BaseConfig.data_path_prefix
root_path = os.path.dirname(dataset_url)
with open(dataset_url) as f:
lines = f.readlines()
new_lines = [root_path + os.path.sep + l for l in lines]
return new_lines


def _load_txt_dataset_with_image(dataset_url):
import keras.preprocessing.image as img_preprocessing
root_path = os.path.dirname(dataset_url)
img_data = []
img_label = []
for line in fileinput.input(dataset_url):
file_path, label = line.split(',')
file_path = (file_path.replace("\\", os.path.sep)
.replace("/", os.path.sep))
file_path = os.path.join(root_path, file_path)
img = img_preprocessing.load_img(file_path).resize((128, 128))
img_data.append(img_preprocessing.img_to_array(img) / 255.0)
img_label += [(0, 1)] if int(label) == 0 else [(1, 0)]
data_set = [(np.array(line[0]), np.array(line[1]))
for line in zip(img_data, img_label)]
return data_set

Loading…
Cancel
Save