Browse Source

Add keras.layers.Reshape.

tags/v0.30
Oceania2018 4 years ago
parent
commit
905c4a9ee8
5 changed files with 64 additions and 3 deletions
  1. +7
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs
  2. +11
    -0
      src/TensorFlowNET.Keras/Layers/LayersApi.Reshaping.cs
  3. +2
    -2
      src/TensorFlowNET.Keras/Layers/LayersApi.cs
  4. +34
    -0
      src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs
  5. +10
    -1
      test/TensorFlowNET.UnitTest/Keras/Layers.Reshaping.Test.cs

+ 7
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs View File

@@ -0,0 +1,7 @@
namespace Tensorflow.Keras.ArgsDefinition
{
public class ReshapeArgs : LayerArgs
{
public TensorShape TargetShape { get; set; }
}
}

+ 11
- 0
src/TensorFlowNET.Keras/Layers/LayersApi.Reshaping.cs View File

@@ -34,5 +34,16 @@ namespace Tensorflow.Keras.Layers
{
Size = size ?? (2, 2)
});

/// <summary>
/// Layer that reshapes inputs into the given shape.
/// </summary>
/// <param name="target_shape"></param>
/// <returns></returns>
public Reshape Reshape(TensorShape target_shape)
=> new Reshape(new ReshapeArgs
{
TargetShape = target_shape
});
}
}

+ 2
- 2
src/TensorFlowNET.Keras/Layers/LayersApi.cs View File

@@ -372,8 +372,8 @@ namespace Tensorflow.Keras.Layers
InputShape = input_shape
});

public Add Add(params Tensor[] inputs)
=> new Add(new MergeArgs { Inputs = inputs });
public Add Add()
=> new Add(new MergeArgs { });

public GlobalAveragePooling2D GlobalAveragePooling2D()
=> new GlobalAveragePooling2D(new Pooling2DArgs { });


+ 34
- 0
src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs View File

@@ -0,0 +1,34 @@
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using static Tensorflow.KerasApi;
using static Tensorflow.Binding;
using System.Collections.Generic;
using System;

namespace Tensorflow.Keras.Layers
{
/// <summary>
/// Layer that reshapes inputs into the given shape.
/// </summary>
public class Reshape : Layer
{
ReshapeArgs args;
public Reshape(ReshapeArgs args)
: base(args)
{
this.args = args;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
{
var shape = new List<int> { inputs.shape[0] };
shape.AddRange(args.TargetShape.dims);

var result = array_ops.reshape(inputs, shape.ToArray());
if (!tf.Context.executing_eagerly())
// result = result.set_shape(compute_output_shape(inputs.shape));
throw new NotImplementedException("");
return result;
}
}
}

+ 10
- 1
test/TensorFlowNET.UnitTest/Keras/Layers.Reshaping.Test.cs View File

@@ -1,6 +1,6 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using Tensorflow;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace TensorFlowNET.UnitTest.Keras
@@ -26,5 +26,14 @@ namespace TensorFlowNET.UnitTest.Keras
var y = keras.layers.UpSampling2D(size: (1, 2)).Apply(x);
Assert.AreEqual((2, 2, 2, 3), y.shape);
}

[TestMethod]
public void Reshape()
{
var inputs = tf.zeros((10, 5, 20));
var outputs = keras.layers.LeakyReLU().Apply(inputs);
outputs = keras.layers.Reshape((20, 5)).Apply(outputs);
Assert.AreEqual((10, 20, 5), outputs.shape);
}
}
}

Loading…
Cancel
Save