diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index 38c75606..d417fa44 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -191,7 +191,7 @@ namespace Tensorflow.Keras.Engine tf.Context.eager_mode(isFunc: tf.Context.is_build_function()); } - build(inputs); + build(inputs.shape); if (need_restore_mode) tf.Context.restore_mode(); @@ -199,7 +199,7 @@ namespace Tensorflow.Keras.Engine built = true; } - protected virtual void build(Tensors inputs) + public virtual void build(Shape input_shape) { built = true; } diff --git a/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs b/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs index 3efda364..6e790a26 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs @@ -6,30 +6,38 @@ using Tensorflow.Keras.Engine; using static Tensorflow.Binding; namespace Tensorflow.Keras.Layers { - /// - /// ELU Layer: - /// x = 0 when x > 0, x = alpha( e^x-1 ) elsewhere - /// - public class ELU : Layer { - ELUArgs args; - float alpha => args.Alpha; - public ELU ( ELUArgs args ) : base(args) { - this.args = args; - } - protected override void build ( Tensors inputs ) { - if ( alpha < 0f ) { - throw new ValueError("Alpha must be a number greater than 0."); - } - built = true; - } - protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { - Tensor output = inputs; - output = tf.where(output > 0f, output, - tf.multiply(alpha, tf.sub(tf.exp(output), 1f))); - return output; - } - public override Shape ComputeOutputShape ( Shape input_shape ) { - return input_shape; + /// + /// ELU Layer: + /// x = 0 when x > 0, x = alpha( e^x-1 ) elsewhere + /// + public class ELU : Layer + { + ELUArgs args; + float alpha => args.Alpha; + public ELU(ELUArgs args) : base(args) + { + this.args = args; + } + + public override void build(Shape input_shape) + { + if (alpha < 0f) + { + throw new ValueError("Alpha must be a number greater than 0."); } - } + built = true; + } + + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + { + Tensor output = inputs; + output = tf.where(output > 0f, output, + tf.multiply(alpha, tf.sub(tf.exp(output), 1f))); + return output; + } + public override Shape ComputeOutputShape(Shape input_shape) + { + return input_shape; + } + } } diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs b/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs index aecb3da2..aba175de 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs @@ -6,19 +6,24 @@ using Tensorflow.Keras.Engine; using static Tensorflow.Binding; namespace Tensorflow.Keras.Layers { - public class Exponential : Layer { - public Exponential ( LayerArgs args ) : base(args) { - // Exponential has no args - } - protected override void build ( Tensors inputs ) { - built = true; - } - protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { - Tensor output = inputs; - return tf.exp(output); - } - public override Shape ComputeOutputShape ( Shape input_shape ) { - return input_shape; - } - } + public class Exponential : Layer + { + public Exponential(LayerArgs args) : base(args) + { + // Exponential has no args + } + public override void build(Shape input_shape) + { + built = true; + } + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + { + Tensor output = inputs; + return tf.exp(output); + } + public override Shape ComputeOutputShape(Shape input_shape) + { + return input_shape; + } + } } diff --git a/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs b/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs index 388302da..b12d7dee 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs @@ -15,7 +15,7 @@ namespace Tensorflow.Keras.Layers { public SELU ( LayerArgs args ) : base(args) { // SELU has no arguments } - protected override void build ( Tensors inputs ) { + public override void build(Shape input_shape) { if ( alpha < 0f ) { throw new ValueError("Alpha must be a number greater than 0."); } diff --git a/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs b/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs index 51a40b58..6f6dd7e8 100644 --- a/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs +++ b/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs @@ -90,9 +90,10 @@ namespace Tensorflow.Keras.Layers }.Contains(this.score_mode)) throw new ValueError("Received: score_mode={score_mode}. Acceptable values are: [\"dot\", \"concat\"]"); } - + // Creates variable when `use_scale` is True or `score_mode` is `concat`. - protected override void build(Tensors inputs) { + public override void build(Shape input_shape) + { if (this.use_scale) this.scale = this.add_weight(name: "scale", shape: 1, @@ -110,7 +111,7 @@ namespace Tensorflow.Keras.Layers trainable: true); else this.concat_score_weight = null; - base.build(inputs); + base.build(input_shape); } /// diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs index 9ef4db18..e0a337ca 100644 --- a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs +++ b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs @@ -29,9 +29,8 @@ namespace Tensorflow.Keras.Layers } - protected override void build(Tensors inputs) + public override void build(Shape input_shape) { - var input_shape = inputs.shape; if (len(input_shape) != 4) throw new ValueError($"Inputs should have rank 4. Received input shape: {input_shape}"); @@ -43,14 +42,12 @@ namespace Tensorflow.Keras.Layers shape: kernel_shape, initializer: kernel_initializer, regularizer: kernel_regularizer, - trainable: true, - dtype: inputs.dtype); + trainable: true); if (use_bias) bias = add_weight(name: "bias", shape: filters, initializer: bias_initializer, - trainable: true, - dtype: inputs.dtype); + trainable: true); built = true; } diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs index 5ac2dd00..912a429b 100644 --- a/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs +++ b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs @@ -57,9 +57,8 @@ namespace Tensorflow.Keras.Layers _tf_data_format = conv_utils.convert_data_format(data_format, rank + 2); } - protected override void build(Tensors inputs) + public override void build(Shape input_shape) { - Shape input_shape = inputs.shape; int channel_axis = data_format == "channels_first" ? 1 : -1; var input_channel = channel_axis < 0 ? input_shape.dims[input_shape.ndim + channel_axis] : diff --git a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs index f3956811..e4c22745 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs @@ -41,9 +41,8 @@ namespace Tensorflow.Keras.Layers this.inputSpec = new InputSpec(min_ndim: 2); } - protected override void build(Tensors inputs) + public override void build(Shape input_shape) { - Shape input_shape = inputs.shape; var last_dim = input_shape.dims.Last(); var axes = new Dictionary(); axes[-1] = (int)last_dim; diff --git a/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs b/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs index 2bd987a7..0f387570 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs @@ -119,9 +119,8 @@ namespace Tensorflow.Keras.Layers this.bias_constraint = args.BiasConstraint; } - protected override void build(Tensors inputs) + public override void build(Shape input_shape) { - var input_shape = inputs.shape; var shape_data = _analyze_einsum_string(this.equation, this.bias_axes, input_shape, this.partial_output_shape); var kernel_shape = shape_data.Item1; var bias_shape = shape_data.Item2; @@ -141,7 +140,7 @@ namespace Tensorflow.Keras.Layers trainable: true); else this.bias = null; - base.build(inputs); + base.build(input_shape); } public override Shape ComputeOutputShape(Shape input_shape) diff --git a/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs b/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs index f16fcfa6..79f4e5ce 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs @@ -54,7 +54,7 @@ namespace Tensorflow.Keras.Layers SupportsMasking = mask_zero; } - protected override void build(Tensors inputs) + public override void build(Shape input_shape) { tf.Context.eager_mode(); embeddings = add_weight(shape: (input_dim, output_dim), diff --git a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs index cf71e184..45f5bf0f 100644 --- a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs +++ b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs @@ -2,49 +2,61 @@ using Tensorflow.Keras.Engine; namespace Tensorflow.Keras.Layers { - public class Cropping1D : Layer { - CroppingArgs args; - public Cropping1D ( CroppingArgs args ) : base(args) { - this.args = args; - } + public class Cropping1D : Layer + { + CroppingArgs args; + public Cropping1D(CroppingArgs args) : base(args) + { + this.args = args; + } - protected override void build ( Tensors inputs ) { - if ( args.cropping.rank != 1 ) { - // throw an ValueError exception - throw new ValueError(""); - } - else if ( args.cropping.shape[0] > 2 || args.cropping.shape[0] < 1 ) { - throw new ValueError("The `cropping` argument must be a tuple of 2 integers."); - } - built = true; + public override void build(Shape input_shape) + { + if (args.cropping.rank != 1) + { + // throw an ValueError exception + throw new ValueError(""); + } + else if (args.cropping.shape[0] > 2 || args.cropping.shape[0] < 1) + { + throw new ValueError("The `cropping` argument must be a tuple of 2 integers."); } + built = true; + } - protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { - Tensor output = inputs; - if ( output.rank != 3 ) { - // throw an ValueError exception - throw new ValueError("Expected dim=3, found dim=" + output.rank); - } - if ( args.cropping.shape[0] == 1 ) { - int crop_start = args.cropping[0]; - output = output[new Slice(), new Slice(crop_start, ( int ) output.shape[1] - crop_start), new Slice()]; - } - else { - int crop_start = args.cropping[0], crop_end = args.cropping[1]; - output = output[new Slice(), new Slice(crop_start, ( int ) (output.shape[1]) - crop_end), new Slice()]; - } - return output; + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + { + Tensor output = inputs; + if (output.rank != 3) + { + // throw an ValueError exception + throw new ValueError("Expected dim=3, found dim=" + output.rank); + } + if (args.cropping.shape[0] == 1) + { + int crop_start = args.cropping[0]; + output = output[new Slice(), new Slice(crop_start, (int)output.shape[1] - crop_start), new Slice()]; } + else + { + int crop_start = args.cropping[0], crop_end = args.cropping[1]; + output = output[new Slice(), new Slice(crop_start, (int)(output.shape[1]) - crop_end), new Slice()]; + } + return output; + } - public override Shape ComputeOutputShape ( Shape input_shape ) { - if ( args.cropping.shape[0] == 1 ) { - int crop = args.cropping[0]; - return new Shape(( int ) (input_shape[0]), ( int ) (input_shape[1] - crop * 2), ( int ) (input_shape[2])); - } - else { - int crop_start = args.cropping[0], crop_end = args.cropping[1]; - return new Shape(( int ) (input_shape[0]), ( int ) (input_shape[1] - crop_start - crop_end), ( int ) (input_shape[2])); - } + public override Shape ComputeOutputShape(Shape input_shape) + { + if (args.cropping.shape[0] == 1) + { + int crop = args.cropping[0]; + return new Shape((int)(input_shape[0]), (int)(input_shape[1] - crop * 2), (int)(input_shape[2])); + } + else + { + int crop_start = args.cropping[0], crop_end = args.cropping[1]; + return new Shape((int)(input_shape[0]), (int)(input_shape[1] - crop_start - crop_end), (int)(input_shape[2])); } - } + } + } } diff --git a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs index 340ba42d..6cb03e1e 100644 --- a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs @@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Layers { public Cropping2D ( Cropping2DArgs args ) : base(args) { this.args = args; } - protected override void build ( Tensors inputs ) { + public override void build(Shape input_shape) { built = true; } protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { diff --git a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs index df102c1f..2d6751bf 100644 --- a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs +++ b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs @@ -11,7 +11,7 @@ namespace Tensorflow.Keras.Layers { this.args = args; } - protected override void build ( Tensors inputs ) { + public override void build(Shape input_shape) { built = true; } diff --git a/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs b/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs index 676d5752..5f821760 100644 --- a/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs +++ b/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs @@ -23,7 +23,7 @@ namespace Tensorflow.Keras.Layers this.args = args; } - protected override void build(Tensors inputs) + public override void build(Shape input_shape) { /*var shape_set = new HashSet(); var reduced_inputs_shapes = inputs.Select(x => x.shape).ToArray(); diff --git a/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs b/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs index be8f574e..0363d58f 100644 --- a/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs +++ b/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs @@ -14,7 +14,7 @@ namespace Tensorflow.Keras.Layers } - protected override void build(Tensors inputs) + public override void build(Shape input_shape) { // output_shape = input_shape.dims[1^]; } diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs index da8e8c03..dac92f81 100644 --- a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs +++ b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs @@ -53,9 +53,8 @@ namespace Tensorflow.Keras.Layers axis = args.Axis.dims.Select(x => (int)x).ToArray(); } - protected override void build(Tensors inputs) + public override void build(Shape input_shape) { - Shape input_shape = inputs.shape; var ndims = input_shape.ndim; foreach (var (idx, x) in enumerate(axis)) if (x < 0) diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs index 51c6423c..5eebd735 100644 --- a/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs +++ b/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs @@ -49,9 +49,8 @@ namespace Tensorflow.Keras.Layers axis = args.Axis.axis; } - protected override void build(Tensors inputs) + public override void build(Shape input_shape) { - Shape input_shape = inputs.shape; var ndims = input_shape.ndim; foreach (var (idx, x) in enumerate(axis)) if (x < 0) diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs index 6d37eaa1..4c52af9b 100644 --- a/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs @@ -35,14 +35,14 @@ namespace Tensorflow.Keras.Layers var shape = data.output_shapes[0]; if (shape.ndim == 1) data = data.map(tensor => array_ops.expand_dims(tensor, -1)); - build(data.variant_tensor); + build(data.variant_tensor.shape); var preprocessed_inputs = data.map(_preprocess); _index_lookup_layer.adapt(preprocessed_inputs); } - protected override void build(Tensors inputs) + public override void build(Shape input_shape) { - base.build(inputs); + base.build(input_shape); } Tensors _preprocess(Tensors inputs) diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs index 08089900..868506b6 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs @@ -7,32 +7,39 @@ using static Tensorflow.Binding; using Tensorflow.Keras.ArgsDefinition; namespace Tensorflow.Keras.Layers { - public class Permute : Layer { - int[] dims, permute; - public Permute ( PermuteArgs args ) : base(args) { - this.dims = args.dims; + public class Permute : Layer + { + int[] dims, permute; + public Permute(PermuteArgs args) : base(args) + { + this.dims = args.dims; + } + public override void build(Shape input_shape) + { + var rank = input_shape.rank; + if (dims.Length != rank - 1) + { + throw new ValueError("Dimensions must match."); } - protected override void build ( Tensors inputs ) { - var rank = inputs.rank; - if ( dims.Length != rank - 1 ) { - throw new ValueError("Dimensions must match."); - } - permute = new int[inputs.rank]; - dims.CopyTo(permute, 1); - built = true; + permute = new int[input_shape.rank]; + dims.CopyTo(permute, 1); + built = true; + } + protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) + { + Tensor outputs = inputs; + return tf.transpose(outputs, new Axis(permute)); + } + public override Shape ComputeOutputShape(Shape input_shape) + { + Shape output_shape = new Shape(input_shape.dims); + for (int i = 0; i < dims.Length; i += 1) + { + var d = dims[i]; + var target_dim = input_shape[d]; + output_shape[i + 1] = target_dim; } - protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { - Tensor outputs = inputs; - return tf.transpose(outputs, new Axis(permute)); - } - public override Shape ComputeOutputShape ( Shape input_shape ) { - Shape output_shape = new Shape(input_shape.dims); - for ( int i = 0; i < dims.Length; i += 1 ) { - var d = dims[i]; - var target_dim = input_shape[d]; - output_shape[i + 1] = target_dim; - } - return output_shape; - } - } + return output_shape; + } + } } diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs index 58b700fe..c8366ff4 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs @@ -15,9 +15,8 @@ namespace Tensorflow.Keras.Layers.Rnn this.args = args; } - protected override void build(Tensors inputs) + public override void build(Shape input_shape) { - var input_shape = inputs.shape; var input_dim = input_shape[-1]; kernel = add_weight("kernel", (input_shape[-1], args.Units), diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs index de50c361..10b28e76 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs @@ -13,7 +13,7 @@ namespace Tensorflow.Keras.Layers.Rnn } - protected override void build(Tensors inputs) + public override void build(Shape input_shape) { }