Browse Source

Add Conv2DTranspose #735

tags/v0.40-tf2.4-tstring
Oceania2018 4 years ago
parent
commit
6a8665f111
35 changed files with 447 additions and 40 deletions
  1. +5
    -0
      src/TensorFlowNET.Core/Binding.Util.cs
  2. +10
    -1
      src/TensorFlowNET.Core/Framework/smart_module.cs
  3. +13
    -0
      src/TensorFlowNET.Core/Graphs/c_api.graph.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs
  5. +42
    -0
      src/TensorFlowNET.Core/Operations/nn_impl.py.cs
  6. +4
    -0
      src/TensorFlowNET.Core/Tensors/TensorShape.Equals.cs
  7. +25
    -0
      src/TensorFlowNET.Keras/BackendImpl.cs
  8. +2
    -2
      src/TensorFlowNET.Keras/Engine/Functional.cs
  9. +3
    -3
      src/TensorFlowNET.Keras/Engine/Layer.Apply.cs
  10. +1
    -1
      src/TensorFlowNET.Keras/Engine/Layer.cs
  11. +1
    -1
      src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
  12. +1
    -1
      src/TensorFlowNET.Keras/Engine/Model.Predict.cs
  13. +1
    -1
      src/TensorFlowNET.Keras/Engine/Model.Train.cs
  14. +4
    -5
      src/TensorFlowNET.Keras/Engine/Model.Training.cs
  15. +1
    -1
      src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs
  16. +150
    -0
      src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs
  17. +4
    -1
      src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs
  18. +1
    -1
      src/TensorFlowNET.Keras/Layers/Core/Dense.cs
  19. +1
    -1
      src/TensorFlowNET.Keras/Layers/Core/Embedding.cs
  20. +2
    -2
      src/TensorFlowNET.Keras/Layers/LSTM.cs
  21. +45
    -0
      src/TensorFlowNET.Keras/Layers/LayersApi.cs
  22. +1
    -1
      src/TensorFlowNET.Keras/Layers/Merging/Merge.cs
  23. +71
    -3
      src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs
  24. +1
    -1
      src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs
  25. +1
    -1
      src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs
  26. +5
    -2
      src/TensorFlowNET.Keras/Layers/Regularization/Dropout.cs
  27. +1
    -1
      src/TensorFlowNET.Keras/Layers/Rescaling/Rescaling.cs
  28. +1
    -1
      src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs
  29. +1
    -1
      src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs
  30. +1
    -1
      src/TensorFlowNET.Keras/Layers/Reshaping/UpSampling2D.cs
  31. +1
    -1
      src/TensorFlowNET.Keras/Layers/Reshaping/ZeroPadding2D.cs
  32. +1
    -1
      src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs
  33. +4
    -4
      src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
  34. +30
    -0
      src/TensorFlowNET.Keras/Utils/conv_utils.cs
  35. +11
    -0
      src/TensorFlowNET.Keras/Utils/tf_utils.cs

+ 5
- 0
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -149,6 +149,8 @@ namespace Tensorflow
return ndArray.ndim == 0 ? 1 : ndArray.shape[0];
case IEnumerable enumerable:
return enumerable.OfType<object>().Count();
case TensorShape arr:
return arr.ndim;
}
throw new NotImplementedException("len() not implemented for type: " + a.GetType());
}
@@ -156,6 +158,9 @@ namespace Tensorflow
public static float min(float a, float b)
=> Math.Min(a, b);

public static int max(int a, int b)
=> Math.Max(a, b);

public static T[] list<T>(IEnumerable<T> list)
=> list.ToArray();



+ 10
- 1
src/TensorFlowNET.Core/Framework/smart_module.cs View File

@@ -15,6 +15,8 @@
******************************************************************************/

using System;
using System.Linq;
using static Tensorflow.Binding;

