You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pretrained_vgg16.py 725 B

4 years ago
1234567891011121314151617181920212223242526272829
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. """VGG-16 for ImageNet using TL models."""
  4. import time
  5. import numpy as np
  6. import tensorflow as tf
  7. import tensorlayer as tl
  8. from examples.model_zoo.imagenet_classes import class_names
  9. from examples.model_zoo.vgg import vgg16
  10. tl.logging.set_verbosity(tl.logging.DEBUG)
  11. # get the whole model
  12. vgg = vgg16(pretrained=True)
  13. vgg.set_eval()
  14. img = tl.vis.read_image('data/tiger.jpeg')
  15. img = tl.prepro.imresize(img, (224, 224)).astype(np.float32) / 255
  16. start_time = time.time()
  17. output = vgg(img)
  18. probs = tf.nn.softmax(output)[0].numpy()
  19. print(" End time : %.5ss" % (time.time() - start_time))
  20. preds = (np.argsort(probs)[::-1])[0:5]
  21. for p in preds:
  22. print(class_names[p], probs[p])

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.