You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dropout.py 1.5 kB

4 years ago
4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import tensorlayer as tl
  4. from tensorlayer import logging
  5. from tensorlayer.layers.core import Module
  6. __all__ = [
  7. 'Dropout',
  8. ]
  9. class Dropout(Module):
  10. """
  11. The :class:`Dropout` class is a noise layer which randomly set some
  12. activations to zero according to a keeping probability.
  13. Parameters
  14. ----------
  15. keep : float
  16. The keeping probability.
  17. The lower the probability it is, the more activations are set to zero.
  18. seed : int or None
  19. The seed for random dropout.
  20. name : None or str
  21. A unique layer name.
  22. Examples
  23. --------
  24. >>> net = tl.layers.Input([10, 200])
  25. >>> net = tl.layers.Dropout(keep=0.2)(net)
  26. """
  27. def __init__(self, keep, seed=0, name=None): #"dropout"):
  28. super(Dropout, self).__init__(name)
  29. self.keep = keep
  30. self.seed = seed
  31. self.build()
  32. self._built = True
  33. logging.info("Dropout %s: keep: %f " % (self.name, self.keep))
  34. def __repr__(self):
  35. s = ('{classname}(keep={keep}')
  36. if self.name is not None:
  37. s += ', name=\'{name}\''
  38. s += ')'
  39. return s.format(classname=self.__class__.__name__, **self.__dict__)
  40. def build(self, inputs_shape=None):
  41. self.dropout = tl.ops.Dropout(keep=self.keep, seed=self.seed)
  42. # @tf.function
  43. def forward(self, inputs):
  44. if self.is_train:
  45. outputs = self.dropout(inputs)
  46. else:
  47. outputs = inputs
  48. return outputs

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.