You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataflow_examples.py 1.8 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import tensorlayer as tl
  4. from tensorlayer.dataflow import Dataset
  5. import numpy as np
  6. X_train, y_train, X_test, y_test = tl.files.load_cifar10_dataset(shape=(-1, 32, 32, 3), plotable=False)
  7. def generator_train():
  8. inputs = X_train
  9. targets = y_train
  10. if len(inputs) != len(targets):
  11. raise AssertionError("The length of inputs and targets should be equal")
  12. for _input, _target in zip(inputs, targets):
  13. # yield _input.encode('utf-8'), _target.encode('utf-8')
  14. yield (_input, np.array(_target))
  15. batch_size = 128
  16. shuffle_buffer_size = 128
  17. n_epoch = 10
  18. import tensorflow as tf
  19. def _map_fn_train(img, target):
  20. # 1. Randomly crop a [height, width] section of the image.
  21. img = tf.image.random_crop(img, [24, 24, 3])
  22. # 2. Randomly flip the image horizontally.
  23. img = tf.image.random_flip_left_right(img)
  24. # 3. Randomly change brightness.
  25. img = tf.image.random_brightness(img, max_delta=63)
  26. # 4. Randomly change contrast.
  27. img = tf.image.random_contrast(img, lower=0.2, upper=1.8)
  28. # 5. Subtract off the mean and divide by the variance of the pixels.
  29. img = tf.image.per_image_standardization(img)
  30. target = tf.reshape(target, ())
  31. return img, target
  32. import multiprocessing
  33. train_ds = Dataset.from_generator(
  34. generator=generator_train, output_types=(tl.float32, tl.int32)
  35. ) # , output_shapes=((24, 24, 3), (1)))
  36. train_ds = train_ds.map(_map_fn_train, num_parallel_calls=multiprocessing.cpu_count())
  37. train_ds = train_ds.repeat(n_epoch)
  38. train_ds = train_ds.shuffle(shuffle_buffer_size)
  39. train_ds = train_ds.prefetch(buffer_size=4096)
  40. train_ds = train_ds.batch(batch_size)
  41. for X_batch, y_batch in train_ds:
  42. print(X_batch.shape, y_batch.shape)

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.