You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dropout.py 1.4 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import tensorlayer as tl
  4. from tensorlayer import logging
  5. from tensorlayer.layers.core import Module
  6. __all__ = [
  7. 'Dropout',
  8. ]
  9. class Dropout(Module):
  10. """
  11. The :class:`Dropout` class is a noise layer which randomly set some
  12. activations to zero according to a keeping probability.
  13. Parameters
  14. ----------
  15. keep : float
  16. The keeping probability.
  17. The lower the probability it is, the more activations are set to zero.
  18. seed : int or None
  19. The seed for random dropout.
  20. name : None or str
  21. A unique layer name.
  22. """
  23. def __init__(self, keep, seed=0, name=None): #"dropout"):
  24. super(Dropout, self).__init__(name)
  25. self.keep = keep
  26. self.seed = seed
  27. self.build()
  28. self._built = True
  29. logging.info("Dropout %s: keep: %f " % (self.name, self.keep))
  30. def __repr__(self):
  31. s = ('{classname}(keep={keep}')
  32. if self.name is not None:
  33. s += ', name=\'{name}\''
  34. s += ')'
  35. return s.format(classname=self.__class__.__name__, **self.__dict__)
  36. def build(self, inputs_shape=None):
  37. self.dropout = tl.ops.Dropout(keep=self.keep, seed=self.seed)
  38. # @tf.function
  39. def forward(self, inputs):
  40. if self.is_train:
  41. outputs = self.dropout(inputs)
  42. else:
  43. outputs = inputs
  44. return outputs

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.