Browse Source

Align keras.Input with tensorflow python.

tags/v0.100.4-load-saved-model
Yaohui Liu Haiping 2 years ago
parent
commit
559d471407
4 changed files with 65 additions and 38 deletions
  1. +8
    -2
      src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
  2. +12
    -28
      src/TensorFlowNET.Keras/KerasInterface.cs
  3. +44
    -7
      src/TensorFlowNET.Keras/Layers/LayersApi.cs
  4. +1
    -1
      test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs

+ 8
- 2
src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs View File

@@ -1,4 +1,5 @@
using System;
using Tensorflow.Framework.Models;
using Tensorflow.NumPy;
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types;

@@ -133,11 +134,16 @@ namespace Tensorflow.Keras.Layers
public ILayer GlobalMaxPooling1D(string data_format = "channels_last");
public ILayer GlobalMaxPooling2D(string data_format = "channels_last");

public Tensors Input(Shape shape,
public Tensors Input(Shape shape = null,
int batch_size = -1,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,
bool sparse = false,
bool ragged = false);
Tensor tensor = null,
bool ragged = false,
TypeSpec type_spec = null,
Shape batch_input_shape = null,
Shape batch_shape = null);
public ILayer InputLayer(Shape input_shape,
string name = null,
bool sparse = false,


+ 12
- 28
src/TensorFlowNET.Keras/KerasInterface.cs View File

@@ -12,6 +12,7 @@ using Tensorflow.Keras.Models;
using Tensorflow.Keras.Optimizers;
using Tensorflow.Keras.Utils;
using System.Threading;
using Tensorflow.Framework.Models;

namespace Tensorflow.Keras
{
@@ -66,33 +67,16 @@ namespace Tensorflow.Keras
/// If set, the layer will not create a placeholder tensor.
/// </param>
/// <returns></returns>
public Tensor Input(Shape shape = null,
int batch_size = -1,
Shape batch_input_shape = null,
TF_DataType dtype = TF_DataType.DtInvalid,
string name = null,
bool sparse = false,
bool ragged = false,
Tensor tensor = null)
{
if (batch_input_shape != null)
shape = batch_input_shape.dims.Skip(1).ToArray();

var args = new InputLayerArgs
{
Name = name,
InputShape = shape,
BatchInputShape = batch_input_shape,
BatchSize = batch_size,
DType = dtype,
Sparse = sparse,
Ragged = ragged,
InputTensor = tensor
};

var layer = new InputLayer(args);

return layer.InboundNodes[0].Outputs;
}
public Tensors Input(Shape shape = null,
int batch_size = -1,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,
bool sparse = false,
Tensor tensor = null,
bool ragged = false,
TypeSpec type_spec = null,
Shape batch_input_shape = null,
Shape batch_shape = null) => keras.layers.Input(shape, batch_size, name,
dtype, sparse, tensor, ragged, type_spec, batch_input_shape, batch_shape);
}
}

+ 44
- 7
src/TensorFlowNET.Keras/Layers/LayersApi.cs View File

@@ -1,4 +1,5 @@
using System;
using Tensorflow.Framework.Models;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Core;
using Tensorflow.Keras.ArgsDefinition.Rnn;
@@ -471,20 +472,56 @@ namespace Tensorflow.Keras.Layers
/// In this case, values of 'None' in the 'shape' argument represent ragged dimensions. For more information about RaggedTensors, see this guide.
/// </param>
/// <returns>A tensor.</returns>
public Tensors Input(Shape shape,
public Tensors Input(Shape shape = null,
int batch_size = -1,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,
bool sparse = false,
bool ragged = false)
Tensor tensor = null,
bool ragged = false,
TypeSpec type_spec = null,
Shape batch_input_shape = null,
Shape batch_shape = null)
{
var input_layer = new InputLayer(new InputLayerArgs
if(sparse && ragged)
{
throw new ValueError("Cannot set both `sparse` and `ragged` to `true` in a Keras `Input`.");
}

InputLayerArgs input_layer_config = new()
{
InputShape = shape,
BatchSize= batch_size,
Name = name,
DType = dtype,
Sparse = sparse,
Ragged = ragged
});
Ragged = ragged,
InputTensor = tensor,
// skip the `type_spec`
};

if(shape is not null && batch_input_shape is not null)
{
throw new ValueError("Only provide the `shape` OR `batch_input_shape` argument "
+ "to Input, not both at the same time.");
}

if(batch_input_shape is null && shape is null && tensor is null && type_spec is null)
{
throw new ValueError("Please provide to Input a `shape` or a `tensor` or a `type_spec` argument. Note that " +
"`shape` does not include the batch dimension.");
}

if(batch_input_shape is not null)
{
shape = batch_input_shape["1:"];
input_layer_config.BatchInputShape = batch_input_shape;
}
else
{
input_layer_config.BatchSize = batch_size;
input_layer_config.InputShape = shape;
}

var input_layer = new InputLayer(input_layer_config);

return input_layer.InboundNodes[0].Outputs;
}


+ 1
- 1
test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs View File

@@ -158,7 +158,7 @@ namespace TensorFlowNET.Keras.UnitTest
var value = keras.Input(shape: (2, 8));
var mask_tensor = keras.Input(shape:(4, 2));
var attention_layer = keras.layers.MultiHeadAttention(num_heads: 2, key_dim: 2);
attention_layer.Apply(new[] { query, value, mask_tensor });
attention_layer.Apply(new Tensor[] { query, value, mask_tensor });

var from_data = 10 * np.random.randn(batch_size, 4, 8);
var to_data = 10 * np.random.randn(batch_size, 2, 8);


Loading…
Cancel
Save