Browse Source

Add keras Resizing layer.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
76a342c91d
5 changed files with 87 additions and 10 deletions
  1. +14
    -10
      src/TensorFlowNET.Core/APIs/tf.random.cs
  2. +9
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/ResizingArgs.cs
  3. +30
    -0
      src/TensorFlowNET.Keras/Layers/Preprocessing/Resizing.cs
  4. +26
    -0
      src/TensorFlowNET.Keras/Preprocessings/Preprocessing.Resizing.cs
  5. +8
    -0
      test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs

+ 14
- 10
src/TensorFlowNET.Core/APIs/tf.random.cs View File

@@ -62,6 +62,19 @@ namespace Tensorflow
int? seed = null,
string name = null,
TF_DataType output_dtype = TF_DataType.DtInvalid) => random_ops.multinomial(logits, num_samples, seed: seed, name: name, output_dtype: output_dtype);

public Tensor uniform(TensorShape shape,
float minval = 0,
float maxval = 1,
TF_DataType dtype = TF_DataType.TF_FLOAT,
int? seed = null,
string name = null)
{
if (dtype.is_integer())
return random_ops.random_uniform_int(shape, (int)minval, (int)maxval, dtype, seed, name);
else
return random_ops.random_uniform(shape, minval, maxval, dtype, seed, name);
}
}

public Tensor random_uniform(TensorShape shape,
@@ -70,16 +83,7 @@ namespace Tensorflow
TF_DataType dtype = TF_DataType.TF_FLOAT,
int? seed = null,
string name = null)
{
if (dtype.is_integer())
{
return random_ops.random_uniform_int(shape, (int)minval, (int)maxval, dtype, seed, name);
}
else
{
return random_ops.random_uniform(shape, minval, maxval, dtype, seed, name);
}
}
=> random.uniform(shape, minval: minval, maxval: maxval, dtype: dtype, seed: seed, name: name);

public Tensor truncated_normal(TensorShape shape,
float mean = 0.0f,


+ 9
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/ResizingArgs.cs View File

@@ -0,0 +1,9 @@
namespace Tensorflow.Keras.ArgsDefinition
{
public class ResizingArgs : LayerArgs
{
public int Height { get; set; }
public int Width { get; set; }
public string Interpolation { get; set; } = "bilinear";
}
}

+ 30
- 0
src/TensorFlowNET.Keras/Layers/Preprocessing/Resizing.cs View File

@@ -0,0 +1,30 @@
using System;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;

namespace Tensorflow.Keras.Layers
{
/// <summary>
/// Resize the batched image input to target height and width.
/// The input should be a 4-D tensor in the format of NHWC.
/// </summary>
public class Resizing : Layer
{
ResizingArgs args;
public Resizing(ResizingArgs args) : base(args)
{
this.args = args;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
return image_ops_impl.resize_images_v2(inputs, new[] { args.Height, args.Width }, method: args.Interpolation);
}

public override TensorShape ComputeOutputShape(TensorShape input_shape)
{
return new TensorShape(input_shape.dims[0], args.Height, args.Width, input_shape.dims[3]);
}
}
}

+ 26
- 0
src/TensorFlowNET.Keras/Preprocessings/Preprocessing.Resizing.cs View File

@@ -0,0 +1,26 @@
using System;
using System.IO;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Layers;
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras
{
public partial class Preprocessing
{
/// <summary>
/// Image resizing layer
/// </summary>
/// <param name="height"></param>
/// <param name="width"></param>
/// <param name="interpolation"></param>
/// <returns></returns>
public Resizing Resizing(int height, int width, string interpolation = "bilinear")
=> new Resizing(new ResizingArgs
{
Height = height,
Width = width,
Interpolation = interpolation
});
}
}

+ 8
- 0
test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs View File

@@ -131,5 +131,13 @@ namespace TensorFlowNET.Keras.UnitTest
Assert.AreEqual((32, 4), output.shape);
}

[TestMethod]
public void Resizing()
{
var inputs = tf.random.uniform((10, 32, 32, 3));
var layer = keras.layers.preprocessing.Resizing(16, 16);
var output = layer.Apply(inputs);
Assert.AreEqual((10, 16, 16, 3), output.shape);
}
}
}

Loading…
Cancel
Save