namespace Tensorflow.Framework
{
@@ -52,7 +54,14 @@ namespace Tensorflow.Framework
{
var pred_value = tensor_util.constant_value(pred);
if (pred_value is null)
return pred.eval(new Session(pred.graph));
{
var result = range(pred.op.NumOutputs).Select(x => IntPtr.Zero).ToArray();
var evaluated = c_api.TF_TryEvaluateConstant(pred.graph, pred._as_tf_output(), result, tf.Status.Handle);
if (!evaluated || c_api.TF_GetCode(tf.Status.Handle) != TF_Code.TF_OK)
return null;
else
throw new NotImplementedException("");
}

return pred_value;
}


+ 13
- 0
src/TensorFlowNET.Core/Graphs/c_api.graph.cs View File

@@ -322,5 +322,18 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]

public static extern void TF_UpdateEdge(IntPtr graph, TF_Output new_src, TF_Input dst, SafeStatusHandle status);

/// <summary>
/// Attempts to evaluate `output`. This will only be possible if `output` doesn't
/// depend on any graph inputs (this function is safe to call if this isn't the
/// case though).
/// </summary>
/// <param name="graph"></param>
/// <param name="output"></param>
/// <param name="result"></param>
/// <param name="status"></param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern bool TF_TryEvaluateConstant(IntPtr graph, TF_Output output, IntPtr[] result, SafeStatusHandle status);
}
}

+ 1
- 1
src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs View File

@@ -50,6 +50,6 @@ namespace Tensorflow.Keras.Engine
}

public override string ToString()
=> $"min_ndim={min_ndim}, , axes={axes.Count}";
=> $"ndim={ndim}, min_ndim={min_ndim}, axes={axes.Count}";
}
}

+ 42
- 0
src/TensorFlowNET.Core/Operations/nn_impl.py.cs View File

@@ -21,6 +21,31 @@ namespace Tensorflow
{
public class nn_impl
{
public static Tensor conv2d_transpose(Tensor value = null,
IVariableV1 filter = null,
Tensor output_shape = null,
TensorShape strides = null,
string padding = "SAME",
string data_format = "NHWC",
string name = null,
TensorShape dilations = null)
{
if (dilations == null)
dilations = (1, 1, 1, 1);
return tf_with(ops.name_scope(name, "conv2d_transpose", new { value, filter, output_shape }), scope =>
{
return gen_nn_ops.conv2d_backprop_input(
input_sizes: output_shape,
filter: filter.AsTensor(),
out_backprop: value,
strides: strides,
padding: padding,
data_format: data_format,
dilations: dilations,
name: name);
});
}

/// <summary>
/// Normalizes along dimension `axis` using an L2 norm.
/// </summary>
@@ -83,6 +108,23 @@ namespace Tensorflow
});
}

public static Tensor batch_normalization(Tensor x,
Tensor mean,
Tensor variance,
Tensor offset,
Tensor scale,
float variance_epsilon = 0.001f,
string name = null)
{
return tf_with(ops.name_scope(name, "batchnorm", new { x, mean, variance, scale, offset }), scope =>
{
var inv = math_ops.rsqrt(variance + variance_epsilon);
inv *= scale;
return x * math_ops.cast(inv, x.dtype) + math_ops.cast(
offset == null ? (-mean * inv) : (offset - mean * inv), x.dtype);
});
}

/// <summary>
/// Batch normalization.
/// </summary>


+ 4
- 0
src/TensorFlowNET.Core/Tensors/TensorShape.Equals.cs View File

@@ -15,6 +15,10 @@ namespace Tensorflow
else if (rank != shape1.rank)
return false;
return Enumerable.SequenceEqual(shape1.dims, dims);
case int[] shape2:
if (rank != shape2.Length)
return false;
return Enumerable.SequenceEqual(dims, shape2);
default:
return false;
}


+ 25
- 0
src/TensorFlowNET.Keras/BackendImpl.cs View File

@@ -317,5 +317,30 @@ namespace Tensorflow.Keras

return array_ops.concat(tensors, axis);
}

