From 6d66092ed641cb2a4931ec0dd7d32ed3f22645b1 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 6 Dec 2020 21:31:24 -0600 Subject: [PATCH] Allow ComputeOutputShape to override in subclass #660. --- src/TensorFlowNET.Keras/Engine/Layer.Layers.cs | 6 +++++- src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs b/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs index dcbfa1e6..ceb3afa4 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs @@ -1,4 +1,5 @@ -using System.Collections.Generic; +using System; +using System.Collections.Generic; namespace Tensorflow.Keras.Engine { @@ -11,5 +12,8 @@ namespace Tensorflow.Keras.Engine { _layers.AddRange(layers); } + + public virtual TensorShape ComputeOutputShape(TensorShape input_shape) + => throw new NotImplementedException(""); } } diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs index b358c719..28c7be3e 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs @@ -26,11 +26,11 @@ namespace Tensorflow.Keras.Layers var result = array_ops.reshape(inputs, shape.ToArray()); if (!tf.Context.executing_eagerly()) - result.set_shape(compute_output_shape(inputs.shape)); + result.set_shape(ComputeOutputShape(inputs.shape)); return result; } - TensorShape compute_output_shape(TensorShape input_shape) + public override TensorShape ComputeOutputShape(TensorShape input_shape) { if (input_shape.dims[0] == -1) {