|
- import time
- import multiprocessing
- import tensorflow as tf
- import os
- os.environ['TL_BACKEND'] = 'tensorflow'
-
- import tensorlayer as tl
- from .densenet import densenet
-
- tl.logging.set_verbosity(tl.logging.DEBUG)
- X_train, y_train, X_test, y_test = tl.files.load_cifar10_dataset(shape=(-1, 32, 32, 3), plotable=False)
-
- # get the network
- net = densenet("densenet-100")
-
- # training settings
- batch_size = 128
- n_epoch = 500
- learning_rate = 0.0001
- print_freq = 5
- n_step_epoch = int(len(y_train) / batch_size)
- n_step = n_epoch * n_step_epoch
- shuffle_buffer_size = 128
-
- train_weights = net.trainable_weights
- optimizer = tl.optimizers.Adam(learning_rate)
- metrics = tl.metric.Accuracy()
-
-
- def generator_train():
- inputs = X_train
- targets = y_train
- if len(inputs) != len(targets):
- raise AssertionError("The length of inputs and targets should be equal")
- for _input, _target in zip(inputs, targets):
- # yield _input.encode('utf-8'), _target.encode('utf-8')
- yield _input, _target
-
-
- def generator_test():
- inputs = X_test
- targets = y_test
- if len(inputs) != len(targets):
- raise AssertionError("The length of inputs and targets should be equal")
- for _input, _target in zip(inputs, targets):
- # yield _input.encode('utf-8'), _target.encode('utf-8')
- yield _input, _target
-
-
- def _map_fn_train(img, target):
- # 1. Randomly crop a [height, width] section of the image.
- img = tf.image.random_crop(img, [24, 24, 3])
- # 2. Randomly flip the image horizontally.
- img = tf.image.random_flip_left_right(img)
- # 3. Randomly change brightness.
- img = tf.image.random_brightness(img, max_delta=63)
- # 4. Randomly change contrast.
- img = tf.image.random_contrast(img, lower=0.2, upper=1.8)
- # 5. Subtract off the mean and divide by the variance of the pixels.
- img = tf.image.per_image_standardization(img)
- target = tf.reshape(target, ())
- return img, target
-
-
- def _map_fn_test(img, target):
- # 1. Crop the central [height, width] of the image.
- img = tf.image.resize_with_pad(img, 24, 24)
- # 2. Subtract off the mean and divide by the variance of the pixels.
- img = tf.image.per_image_standardization(img)
- img = tf.reshape(img, (24, 24, 3))
- target = tf.reshape(target, ())
- return img, target
-
-
- # dataset API and augmentation
- train_ds = tf.data.Dataset.from_generator(
- generator_train, output_types=(tf.float32, tf.int32)
- ) # , output_shapes=((24, 24, 3), (1)))
- train_ds = train_ds.map(_map_fn_train,num_parallel_calls=multiprocessing.cpu_count())
- # train_ds = train_ds.repeat(n_epoch)
- train_ds = train_ds.shuffle(shuffle_buffer_size)
- train_ds = train_ds.prefetch(buffer_size=4096)
- train_ds = train_ds.batch(batch_size)
- # value = train_ds.make_one_shot_iterator().get_next()
-
- test_ds = tf.data.Dataset.from_generator(
- generator_test, output_types=(tf.float32, tf.int32)
- ) # , output_shapes=((24, 24, 3), (1)))
- # test_ds = test_ds.shuffle(shuffle_buffer_size)
- test_ds = test_ds.map(_map_fn_test,num_parallel_calls=multiprocessing.cpu_count())
- # test_ds = test_ds.repeat(n_epoch)
- test_ds = test_ds.prefetch(buffer_size=4096)
- test_ds = test_ds.batch(batch_size)
- # value_test = test_ds.make_one_shot_iterator().get_next()
-
-
- class WithLoss(tl.layers.Module):
-
- def __init__(self, net, loss_fn):
- super(WithLoss, self).__init__()
- self._net = net
- self._loss_fn = loss_fn
-
- def forward(self, data, label):
- out = self._net(data)
- loss = self._loss_fn(out, label)
- return loss
-
-
- net_with_loss = WithLoss(net, loss_fn=tl.cost.softmax_cross_entropy_with_logits)
- net_with_train = tl.models.TrainOneStep(net_with_loss, optimizer, train_weights)
-
- for epoch in range(n_epoch):
- start_time = time.time()
- net.set_train()
- train_loss, train_acc, n_iter = 0, 0, 0
- for X_batch, y_batch in train_ds:
-
- X_batch = tl.ops.convert_to_tensor(X_batch.numpy(), dtype=tl.float32)
- y_batch = tl.ops.convert_to_tensor(y_batch.numpy(), dtype=tl.int64)
-
- _loss_ce = net_with_train(X_batch, y_batch)
- train_loss += _loss_ce
-
- n_iter += 1
- _logits = net(X_batch)
- metrics.update(_logits, y_batch)
- train_acc += metrics.result()
- metrics.reset()
- print("Epoch {} of {} took {}".format(epoch + 1, n_epoch, time.time() - start_time))
- print(" train loss: {}".format(train_loss / n_iter))
- print(" train acc: {}".format(train_acc / n_iter))
|