public Tensor conv2d_transpose(Tensor x,
IVariableV1 kernel,
Tensor output_shape,
TensorShape strides = null,
string padding = "valid",
string data_format = null,
TensorShape dilation_rate = null)
{
var force_transpose = false;
if (data_format == "channels_first" && !dilation_rate.Equals(new[] { 1, 1 }))
force_transpose = true;
// x, tf_data_format = _preprocess_conv2d_input(x, data_format, force_transpose)
var tf_data_format = "NHWC";
padding = padding.ToUpper();
strides = new TensorShape(1, strides[0], strides[1], 1);
if (dilation_rate.Equals(new[] { 1, 1 }))
x = nn_impl.conv2d_transpose(x, kernel, output_shape, strides,
padding: padding,
data_format: tf_data_format);
else
throw new NotImplementedException("");

return x;
}
}
}

+ 2
- 2
src/TensorFlowNET.Keras/Engine/Functional.cs View File

@@ -301,9 +301,9 @@ namespace Tensorflow.Keras.Engine
nodes_in_decreasing_depth.append(node);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
return run_internal_graph(inputs, is_training);
return run_internal_graph(inputs, training.Value);
}

Tensors run_internal_graph(Tensors inputs, bool training = false, Tensors mask = null)


+ 3
- 3
src/TensorFlowNET.Keras/Engine/Layer.Apply.cs View File

@@ -10,9 +10,9 @@ namespace Tensorflow.Keras.Engine
/// </summary>
/// <param name="input"></param>
/// <param name="state"></param>
/// <param name="is_training"></param>
/// <param name="training"></param>
/// <returns></returns>
public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false)
public Tensors Apply(Tensors inputs, Tensor state = null, bool training = false)
{
callContext = callContext ?? new ThreadLocal<CallContext>()
{
@@ -38,7 +38,7 @@ namespace Tensorflow.Keras.Engine
if (!built)
MaybeBuild(inputs);

outputs = Call(inputs, state: state, is_training: is_training);
outputs = Call(inputs, state: state, training: training);

// memory leak
// _set_connectivity_metadata_(inputs, outputs);


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -155,7 +155,7 @@ namespace Tensorflow.Keras.Engine
/// <param name="state"></param>
/// <param name="is_training"></param>
/// <returns></returns>
protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
throw new NotImplementedException("");
}


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs View File

@@ -73,7 +73,7 @@ namespace Tensorflow.Keras.Engine
List<(string, Tensor)> test_step(Tensor x, Tensor y)
{
(x, y) = data_handler.DataAdapter.Expand1d(x, y);
var y_pred = Apply(x, is_training: false);
var y_pred = Apply(x, training: false);
var loss = compiled_loss.Call(y, y_pred);

compiled_metrics.update_state(y, y_pred);


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Model.Predict.cs View File

@@ -76,7 +76,7 @@ namespace Tensorflow.Keras.Engine

Tensors predict_step(Tensor data)
{
return Apply(data, is_training: false);
return Apply(data, training: false);
}
}
}

+ 1
- 1
src/TensorFlowNET.Keras/Engine/Model.Train.cs View File

