Browse Source

Allow ComputeOutputShape to override in subclass #660.

tags/v0.30
Oceania2018 4 years ago
parent
commit
6d66092ed6
2 changed files with 7 additions and 3 deletions
  1. +5
    -1
      src/TensorFlowNET.Keras/Engine/Layer.Layers.cs
  2. +2
    -2
      src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs

+ 5
- 1
src/TensorFlowNET.Keras/Engine/Layer.Layers.cs View File

@@ -1,4 +1,5 @@
using System.Collections.Generic;
using System;
using System.Collections.Generic;

namespace Tensorflow.Keras.Engine
{
@@ -11,5 +12,8 @@ namespace Tensorflow.Keras.Engine
{
_layers.AddRange(layers);
}

public virtual TensorShape ComputeOutputShape(TensorShape input_shape)
=> throw new NotImplementedException("");
}
}

+ 2
- 2
src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs View File

@@ -26,11 +26,11 @@ namespace Tensorflow.Keras.Layers

var result = array_ops.reshape(inputs, shape.ToArray());
if (!tf.Context.executing_eagerly())
result.set_shape(compute_output_shape(inputs.shape));
result.set_shape(ComputeOutputShape(inputs.shape));
return result;
}

TensorShape compute_output_shape(TensorShape input_shape)
public override TensorShape ComputeOutputShape(TensorShape input_shape)
{
if (input_shape.dims[0] == -1)
{


Loading…
Cancel
Save