You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DenseNet121-ImageNet.py 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import time
  2. import multiprocessing
  3. import tensorflow as tf
  4. import os
  5. os.environ['TL_BACKEND'] = 'tensorflow'
  6. import tensorlayer as tl
  7. from .densenet import densenet
  8. tl.logging.set_verbosity(tl.logging.DEBUG)
  9. def load_ImageNet_dataset(shape=(-1, 256, 256, 3), plotable=False):
  10. '''已加载到本地的ImageNet数据集'''
  11. return X_train, y_train, X_test, y_test
  12. # get the network
  13. net = densenet("densenet-121")
  14. X_train, y_train, X_test, y_test = load_ImageNet_dataset(shape=(-1, 256, 256, 3), plotable=False)
  15. # training settings
  16. batch_size = 128
  17. n_epoch = 500
  18. learning_rate = 0.0001
  19. print_freq = 5
  20. n_step_epoch = int(len(y_train) / batch_size)
  21. n_step = n_epoch * n_step_epoch
  22. shuffle_buffer_size = 128
  23. train_weights = net.trainable_weights
  24. optimizer = tl.optimizers.Adam(learning_rate)
  25. metrics = tl.metric.Accuracy()
  26. def generator_train():
  27. inputs = X_train
  28. targets = y_train
  29. if len(inputs) != len(targets):
  30. raise AssertionError("The length of inputs and targets should be equal")
  31. for _input, _target in zip(inputs, targets):
  32. # yield _input.encode('utf-8'), _target.encode('utf-8')
  33. yield _input, _target
  34. def generator_test():
  35. inputs = X_test
  36. targets = y_test
  37. if len(inputs) != len(targets):
  38. raise AssertionError("The length of inputs and targets should be equal")
  39. for _input, _target in zip(inputs, targets):
  40. # yield _input.encode('utf-8'), _target.encode('utf-8')
  41. yield _input, _target
  42. def _map_fn_train(img, target):
  43. # 1. Randomly crop a [height, width] section of the image.
  44. img = tf.image.random_crop(img, [24, 24, 3])
  45. # 2. Randomly flip the image horizontally.
  46. img = tf.image.random_flip_left_right(img)
  47. # 3. Randomly change brightness.
  48. img = tf.image.random_brightness(img, max_delta=63)
  49. # 4. Randomly change contrast.
  50. img = tf.image.random_contrast(img, lower=0.2, upper=1.8)
  51. # 5. Subtract off the mean and divide by the variance of the pixels.
  52. img = tf.image.per_image_standardization(img)
  53. target = tf.reshape(target, ())
  54. return img, target
  55. def _map_fn_test(img, target):
  56. # 1. Crop the central [height, width] of the image.
  57. img = tf.image.resize_with_pad(img, 24, 24)
  58. # 2. Subtract off the mean and divide by the variance of the pixels.
  59. img = tf.image.per_image_standardization(img)
  60. img = tf.reshape(img, (24, 24, 3))
  61. target = tf.reshape(target, ())
  62. return img, target
  63. # dataset API and augmentation
  64. train_ds = tf.data.Dataset.from_generator(
  65. generator_train, output_types=(tf.float32, tf.int32)
  66. ) # , output_shapes=((24, 24, 3), (1)))
  67. train_ds = train_ds.map(_map_fn_train,num_parallel_calls=multiprocessing.cpu_count())
  68. # train_ds = train_ds.repeat(n_epoch)
  69. train_ds = train_ds.shuffle(shuffle_buffer_size)
  70. train_ds = train_ds.prefetch(buffer_size=4096)
  71. train_ds = train_ds.batch(batch_size)
  72. # value = train_ds.make_one_shot_iterator().get_next()
  73. test_ds = tf.data.Dataset.from_generator(
  74. generator_test, output_types=(tf.float32, tf.int32)
  75. ) # , output_shapes=((24, 24, 3), (1)))
  76. # test_ds = test_ds.shuffle(shuffle_buffer_size)
  77. test_ds = test_ds.map(_map_fn_test,num_parallel_calls=multiprocessing.cpu_count())
  78. # test_ds = test_ds.repeat(n_epoch)
  79. test_ds = test_ds.prefetch(buffer_size=4096)
  80. test_ds = test_ds.batch(batch_size)
  81. # value_test = test_ds.make_one_shot_iterator().get_next()
  82. class WithLoss(tl.layers.Module):
  83. def __init__(self, net, loss_fn):
  84. super(WithLoss, self).__init__()
  85. self._net = net
  86. self._loss_fn = loss_fn
  87. def forward(self, data, label):
  88. out = self._net(data)
  89. loss = self._loss_fn(out, label)
  90. return loss
  91. net_with_loss = WithLoss(net, loss_fn=tl.cost.softmax_cross_entropy_with_logits)
  92. net_with_train = tl.models.TrainOneStep(net_with_loss, optimizer, train_weights)
  93. for epoch in range(n_epoch):
  94. start_time = time.time()
  95. net.set_train()
  96. train_loss, train_acc, n_iter = 0, 0, 0
  97. for X_batch, y_batch in train_ds:
  98. X_batch = tl.ops.convert_to_tensor(X_batch.numpy(), dtype=tl.float32)
  99. y_batch = tl.ops.convert_to_tensor(y_batch.numpy(), dtype=tl.int64)
  100. _loss_ce = net_with_train(X_batch, y_batch)
  101. train_loss += _loss_ce
  102. n_iter += 1
  103. _logits = net(X_batch)
  104. metrics.update(_logits, y_batch)
  105. train_acc += metrics.result()
  106. metrics.reset()
  107. print("Epoch {} of {} took {}".format(epoch + 1, n_epoch, time.time() - start_time))
  108. print(" train loss: {}".format(train_loss / n_iter))
  109. print(" train acc: {}".format(train_acc / n_iter))

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.