@@ -25,7 +25,7 @@ namespace Tensorflow.Keras.Engine
{
(x, y) = data_handler.DataAdapter.Expand1d(x, y);
using var tape = tf.GradientTape();
var y_pred = Apply(x, is_training: true);
var y_pred = Apply(x, training: true);
var loss = compiled_loss.Call(y, y_pred);

// For custom training steps, users can just write:


+ 4
- 5
src/TensorFlowNET.Keras/Engine/Model.Training.cs View File

@@ -8,9 +8,10 @@ using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Engine
{
public partial class Model
public partial class Model
{
public List<(IVariableV1, NDArray)> load_weights(string filepath, bool by_name = false, bool skip_mismatch = false, object options = null)
List<(IVariableV1, NDArray)> LoadedWeights;
public void load_weights(string filepath, bool by_name = false, bool skip_mismatch = false, object options = null)
{
long fileId = Hdf5.OpenFile(filepath, true);

@@ -25,10 +26,8 @@ namespace Tensorflow.Keras.Engine
throw new NotImplementedException("");
else
{
var weights = hdf5_format.load_weights_from_hdf5_group(fileId, Layers);
LoadedWeights = hdf5_format.load_weights_from_hdf5_group(fileId, Layers);
Hdf5.CloseFile(fileId);
// return a reference to prevent GC collect Variable.
return weights;
}
}



+ 1
- 1
src/TensorFlowNET.Keras/Layers/Activation/LeakyReLu.cs View File

@@ -19,7 +19,7 @@ namespace Tensorflow.Keras.Layers
this.args = args;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
return tf.nn.leaky_relu(inputs, alpha: alpha);
}


+ 150
- 0
src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs View File

@@ -0,0 +1,150 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using static Tensorflow.Binding;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Utils;
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.Layers
{
public class Conv2DTranspose : Conv2D
{
public Conv2DTranspose(Conv2DArgs args) : base(args)
{

}

protected override void build(Tensors inputs)
{
var input_shape = inputs.shape;
if (len(input_shape) != 4)
throw new ValueError($"Inputs should have rank 4. Received input shape: {input_shape}");

var channel_axis = _get_channel_axis();
var input_dim = input_shape[-1];
var kernel_shape = new TensorShape(kernel_size[0], kernel_size[1], filters, input_dim);

kernel = add_weight(name: "kernel",
shape: kernel_shape,
initializer: kernel_initializer,
regularizer: kernel_regularizer,
trainable: true,
dtype: inputs.dtype);
if (use_bias)
bias = add_weight(name: "bias",
shape: filters,
initializer: bias_initializer,
trainable: true,
dtype: inputs.dtype);
built = true;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
var inputs_shape = array_ops.shape(inputs);
var batch_size = inputs_shape[0];
var (h_axis, w_axis) = (1, 2);
if (data_format == "channels_first")
(h_axis, w_axis) = (2, 3);
var (height, width) = (-1, -1);
if(inputs.shape.rank > -1)
{
var dims = inputs.shape.dims;
(height, width) = (dims[h_axis], dims[w_axis]);
}
var (kernel_h, kernel_w) = kernel_size;
var (stride_h, stride_w) = strides;

var (out_pad_h, out_pad_w) = (-1, -1);

// Infer the dynamic output shape:
var out_height = conv_utils.deconv_output_length(height,
kernel_h,
padding: padding,
output_padding: out_pad_h,
stride: stride_h,
dilation: dilation_rate[0]);

var out_width = conv_utils.deconv_output_length(width,
kernel_w,
padding: padding,
output_padding: out_pad_w,
stride: stride_w,
dilation: dilation_rate[1]);

Tensor output_shape_tensor;
if (data_format == "channels_first")
output_shape_tensor = array_ops.stack(new object[] { batch_size, filters, out_height, out_width });
else
output_shape_tensor = array_ops.stack(new object[] { batch_size, out_height, out_width, filters });

var outputs = keras.backend.conv2d_transpose(
inputs,
kernel,
output_shape_tensor,
strides: strides,
padding: padding,
data_format: data_format,
dilation_rate: dilation_rate);

if (!tf.Context.executing_eagerly())
{
var out_shape = ComputeOutputShape(inputs.shape);
outputs.set_shape(out_shape);
}

if (use_bias)
throw new NotImplementedException("");

if (activation != null)
return activation(outputs);

return outputs;
}

public override TensorShape ComputeOutputShape(TensorShape input_shape)
{
var output_shape = input_shape.dims;
var (c_axis, h_axis, w_axis) = (3, 1, 2);
if (data_format == "channels_first")
(c_axis, h_axis, w_axis) = (1, 2, 3);

var (kernel_h, kernel_w) = kernel_size;
var (stride_h, stride_w) = strides;

var (out_pad_h, out_pad_w) = (-1, -1);
output_shape[c_axis] = filters;
output_shape[h_axis] = conv_utils.deconv_output_length(
output_shape[h_axis],
kernel_h,
padding: padding,
output_padding: out_pad_h,
stride: stride_h,
dilation: dilation_rate[0]);
output_shape[w_axis] = conv_utils.deconv_output_length(
output_shape[w_axis],
kernel_w,
padding: padding,
output_padding: out_pad_w,
stride: stride_w,
dilation: dilation_rate[1]);

return new TensorShape(output_shape);
}
}
}

+ 4
- 1
src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs View File

@@ -99,7 +99,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = false)
{
var outputs = _convolution_op.Apply(inputs, kernel);
if (use_bias)
@@ -119,5 +119,8 @@ namespace Tensorflow.Keras.Layers

return outputs;
}

protected virtual int _get_channel_axis()
=> data_format == "channels_first" ? -1 - rank : -1;
}
}

