You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pretrained_resnet50.py 829 B

4 years ago
4 years ago
1234567891011121314151617181920212223242526272829303132
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. """
  4. ResNet50 for ImageNet using TL models
  5. """
  6. import time
  7. import numpy as np
  8. import tensorlayer as tl
  9. from examples.model_zoo.imagenet_classes import class_names
  10. from examples.model_zoo.resnet import ResNet50
  11. tl.logging.set_verbosity(tl.logging.DEBUG)
  12. # get the whole model
  13. resnet = ResNet50(pretrained=True)
  14. resnet.set_eval()
  15. img1 = tl.vis.read_image('data/tiger.jpeg')
  16. img1 = tl.prepro.imresize(img1, (224, 224))[:, :, ::-1]
  17. img1 = img1 - np.array([103.939, 116.779, 123.68]).reshape((1, 1, 3))
  18. img1 = img1.astype(np.float32)[np.newaxis, ...]
  19. start_time = time.time()
  20. output = resnet(img1)
  21. prob = tl.ops.softmax(output)[0].numpy()
  22. print(" End time : %.5ss" % (time.time() - start_time))
  23. preds = (np.argsort(prob)[::-1])[0:5]
  24. for p in preds:
  25. print(class_names[p], prob[p])

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.