#! /usr/bin/python # -*- coding: utf-8 -*- import tensorlayer as tl from tensorlayer import logging from tensorlayer.layers.core import Module __all__ = [ 'ExpandDims', 'Tile', ] class ExpandDims(Module): """ The :class:`ExpandDims` class inserts a dimension of 1 into a tensor's shape, see `tf.expand_dims() `__ . Parameters ---------- axis : int The dimension index at which to expand the shape of input. name : str A unique layer name. If None, a unique name will be automatically assigned. Examples -------- >>> x = tl.layers.Input([10, 3], name='in') >>> y = tl.layers.ExpandDims(axis=-1)(x) [10, 3, 1] """ def __init__( self, axis, name=None # 'expand_dims', ): super(ExpandDims, self).__init__(name) self.axis = axis self.build((None, )) self._built = True logging.info("ExpandDims %s: axis: %d" % (self.name, self.axis)) def __repr__(self): s = '{classname}(' s += 'axis={axis},' s += 'name={name}' s += ")" return s.format(classname=self.__class__.__name__, **self.__dict__) def build(self, inputs_shape): self.expand_dims = tl.ops.ExpandDims(axis=self.axis) # @tf.function def forward(self, inputs): outputs = self.expand_dims(inputs) return outputs class Tile(Module): """ The :class:`Tile` class constructs a tensor by tiling a given tensor, see `tf.tile() `__ . Parameters ---------- multiples: tensor Must be one of the following types: int32, int64. 1-D Length must be the same as the number of dimensions in input. name : None or str A unique layer name. Examples -------- >>> x = tl.layers.Input([10, 3], name='in') >>> y = tl.layers.Tile(multiples=[2, 3])(x) """ def __init__(self, multiples=None, name=None): #'tile'): super(Tile, self).__init__(name) self.multiples = multiples self.build((None, )) self._built = True logging.info("Tile %s: multiples: %s" % (self.name, self.multiples)) def __repr__(self): s = '{classname}(' s += 'multiples={multiples},' s += 'name={name}' s += ")" return s.format(classname=self.__class__.__name__, **self.__dict__) def build(self, inputs_shape): self.tile = tl.ops.Tile() # @tf.function def forward(self, inputs): outputs = self.tile(inputs, multiples=self.multiples) return outputs