+ 1
- 1
src/TensorFlowNET.Keras/Layers/Core/Dense.cs View File

@@ -65,7 +65,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
Tensor outputs = null;
var rank = inputs.rank;


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Core/Embedding.cs View File

@@ -62,7 +62,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
var dtype = inputs.dtype;
if (dtype != tf.int32 && dtype != tf.int64)


+ 2
- 2
src/TensorFlowNET.Keras/Layers/LSTM.cs View File

@@ -26,9 +26,9 @@ namespace Tensorflow.Keras.Layers
.ToArray();
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
return base.Call(inputs, state: state, is_training: is_training);
return base.Call(inputs, state: state, training: training);
}
}
}

+ 45
- 0
src/TensorFlowNET.Keras/Layers/LayersApi.cs View File

@@ -140,6 +140,51 @@ namespace Tensorflow.Keras.Layers
Activation = GetActivationByName(activation)
});

/// <summary>
/// Transposed convolution layer (sometimes called Deconvolution).
/// </summary>
/// <param name="filters"></param>
/// <param name="kernel_size"></param>
/// <param name="strides"></param>
/// <param name="padding"></param>
/// <param name="data_format"></param>
/// <param name="dilation_rate"></param>
/// <param name="activation"></param>
/// <param name="use_bias"></param>
/// <param name="kernel_initializer"></param>
/// <param name="bias_initializer"></param>
/// <param name="kernel_regularizer"></param>
/// <param name="bias_regularizer"></param>
/// <param name="activity_regularizer"></param>
/// <returns></returns>
public Conv2DTranspose Conv2DTranspose(int filters,
TensorShape kernel_size = null,
TensorShape strides = null,
string padding = "valid",
string data_format = null,
TensorShape dilation_rate = null,
string activation = null,
bool use_bias = true,
string kernel_initializer = null,
string bias_initializer = null,
string kernel_regularizer = null,
string bias_regularizer = null,
string activity_regularizer = null)
=> new Conv2DTranspose(new Conv2DArgs
{
Rank = 2,
Filters = filters,
KernelSize = kernel_size,
Strides = strides == null ? (1, 1) : strides,
Padding = padding,
DataFormat = data_format,
DilationRate = dilation_rate == null ? (1, 1) : dilation_rate,
UseBias = use_bias,
KernelInitializer = GetInitializerByName(kernel_initializer),
BiasInitializer = GetInitializerByName(bias_initializer),
Activation = GetActivationByName(activation)
});

public Dense Dense(int units,
Activation activation = null,
IInitializer kernel_initializer = null,


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Merging/Merge.cs View File

@@ -19,7 +19,7 @@ namespace Tensorflow.Keras.Layers
// output_shape = input_shape.dims[1^];
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
return _merge_function(inputs);
}


+ 71
- 3
src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs View File

