diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs
index 38c75606..d417fa44 100644
--- a/src/TensorFlowNET.Keras/Engine/Layer.cs
+++ b/src/TensorFlowNET.Keras/Engine/Layer.cs
@@ -191,7 +191,7 @@ namespace Tensorflow.Keras.Engine
tf.Context.eager_mode(isFunc: tf.Context.is_build_function());
}
- build(inputs);
+ build(inputs.shape);
if (need_restore_mode)
tf.Context.restore_mode();
@@ -199,7 +199,7 @@ namespace Tensorflow.Keras.Engine
built = true;
}
- protected virtual void build(Tensors inputs)
+ public virtual void build(Shape input_shape)
{
built = true;
}
diff --git a/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs b/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs
index 3efda364..6e790a26 100644
--- a/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs
+++ b/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs
@@ -6,30 +6,38 @@ using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;
namespace Tensorflow.Keras.Layers {
- ///
- /// ELU Layer:
- /// x = 0 when x > 0, x = alpha( e^x-1 ) elsewhere
- ///
- public class ELU : Layer {
- ELUArgs args;
- float alpha => args.Alpha;
- public ELU ( ELUArgs args ) : base(args) {
- this.args = args;
- }
- protected override void build ( Tensors inputs ) {
- if ( alpha < 0f ) {
- throw new ValueError("Alpha must be a number greater than 0.");
- }
- built = true;
- }
- protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
- Tensor output = inputs;
- output = tf.where(output > 0f, output,
- tf.multiply(alpha, tf.sub(tf.exp(output), 1f)));
- return output;
- }
- public override Shape ComputeOutputShape ( Shape input_shape ) {
- return input_shape;
+ ///
+ /// ELU Layer:
+ /// x = 0 when x > 0, x = alpha( e^x-1 ) elsewhere
+ ///
+ public class ELU : Layer
+ {
+ ELUArgs args;
+ float alpha => args.Alpha;
+ public ELU(ELUArgs args) : base(args)
+ {
+ this.args = args;
+ }
+
+ public override void build(Shape input_shape)
+ {
+ if (alpha < 0f)
+ {
+ throw new ValueError("Alpha must be a number greater than 0.");
}
- }
+ built = true;
+ }
+
+ protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
+ {
+ Tensor output = inputs;
+ output = tf.where(output > 0f, output,
+ tf.multiply(alpha, tf.sub(tf.exp(output), 1f)));
+ return output;
+ }
+ public override Shape ComputeOutputShape(Shape input_shape)
+ {
+ return input_shape;
+ }
+ }
}
diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs b/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs
index aecb3da2..aba175de 100644
--- a/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs
+++ b/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs
@@ -6,19 +6,24 @@ using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;
namespace Tensorflow.Keras.Layers {
- public class Exponential : Layer {
- public Exponential ( LayerArgs args ) : base(args) {
- // Exponential has no args
- }
- protected override void build ( Tensors inputs ) {
- built = true;
- }
- protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
- Tensor output = inputs;
- return tf.exp(output);
- }
- public override Shape ComputeOutputShape ( Shape input_shape ) {
- return input_shape;
- }
- }
+ public class Exponential : Layer
+ {
+ public Exponential(LayerArgs args) : base(args)
+ {
+ // Exponential has no args
+ }
+ public override void build(Shape input_shape)
+ {
+ built = true;
+ }
+ protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
+ {
+ Tensor output = inputs;
+ return tf.exp(output);
+ }
+ public override Shape ComputeOutputShape(Shape input_shape)
+ {
+ return input_shape;
+ }
+ }
}
diff --git a/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs b/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs
index 388302da..b12d7dee 100644
--- a/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs
+++ b/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs
@@ -15,7 +15,7 @@ namespace Tensorflow.Keras.Layers {
public SELU ( LayerArgs args ) : base(args) {
// SELU has no arguments
}
- protected override void build ( Tensors inputs ) {
+ public override void build(Shape input_shape) {
if ( alpha < 0f ) {
throw new ValueError("Alpha must be a number greater than 0.");
}
diff --git a/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs b/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs
index 51a40b58..6f6dd7e8 100644
--- a/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs
+++ b/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs
@@ -90,9 +90,10 @@ namespace Tensorflow.Keras.Layers
}.Contains(this.score_mode))
throw new ValueError("Received: score_mode={score_mode}. Acceptable values are: [\"dot\", \"concat\"]");
}
-
+
// Creates variable when `use_scale` is True or `score_mode` is `concat`.
- protected override void build(Tensors inputs) {
+ public override void build(Shape input_shape)
+ {
if (this.use_scale)
this.scale = this.add_weight(name: "scale",
shape: 1,
@@ -110,7 +111,7 @@ namespace Tensorflow.Keras.Layers
trainable: true);
else
this.concat_score_weight = null;
- base.build(inputs);
+ base.build(input_shape);
}
///
diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs
index 9ef4db18..e0a337ca 100644
--- a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs
+++ b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs
@@ -29,9 +29,8 @@ namespace Tensorflow.Keras.Layers
}
- protected override void build(Tensors inputs)
+ public override void build(Shape input_shape)
{
- var input_shape = inputs.shape;
if (len(input_shape) != 4)
throw new ValueError($"Inputs should have rank 4. Received input shape: {input_shape}");
@@ -43,14 +42,12 @@ namespace Tensorflow.Keras.Layers
shape: kernel_shape,
initializer: kernel_initializer,
regularizer: kernel_regularizer,
- trainable: true,
- dtype: inputs.dtype);
+ trainable: true);
if (use_bias)
bias = add_weight(name: "bias",
shape: filters,
initializer: bias_initializer,
- trainable: true,
- dtype: inputs.dtype);
+ trainable: true);
built = true;
}
diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs
index 5ac2dd00..912a429b 100644
--- a/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs
+++ b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs
@@ -57,9 +57,8 @@ namespace Tensorflow.Keras.Layers
_tf_data_format = conv_utils.convert_data_format(data_format, rank + 2);
}
- protected override void build(Tensors inputs)
+ public override void build(Shape input_shape)
{
- Shape input_shape = inputs.shape;
int channel_axis = data_format == "channels_first" ? 1 : -1;
var input_channel = channel_axis < 0 ?
input_shape.dims[input_shape.ndim + channel_axis] :
diff --git a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs
index f3956811..e4c22745 100644
--- a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs
+++ b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs
@@ -41,9 +41,8 @@ namespace Tensorflow.Keras.Layers
this.inputSpec = new InputSpec(min_ndim: 2);
}
- protected override void build(Tensors inputs)
+ public override void build(Shape input_shape)
{
- Shape input_shape = inputs.shape;
var last_dim = input_shape.dims.Last();
var axes = new Dictionary();
axes[-1] = (int)last_dim;
diff --git a/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs b/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs
index 2bd987a7..0f387570 100644
--- a/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs
+++ b/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs
@@ -119,9 +119,8 @@ namespace Tensorflow.Keras.Layers
this.bias_constraint = args.BiasConstraint;
}
- protected override void build(Tensors inputs)
+ public override void build(Shape input_shape)
{
- var input_shape = inputs.shape;
var shape_data = _analyze_einsum_string(this.equation, this.bias_axes, input_shape, this.partial_output_shape);
var kernel_shape = shape_data.Item1;
var bias_shape = shape_data.Item2;
@@ -141,7 +140,7 @@ namespace Tensorflow.Keras.Layers
trainable: true);
else
this.bias = null;
- base.build(inputs);
+ base.build(input_shape);
}
public override Shape ComputeOutputShape(Shape input_shape)
diff --git a/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs b/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs
index f16fcfa6..79f4e5ce 100644
--- a/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs
+++ b/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs
@@ -54,7 +54,7 @@ namespace Tensorflow.Keras.Layers
SupportsMasking = mask_zero;
}
- protected override void build(Tensors inputs)
+ public override void build(Shape input_shape)
{
tf.Context.eager_mode();
embeddings = add_weight(shape: (input_dim, output_dim),
diff --git a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs
index cf71e184..45f5bf0f 100644
--- a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs
+++ b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs
@@ -2,49 +2,61 @@
using Tensorflow.Keras.Engine;
namespace Tensorflow.Keras.Layers {
- public class Cropping1D : Layer {
- CroppingArgs args;
- public Cropping1D ( CroppingArgs args ) : base(args) {
- this.args = args;
- }
+ public class Cropping1D : Layer
+ {
+ CroppingArgs args;
+ public Cropping1D(CroppingArgs args) : base(args)
+ {
+ this.args = args;
+ }
- protected override void build ( Tensors inputs ) {
- if ( args.cropping.rank != 1 ) {
- // throw an ValueError exception
- throw new ValueError("");
- }
- else if ( args.cropping.shape[0] > 2 || args.cropping.shape[0] < 1 ) {
- throw new ValueError("The `cropping` argument must be a tuple of 2 integers.");
- }
- built = true;
+ public override void build(Shape input_shape)
+ {
+ if (args.cropping.rank != 1)
+ {
+ // throw an ValueError exception
+ throw new ValueError("");
+ }
+ else if (args.cropping.shape[0] > 2 || args.cropping.shape[0] < 1)
+ {
+ throw new ValueError("The `cropping` argument must be a tuple of 2 integers.");
}
+ built = true;
+ }
- protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
- Tensor output = inputs;
- if ( output.rank != 3 ) {
- // throw an ValueError exception
- throw new ValueError("Expected dim=3, found dim=" + output.rank);
- }
- if ( args.cropping.shape[0] == 1 ) {
- int crop_start = args.cropping[0];
- output = output[new Slice(), new Slice(crop_start, ( int ) output.shape[1] - crop_start), new Slice()];
- }
- else {
- int crop_start = args.cropping[0], crop_end = args.cropping[1];
- output = output[new Slice(), new Slice(crop_start, ( int ) (output.shape[1]) - crop_end), new Slice()];
- }
- return output;
+ protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
+ {
+ Tensor output = inputs;
+ if (output.rank != 3)
+ {
+ // throw an ValueError exception
+ throw new ValueError("Expected dim=3, found dim=" + output.rank);
+ }
+ if (args.cropping.shape[0] == 1)
+ {
+ int crop_start = args.cropping[0];
+ output = output[new Slice(), new Slice(crop_start, (int)output.shape[1] - crop_start), new Slice()];
}
+ else
+ {
+ int crop_start = args.cropping[0], crop_end = args.cropping[1];
+ output = output[new Slice(), new Slice(crop_start, (int)(output.shape[1]) - crop_end), new Slice()];
+ }
+ return output;
+ }
- public override Shape ComputeOutputShape ( Shape input_shape ) {
- if ( args.cropping.shape[0] == 1 ) {
- int crop = args.cropping[0];
- return new Shape(( int ) (input_shape[0]), ( int ) (input_shape[1] - crop * 2), ( int ) (input_shape[2]));
- }
- else {
- int crop_start = args.cropping[0], crop_end = args.cropping[1];
- return new Shape(( int ) (input_shape[0]), ( int ) (input_shape[1] - crop_start - crop_end), ( int ) (input_shape[2]));
- }
+ public override Shape ComputeOutputShape(Shape input_shape)
+ {
+ if (args.cropping.shape[0] == 1)
+ {
+ int crop = args.cropping[0];
+ return new Shape((int)(input_shape[0]), (int)(input_shape[1] - crop * 2), (int)(input_shape[2]));
+ }
+ else
+ {
+ int crop_start = args.cropping[0], crop_end = args.cropping[1];
+ return new Shape((int)(input_shape[0]), (int)(input_shape[1] - crop_start - crop_end), (int)(input_shape[2]));
}
- }
+ }
+ }
}
diff --git a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs
index 340ba42d..6cb03e1e 100644
--- a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs
+++ b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs
@@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Layers {
public Cropping2D ( Cropping2DArgs args ) : base(args) {
this.args = args;
}
- protected override void build ( Tensors inputs ) {
+ public override void build(Shape input_shape) {
built = true;
}
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
diff --git a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs
index df102c1f..2d6751bf 100644
--- a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs
+++ b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs
@@ -11,7 +11,7 @@ namespace Tensorflow.Keras.Layers {
this.args = args;
}
- protected override void build ( Tensors inputs ) {
+ public override void build(Shape input_shape) {
built = true;
}
diff --git a/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs b/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs
index 676d5752..5f821760 100644
--- a/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs
+++ b/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs
@@ -23,7 +23,7 @@ namespace Tensorflow.Keras.Layers
this.args = args;
}
- protected override void build(Tensors inputs)
+ public override void build(Shape input_shape)
{
/*var shape_set = new HashSet();
var reduced_inputs_shapes = inputs.Select(x => x.shape).ToArray();
diff --git a/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs b/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs
index be8f574e..0363d58f 100644
--- a/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs
+++ b/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs
@@ -14,7 +14,7 @@ namespace Tensorflow.Keras.Layers
}
- protected override void build(Tensors inputs)
+ public override void build(Shape input_shape)
{
// output_shape = input_shape.dims[1^];
}
diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs
index da8e8c03..dac92f81 100644
--- a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs
+++ b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs
@@ -53,9 +53,8 @@ namespace Tensorflow.Keras.Layers
axis = args.Axis.dims.Select(x => (int)x).ToArray();
}
- protected override void build(Tensors inputs)
+ public override void build(Shape input_shape)
{
- Shape input_shape = inputs.shape;
var ndims = input_shape.ndim;
foreach (var (idx, x) in enumerate(axis))
if (x < 0)
diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs
index 51c6423c..5eebd735 100644
--- a/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs
+++ b/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs
@@ -49,9 +49,8 @@ namespace Tensorflow.Keras.Layers
axis = args.Axis.axis;
}
- protected override void build(Tensors inputs)
+ public override void build(Shape input_shape)
{
- Shape input_shape = inputs.shape;
var ndims = input_shape.ndim;
foreach (var (idx, x) in enumerate(axis))
if (x < 0)
diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs
index 6d37eaa1..4c52af9b 100644
--- a/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs
+++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs
@@ -35,14 +35,14 @@ namespace Tensorflow.Keras.Layers
var shape = data.output_shapes[0];
if (shape.ndim == 1)
data = data.map(tensor => array_ops.expand_dims(tensor, -1));
- build(data.variant_tensor);
+ build(data.variant_tensor.shape);
var preprocessed_inputs = data.map(_preprocess);
_index_lookup_layer.adapt(preprocessed_inputs);
}
- protected override void build(Tensors inputs)
+ public override void build(Shape input_shape)
{
- base.build(inputs);
+ base.build(input_shape);
}
Tensors _preprocess(Tensors inputs)
diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs
index 08089900..868506b6 100644
--- a/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs
+++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs
@@ -7,32 +7,39 @@ using static Tensorflow.Binding;
using Tensorflow.Keras.ArgsDefinition;
namespace Tensorflow.Keras.Layers {
- public class Permute : Layer {
- int[] dims, permute;
- public Permute ( PermuteArgs args ) : base(args) {
- this.dims = args.dims;
+ public class Permute : Layer
+ {
+ int[] dims, permute;
+ public Permute(PermuteArgs args) : base(args)
+ {
+ this.dims = args.dims;
+ }
+ public override void build(Shape input_shape)
+ {
+ var rank = input_shape.rank;
+ if (dims.Length != rank - 1)
+ {
+ throw new ValueError("Dimensions must match.");
}
- protected override void build ( Tensors inputs ) {
- var rank = inputs.rank;
- if ( dims.Length != rank - 1 ) {
- throw new ValueError("Dimensions must match.");
- }
- permute = new int[inputs.rank];
- dims.CopyTo(permute, 1);
- built = true;
+ permute = new int[input_shape.rank];
+ dims.CopyTo(permute, 1);
+ built = true;
+ }
+ protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
+ {
+ Tensor outputs = inputs;
+ return tf.transpose(outputs, new Axis(permute));
+ }
+ public override Shape ComputeOutputShape(Shape input_shape)
+ {
+ Shape output_shape = new Shape(input_shape.dims);
+ for (int i = 0; i < dims.Length; i += 1)
+ {
+ var d = dims[i];
+ var target_dim = input_shape[d];
+ output_shape[i + 1] = target_dim;
}
- protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
- Tensor outputs = inputs;
- return tf.transpose(outputs, new Axis(permute));
- }
- public override Shape ComputeOutputShape ( Shape input_shape ) {
- Shape output_shape = new Shape(input_shape.dims);
- for ( int i = 0; i < dims.Length; i += 1 ) {
- var d = dims[i];
- var target_dim = input_shape[d];
- output_shape[i + 1] = target_dim;
- }
- return output_shape;
- }
- }
+ return output_shape;
+ }
+ }
}
diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs
index 58b700fe..c8366ff4 100644
--- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs
+++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs
@@ -15,9 +15,8 @@ namespace Tensorflow.Keras.Layers.Rnn
this.args = args;
}
- protected override void build(Tensors inputs)
+ public override void build(Shape input_shape)
{
- var input_shape = inputs.shape;
var input_dim = input_shape[-1];
kernel = add_weight("kernel", (input_shape[-1], args.Units),
diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs
index de50c361..10b28e76 100644
--- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs
+++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs
@@ -13,7 +13,7 @@ namespace Tensorflow.Keras.Layers.Rnn
}
- protected override void build(Tensors inputs)
+ public override void build(Shape input_shape)
{
}