You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

array_ops.py 4.3 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. """A file containing functions related to array manipulation."""
  4. from tensorflow.python.eager import context
  5. from tensorflow.python.framework import constant_op, dtypes, ops, tensor_shape
  6. from tensorflow.python.framework.constant_op import constant
  7. from tensorflow.python.framework.ops import convert_to_tensor
  8. from tensorflow.python.ops.array_ops import shape_internal
  9. from tensorflow.python.ops.gen_array_ops import fill, reshape
  10. __all__ = ['alphas', 'alphas_like']
  11. def alphas(shape, alpha_value, name=None):
  12. """Creates a tensor with all elements set to `alpha_value`.
  13. This operation returns a tensor of type `dtype` with shape `shape` and all
  14. elements set to alpha.
  15. Parameters
  16. ----------
  17. shape: A list of integers, a tuple of integers, or a 1-D `Tensor` of type `int32`.
  18. The shape of the desired tensor
  19. alpha_value: `float32`, `float64`, `int8`, `uint8`, `int16`, `uint16`, int32`, `int64`
  20. The value used to fill the resulting `Tensor`.
  21. name: str
  22. A name for the operation (optional).
  23. Returns
  24. -------
  25. A `Tensor` with all elements set to alpha.
  26. Examples
  27. --------
  28. >>> tl.alphas([2, 3], tf.int32) # [[alpha, alpha, alpha], [alpha, alpha, alpha]]
  29. """
  30. with ops.name_scope(name, "alphas", [shape]) as name:
  31. alpha_tensor = convert_to_tensor(alpha_value)
  32. alpha_dtype = dtypes.as_dtype(alpha_tensor.dtype).base_dtype
  33. if not isinstance(shape, ops.Tensor):
  34. try:
  35. shape = constant_op._tensor_shape_tensor_conversion_function(tensor_shape.TensorShape(shape))
  36. except (TypeError, ValueError):
  37. shape = ops.convert_to_tensor(shape, dtype=dtypes.int32)
  38. if not shape._shape_tuple():
  39. shape = reshape(shape, [-1]) # Ensure it's a vector
  40. try:
  41. output = constant(alpha_value, shape=shape, dtype=alpha_dtype, name=name)
  42. except (TypeError, ValueError):
  43. output = fill(shape, constant(alpha_value, dtype=alpha_dtype), name=name)
  44. if output.dtype.base_dtype != alpha_dtype:
  45. raise AssertionError("Dtypes do not corresponds: %s and %s" % (output.dtype.base_dtype, alpha_dtype))
  46. return output
  47. def alphas_like(tensor, alpha_value, name=None, optimize=True):
  48. """Creates a tensor with all elements set to `alpha_value`.
  49. Given a single tensor (`tensor`), this operation returns a tensor of the same
  50. type and shape as `tensor` with all elements set to `alpha_value`.
  51. Parameters
  52. ----------
  53. tensor: tf.Tensor
  54. The Tensorflow Tensor that will be used as a template.
  55. alpha_value: `float32`, `float64`, `int8`, `uint8`, `int16`, `uint16`, int32`, `int64`
  56. The value used to fill the resulting `Tensor`.
  57. name: str
  58. A name for the operation (optional).
  59. optimize: bool
  60. if true, attempt to statically determine the shape of 'tensor' and encode it as a constant.
  61. Returns
  62. -------
  63. A `Tensor` with all elements set to `alpha_value`.
  64. Examples
  65. --------
  66. >>> tensor = tf.constant([[1, 2, 3], [4, 5, 6]])
  67. >>> tl.alphas_like(tensor, 0.5) # [[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]
  68. """
  69. with ops.name_scope(name, "alphas_like", [tensor]) as name:
  70. tensor = ops.convert_to_tensor(tensor, name="tensor")
  71. if context.in_eager_mode(): # and dtype is not None and dtype != tensor.dtype:
  72. ret = alphas(shape_internal(tensor, optimize=optimize), alpha_value=alpha_value, name=name)
  73. else: # if context.in_graph_mode():
  74. # For now, variant types must be created via zeros_like; as we need to
  75. # pass the input variant object to the proper zeros callback.
  76. if (optimize and tensor.shape.is_fully_defined()):
  77. # We can produce a zeros tensor independent of the value of 'tensor',
  78. # since the shape is known statically.
  79. ret = alphas(tensor.shape, alpha_value=alpha_value, name=name)
  80. # elif dtype is not None and dtype != tensor.dtype and dtype != dtypes.variant:
  81. else:
  82. ret = alphas(shape_internal(tensor, optimize=optimize), alpha_value=alpha_value, name=name)
  83. ret.set_shape(tensor.get_shape())
  84. return ret

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.