@@ -36,6 +36,7 @@ namespace Tensorflow.Keras.Layers
bool fused;
int[] axis;
string _data_format;
TensorShape kernel_size;
IInitializer beta_initializer => args.BetaInitializer;
IInitializer gamma_initializer => args.GammaInitializer;
IInitializer moving_mean_initializer => args.MovingMeanInitializer;
@@ -120,10 +121,35 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool training = false)
public override TensorShape ComputeOutputShape(TensorShape input_shape)
{
return input_shape;
}

(Tensor, Tensor) _moments(Tensors inputs, int[] reduction_axes, bool keep_dims)
{
var (mean, variance) = _calculate_mean_and_var(inputs, reduction_axes, keep_dims);
if (_support_zero_size_input())
throw new NotImplementedException("");
return (mean, variance);
}

(Tensor, Tensor) _calculate_mean_and_var(Tensors inputs, int[] reduction_axes, bool keep_dims)
{
return nn_impl.moments(inputs, reduction_axes, keep_dims: keep_dims);
}

bool _support_zero_size_input()
{
return false;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
Tensor outputs = null;
var training_tensor = tf.logical_and(training, Trainable);
var training_tensor = training == null
? tf.placeholder(tf.@bool, TensorShape.Scalar)
: tf.logical_and(training.Value, Trainable);
if (fused)
{
// var training = tf.convert_to_tensor(training);
@@ -131,7 +157,49 @@ namespace Tensorflow.Keras.Layers
return outputs;
}

throw new NotImplementedException("BatchNormalization call");
var inputs_dtype = inputs.dtype.as_base_dtype();
var input_shape = inputs.shape;
var ndims = len(input_shape);
var reduction_axes = range(ndims).Where(x => !axis.Contains(x)).ToArray();

// Broadcasting only necessary for single-axis batch norm where the axis is
// not the last dimension
var broadcast_shape = range(ndims).Select(x => 1).ToArray();
broadcast_shape[axis[0]] = input_shape.dims[axis[0]];

var (scale, offset) = (gamma, beta);
var training_value = tf_utils.constant_value(training_tensor);

Tensor mean;
Tensor variance;
if (training_value.HasValue && training_value.Value == false)
{
(mean, variance) = (moving_mean.AsTensor(), moving_variance.AsTensor());
}
else
{
var keep_dims = len(axis) > 1;
(mean, variance) = _moments(inputs, reduction_axes, keep_dims: keep_dims);
mean = tf_utils.smart_cond(training_tensor,
() => new[] { mean },
() => new[] { ops.convert_to_tensor(moving_mean) }).FirstOrDefault();

variance = tf_utils.smart_cond(training_tensor,
() => new[] { variance },
() => new[] { ops.convert_to_tensor(moving_variance) }).FirstOrDefault();

var (new_mean, new_variance) = (mean, variance);
}

mean = math_ops.cast(mean, inputs.dtype);
variance = math_ops.cast(variance, inputs.dtype);
var offset_tensor = math_ops.cast(offset, inputs.dtype);
var scale_tensor = math_ops.cast(scale, inputs.dtype);
outputs = nn_impl.batch_normalization(inputs, mean, variance,
offset_tensor, scale_tensor, epsilon);
// If some components of the shape got lost due to adjustments, fix that.
outputs.set_shape(input_shape);
return outputs;
}

private Tensor _fused_batch_norm(Tensor inputs, Tensor training)


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Pooling/GlobalAveragePooling2D.cs View File

@@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Layers
{
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
if (data_format == "channels_last")
return math_ops.reduce_mean(inputs, new int[] { 1, 2 }, false);


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Pooling/Pooling2D.cs View File

@@ -36,7 +36,7 @@ namespace Tensorflow.Keras.Layers
input_spec = new InputSpec(ndim: 4);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
int[] pool_shape;
int[] strides;


+ 5
- 2
src/TensorFlowNET.Keras/Layers/Regularization/Dropout.cs View File

@@ -15,9 +15,12 @@ namespace Tensorflow.Keras.Layers
this.args = args;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
var output = tf_utils.smart_cond(is_training,
if (training == null)
training = false;

var output = tf_utils.smart_cond(training.Value,
() => tf.nn.dropout(inputs,
noise_shape: get_noise_shape(inputs),
seed: args.Seed,


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Rescaling/Rescaling.cs View File

@@ -17,7 +17,7 @@ namespace Tensorflow.Keras.Layers
this.args = args;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
scale = math_ops.cast(args.Scale, args.DType);
offset = math_ops.cast(args.Offset, args.DType);


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs View File

@@ -22,7 +22,7 @@ namespace Tensorflow.Keras.Layers
_channels_first = args.DataFormat == "channels_first";
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
if (_channels_first)
{


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs View File

@@ -19,7 +19,7 @@ namespace Tensorflow.Keras.Layers
this.args = args;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
var shapes = new List<object>();
shapes.Add(array_ops.shape(inputs)[0]);


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Reshaping/UpSampling2D.cs View File

@@ -24,7 +24,7 @@ namespace Tensorflow.Keras.Layers
inputSpec = new InputSpec(ndim: 4);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
return keras.backend.resize_images(inputs,
size[0], size[1],


+ 1
- 1
src/TensorFlowNET.Keras/Layers/Reshaping/ZeroPadding2D.cs View File

@@ -26,7 +26,7 @@ namespace Tensorflow.Keras.Layers
this.input_spec = new InputSpec(ndim: 4);
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
return keras.backend.spatial_2d_padding(inputs,
padding: padding,


+ 1
- 1
src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs View File

@@ -32,7 +32,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
if (tf.Context.executing_eagerly())
return _defun_call(inputs);


+ 4
- 4
src/TensorFlowNET.Keras/Tensorflow.Keras.csproj View File

@@ -55,10 +55,6 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
<PackageReference Include="SharpZipLib" Version="1.3.1" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\TensorFlowNET.Core\Tensorflow.Binding.csproj" />
</ItemGroup>

<ItemGroup>
<None Include="..\..\LICENSE">
<Pack>True</Pack>
@@ -70,4 +66,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
<Folder Include="Engine\Interfaces\" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\TensorFlowNET.Core\Tensorflow.Binding.csproj" />
</ItemGroup>

</Project>

+ 30
- 0
src/TensorFlowNET.Keras/Utils/conv_utils.cs View File

@@ -14,7 +14,9 @@
limitations under the License.
******************************************************************************/

using System;
using System.Linq;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Utils
{
@@ -63,5 +65,33 @@ namespace Tensorflow.Keras.Utils
return ImageDataFormat.channels_last.ToString();
return value.ToLower();
}

public static int deconv_output_length(int input_length,
int filter_size,
string padding,
int output_padding = -1,
int stride = 0,
int dilation = 1)
{
// Get the dilated kernel size
filter_size = filter_size + (filter_size - 1) * (dilation - 1);

// Infer length if output padding is None, else compute the exact length
int length = -1;
if (output_padding == -1)
{
if (padding == "valid")
length = input_length * stride + max(filter_size - stride, 0);
else if (padding == "full")
length = input_length * stride - (stride + filter_size - 2);
else if (padding == "same")
length = input_length * stride;
}
else
{
throw new NotImplementedException("");
}
return length;
}
}
}

+ 11
- 0
src/TensorFlowNET.Keras/Utils/tf_utils.cs View File

@@ -37,6 +37,17 @@ namespace Tensorflow.Keras.Utils
return true;
}

public static Tensor[] smart_cond<T>(IVariableV1 pred,
Func<T[]> true_fn = null,
Func<T[]> false_fn = null,
string name = null)
{
return control_flow_ops.cond(pred.AsTensor(),
true_fn: true_fn,
false_fn: false_fn,
name: name);
}

public static Tensor[] smart_cond<T>(Tensor pred,
Func<T[]> true_fn = null,
Func<T[]> false_fn = null,


Loading…
Cancel
Save