Browse Source

Refactor convolutional layer.

tags/v0.30
Oceania2018 5 years ago
parent
commit
b63a44e1fe
46 changed files with 671 additions and 548 deletions
  1. +19
    -17
      src/TensorFlowNET.Core/APIs/tf.layers.cs
  2. +2
    -0
      src/TensorFlowNET.Core/APIs/tf.ops.cs
  3. +5
    -0
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
  4. +20
    -0
      src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
  5. +10
    -24
      src/TensorFlowNET.Core/Gradients/nn_grad.cs
  6. +24
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/BatchNormalizationArgs.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Conv2DArgs.cs
  8. +4
    -3
      src/TensorFlowNET.Core/Keras/ArgsDefinition/ConvolutionalArgs.cs
  9. +1
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs
  10. +2
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/ModelArgs.cs
  11. +8
    -14
      src/TensorFlowNET.Core/Keras/BackendImpl.cs
  12. +67
    -0
      src/TensorFlowNET.Core/Keras/Engine/Functional.cs
  13. +8
    -0
      src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs
  14. +7
    -3
      src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs
  15. +89
    -21
      src/TensorFlowNET.Core/Keras/Engine/Layer.cs
  16. +4
    -5
      src/TensorFlowNET.Core/Keras/Engine/Model.cs
  17. +3
    -6
      src/TensorFlowNET.Core/Keras/KerasApi.cs
  18. +50
    -78
      src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
  19. +3
    -4
      src/TensorFlowNET.Core/Keras/Layers/Conv2D.cs
  20. +16
    -13
      src/TensorFlowNET.Core/Keras/Layers/Convolutional.cs
  21. +2
    -3
      src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs
  22. +52
    -8
      src/TensorFlowNET.Core/Keras/Layers/LayersApi.cs
  23. +12
    -0
      src/TensorFlowNET.Core/Keras/Regularizers.cs
  24. +11
    -0
      src/TensorFlowNET.Core/Keras/Regularizers/IRegularizer.cs
  25. +21
    -0
      src/TensorFlowNET.Core/Keras/Regularizers/L2.cs
  26. +10
    -0
      src/TensorFlowNET.Core/Keras/Regularizers/RegularizerArgs.cs
  27. +11
    -12
      src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs
  28. +1
    -3
      src/TensorFlowNET.Core/Layers/Layer.cs
  29. +2
    -2
      src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs
  30. +0
    -84
      src/TensorFlowNET.Core/Operations/NnOps/Convolution.cs
  31. +100
    -0
      src/TensorFlowNET.Core/Operations/NnOps/ConvolutionInternal.cs
  32. +0
    -83
      src/TensorFlowNET.Core/Operations/NnOps/_NonAtrousConvolution.cs
  33. +0
    -76
      src/TensorFlowNET.Core/Operations/NnOps/_WithSpaceToBatch.cs
  34. +58
    -38
      src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
  35. +5
    -5
      src/TensorFlowNET.Core/Operations/gen_random_ops.cs
  36. +10
    -10
      src/TensorFlowNET.Core/Operations/nn_ops.cs
  37. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  38. +1
    -1
      src/TensorFlowNET.Core/Tensors/TensorShape.cs
  39. +13
    -4
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  40. +1
    -1
      src/TensorFlowNET.Core/Variables/IVariableV1.cs
  41. +2
    -2
      src/TensorFlowNET.Core/Variables/RefVariable.cs
  42. +2
    -2
      src/TensorFlowNET.Core/Variables/ResourceVariable.cs
  43. +5
    -20
      src/TensorFlowNET.Core/ops.cs
  44. +1
    -0
      src/TensorFlowNET.Core/ops.name_scope.cs
  45. +2
    -1
      test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs
  46. +5
    -0
      test/TensorFlowNET.UnitTest/ManagedAPI/BitwiseApiTest.cs

+ 19
- 17
src/TensorFlowNET.Core/APIs/tf.layers.cs View File

@@ -92,7 +92,7 @@ namespace Tensorflow
/// <param name="renorm"></param>
/// <param name="renorm_momentum"></param>
/// <returns></returns>
public Tensor batch_normalization(Tensor inputs,
public Tensors batch_normalization(Tensor inputs,
int axis = -1,
float momentum = 0.99f,
float epsilon = 0.001f,
@@ -108,22 +108,24 @@ namespace Tensorflow
bool renorm = false,
float renorm_momentum = 0.99f)
{
var layer = new BatchNormalization(
axis: axis,
momentum: momentum,
epsilon: epsilon,
center: center,
scale: scale,
beta_initializer: beta_initializer,
gamma_initializer: gamma_initializer,
moving_mean_initializer: moving_mean_initializer,
moving_variance_initializer: moving_variance_initializer,
renorm: renorm,
renorm_momentum: renorm_momentum,
trainable: trainable,
name: name);

return layer.apply(inputs, training: training).Item1;
var layer = new BatchNormalization(new BatchNormalizationArgs
{
Axis = axis,
Momentum = momentum,
Epsilon = epsilon,
Center = center,
Scale = scale,
BetaInitializer = beta_initializer,
GammaInitializer = gamma_initializer,
MovingMeanInitializer = moving_mean_initializer,
MovingVarianceInitializer = moving_variance_initializer,
Renorm = renorm,
RenormMomentum = renorm_momentum,
Trainable = trainable,
Name = name
});

return layer.Apply(inputs);
}

/// <summary>


+ 2
- 0
src/TensorFlowNET.Core/APIs/tf.ops.cs View File

@@ -41,6 +41,8 @@ namespace Tensorflow

/// <summary>
/// A context manager that lifts ops out of control-flow scopes and function-building graphs.
/// When eager execution is enabled, code inside an init_scope block runs with
/// eager execution enabled even when tracing a `tf.function`.
/// </summary>
public void init_scope()
=> ops.init_scope();


+ 5
- 0
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs View File

@@ -227,6 +227,11 @@ namespace Tensorflow.Eager
input_handle = input.EagerTensorHandle;
flattened_inputs.Add(input);
break;
case ResourceVariable variable:
var var_tensor = variable.AsTensor();
input_handle = var_tensor.EagerTensorHandle;
flattened_inputs.Add(var_tensor);
break;
default:
var tensor = tf.convert_to_tensor(inputs);
input_handle = tensor.EagerTensorHandle;


+ 20
- 0
src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs View File

@@ -57,6 +57,26 @@ namespace Tensorflow.Eager
return this;
}

/// <summary>
/// _create_substitute_placeholder
/// </summary>
/// <returns></returns>
public Tensor AsPlaceholder(string name = null)
{
Tensor placeholder = null;
tf_with(ops.control_dependencies(null), delegate
{
placeholder = tf.placeholder(dtype, shape: shape, name: name ?? this.name);
});
// custom_gradient.copy_handle_data(value, placeholder)
return placeholder;
}

void copy_handle_data()
{

}

public override IntPtr ToPointer()
=> EagerTensorHandle?.DangerousGetHandle() ?? IntPtr.Zero;



+ 10
- 24
src/TensorFlowNET.Core/Gradients/nn_grad.cs View File

@@ -138,30 +138,16 @@ namespace Tensorflow.Gradients
return new Tensor[]
{
gen_nn_ops.conv2d_backprop_input(new Conv2dParams
{
InputSizes = shape[0],
Filter = op.inputs[1],
OutBackProp = grads[0],
Dilations = dilations,
Strides = strides,
Padding = padding.ToString(),
ExplicitPaddings = explicit_paddings,
UseCudnnOnGpu = (bool)use_cudnn_on_gpu,
DataFormat = data_format.ToString(),
}),
gen_nn_ops.conv2d_backprop_filter(new Conv2dParams
{
Input = op.inputs[0],
FilterSizes = shape[1],
OutBackProp = grads[0],
Dilations = dilations,
Strides = strides,
Padding = padding.ToString(),
ExplicitPaddings = explicit_paddings,
UseCudnnOnGpu = (bool)use_cudnn_on_gpu,
DataFormat = data_format.ToString()
})
gen_nn_ops.conv2d_backprop_input(shape[0], op.inputs[1], grads[0],
strides, padding, use_cudnn_on_gpu, explicit_paddings,
dilations: dilations,
data_format: data_format),
gen_nn_ops.conv2d_backprop_filter(op.inputs[0], shape[1], grads[0],
strides, padding,
dilations: dilations,
explicit_paddings: explicit_paddings,
use_cudnn_on_gpu: use_cudnn_on_gpu,
data_format: data_format)
};
}



+ 24
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/BatchNormalizationArgs.cs View File

@@ -0,0 +1,24 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.ArgsDefinition
{
public class BatchNormalizationArgs : LayerArgs
{
public TensorShape Axis { get; set; } = -1;
public float Momentum { get; set; } = 0.99f;
public float Epsilon { get; set; } = 1e-3f;
public bool Center { get; set; } = true;
public bool Scale { get; set; } = true;
public IInitializer BetaInitializer { get; set; } = tf.zeros_initializer;
public IInitializer GammaInitializer { get; set; } = tf.ones_initializer;
public IInitializer MovingMeanInitializer { get; set; } = tf.zeros_initializer;
public IInitializer MovingVarianceInitializer { get; set; } = tf.ones_initializer;
public IRegularizer BetaRegularizer { get; set; }
public IRegularizer GammaRegularizer { get; set; }
public bool Renorm { get; set; }
public float RenormMomentum { get; set; } = 0.99f;
}
}

+ 1
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/Conv2DArgs.cs View File

@@ -4,7 +4,7 @@ using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class Conv2DArgs : ConvArgs
public class Conv2DArgs : ConvolutionalArgs
{
}


src/TensorFlowNET.Core/Keras/ArgsDefinition/ConvArgs.cs → src/TensorFlowNET.Core/Keras/ArgsDefinition/ConvolutionalArgs.cs View File

@@ -5,10 +5,11 @@ using static Tensorflow.Binding;

namespace Tensorflow.Keras.ArgsDefinition
{
public class ConvArgs : LayerArgs
public class ConvolutionalArgs : LayerArgs
{
public int Rank { get; set; } = 2;
public int Filters { get; set; }
public int NumSpatialDims { get; set; } = Unknown;
public TensorShape KernelSize { get; set; } = 5;

/// <summary>
@@ -24,8 +25,8 @@ namespace Tensorflow.Keras.ArgsDefinition
public bool UseBias { get; set; }
public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer;
public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer;
public IInitializer KernelRegularizer { get; set; }
public IInitializer BiasRegularizer { get; set; }
public IRegularizer KernelRegularizer { get; set; }
public IRegularizer BiasRegularizer { get; set; }
public Action KernelConstraint { get; set; }
public Action BiasConstraint { get; set; }
}

+ 1
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs View File

@@ -46,7 +46,7 @@ namespace Tensorflow.Keras.ArgsDefinition
/// <summary>
/// Regularizer function applied to the output of the layer(its "activation").
/// </summary>
public IInitializer ActivityRegularizer { get; set; }
public IRegularizer ActivityRegularizer { get; set; }

public bool Autocast { get; set; }
}


+ 2
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/ModelArgs.cs View File

@@ -6,7 +6,7 @@ namespace Tensorflow.Keras.ArgsDefinition
{
public class ModelArgs : LayerArgs
{
public Tensor[] Inputs { get; set; }
public Tensor[] Outputs { get; set; }
public Tensors Inputs { get; set; }
public Tensors Outputs { get; set; }
}
}

+ 8
- 14
src/TensorFlowNET.Core/Keras/BackendImpl.cs View File

@@ -42,7 +42,7 @@ namespace Tensorflow.Keras
/// for various layer names in each graph.
/// Allows to give unique autogenerated names to layers, in a graph-specific way.
/// </summary>
public Dictionary<Graph, Dictionary<(string, string), int>> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<(string, string), int>>();
public Dictionary<Graph, Dictionary<string, int>> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<string, int>>();
public Dictionary<string, IVariableV1> _GRAPH_VARIABLES = new Dictionary<string, IVariableV1>();
public Dictionary<string, Optimizer> _GRAPH_TF_OPTIMIZERS = new Dictionary<string, Optimizer>();

@@ -80,25 +80,19 @@ namespace Tensorflow.Keras
return ops.get_default_graph();
}

public int get_uid(string prefix, string @namespace = "")
public int get_uid(string prefix)
{
var graph = tf.get_default_graph();
if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph))
PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<(string, string), int>());
PER_GRAPH_LAYER_NAME_UIDS[graph][(@namespace, prefix)] += 1;
PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<string, int>());
if (!PER_GRAPH_LAYER_NAME_UIDS[graph].ContainsKey(prefix))
PER_GRAPH_LAYER_NAME_UIDS[graph][prefix] = 0;
PER_GRAPH_LAYER_NAME_UIDS[graph][prefix] += 1;

return PER_GRAPH_LAYER_NAME_UIDS[graph][(@namespace, prefix)];
return PER_GRAPH_LAYER_NAME_UIDS[graph][prefix];
}
public int get_uid((string, string) name)
{
var graph = tf.get_default_graph();
if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph))
PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<(string, string), int>());
PER_GRAPH_LAYER_NAME_UIDS[graph][(name)] += 1;

return PER_GRAPH_LAYER_NAME_UIDS[graph][name];
}
public void reset_uids() => PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<(string, string), int>>();
public void reset_uids() => PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<string, int>>();
public void clear_session()
{
ops.reset_default_graph();


+ 67
- 0
src/TensorFlowNET.Core/Keras/Engine/Functional.cs View File

@@ -0,0 +1,67 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;

namespace Tensorflow.Keras.Engine
{
/// <summary>
/// A `Functional` model is a `Model` defined as a directed graph of layers.
/// </summary>
public class Functional : Model
{
TensorShape _build_input_shape;
bool _compute_output_and_mask_jointly;
bool _expects_training_arg;
bool _expects_mask_arg;
bool _autocast;
List<Layer> _output_layers;
List<Layer> _input_layers;
List<KerasHistory> _input_coordinates;
List<KerasHistory> _output_coordinates;

public Functional(Tensors inputs, Tensors outputs)
: base(new ModelArgs
{
Inputs = inputs,
Outputs = outputs
})
{
_input_layers = new List<Layer>();
_output_layers = new List<Layer>();
_input_coordinates = new List<KerasHistory>();
_output_coordinates = new List<KerasHistory>();
_init_graph_network(inputs, outputs);
}

void _init_graph_network(Tensors inputs, Tensors outputs)
{
_is_graph_network = true;
this.inputs = inputs;
this.outputs = outputs;
built = true;
_build_input_shape = inputs.shape;
_compute_output_and_mask_jointly = true;
_expects_training_arg = true;
_expects_mask_arg = true;
// A graph network does not autocast inputs, as its layers will cast them instead.
_autocast = false;

// Build self._output_layers:
foreach(var x in outputs)
{
var (layer, node_index, tensor_index) = x.KerasHistory;
_output_layers.append(layer);
_output_coordinates.append(new KerasHistory(layer, node_index, tensor_index));
}

// Build self._input_layers:
foreach(var x in inputs)
{
var (layer, node_index, tensor_index) = x.KerasHistory;
_input_layers.append(layer);
_input_coordinates.append(new KerasHistory(layer, node_index, tensor_index));
}
}
}
}

+ 8
- 0
src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs View File

@@ -15,6 +15,7 @@
******************************************************************************/

using System.Collections.Generic;
using System.Linq;

namespace Tensorflow.Keras.Engine
{
@@ -27,6 +28,7 @@ namespace Tensorflow.Keras.Engine
public int? min_ndim;
Dictionary<int, int> axes;
TensorShape shape;
public int[] AllAxisDim;

public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid,
int? ndim = null,
@@ -42,6 +44,12 @@ namespace Tensorflow.Keras.Engine
this.shape = shape;
if (ndim == null && shape != null)
this.ndim = shape.ndim;

if(axes != null)
AllAxisDim = axes.Select(x => x.Value).ToArray();
}

public override string ToString()
=> $"min_ndim={min_ndim}, , axes={axes.Count}";
}
}

+ 7
- 3
src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs View File

@@ -20,10 +20,14 @@ namespace Tensorflow.Keras.Engine
this.tensor_index = tensor_index;
}

public void Deconstruct(out Layer layer, out int node_index, out int tensor_index)
{
layer = this.layer;
node_index = this.node_index;
tensor_index = this.tensor_index;
}

public static implicit operator Layer(KerasHistory history)
=> history.layer;

public static implicit operator (Layer, int, int)(KerasHistory history)
=> (history.layer, history.node_index, history.tensor_index);
}
}

+ 89
- 21
src/TensorFlowNET.Core/Keras/Engine/Layer.cs View File

@@ -72,10 +72,10 @@ namespace Tensorflow.Keras.Engine
protected List<IVariableV1> nonTrainableWeights;
public List<IVariableV1> non_trainable_variables => nonTrainableWeights;

string name;
protected string name;
protected string base_name;
public string Name => name;

protected string baseName;
protected bool computePreviousMask;
protected List<Operation> updates;
public TensorShape BatchInputShape => args.BatchInputShape;
@@ -98,9 +98,9 @@ namespace Tensorflow.Keras.Engine
// Indicates whether `build` needs to be called upon layer call, to create
// the layer's weights.
built = false;
this.SupportsMasking = false;
SupportsMasking = false;

_init_set_name(name);
_init_set_name(args.Name);
trainableWeights = new List<IVariableV1>();
nonTrainableWeights = new List<IVariableV1>();
computePreviousMask = false;
@@ -124,23 +124,25 @@ namespace Tensorflow.Keras.Engine
/// <returns></returns>
public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false)
{
Tensors outputs = null;

callContext = callContext ?? new ThreadLocal<CallContext>()
{
Value = new CallContext()
};

if (_in_functional_construction_mode(inputs))
return _functional_construction_call(inputs);

Tensors outputs = null;

var eager = tf.executing_eagerly();
using var ctxManager = CallContext.enter();

string nameScope = "";
if (eager)
nameScope = name;
nameScope = Name;
else
nameScope = _name_scope();

// using var graph = tf.keras.backend.get_graph().as_default();
if (!inputs.IsEagerTensor)
tf.Context.graph_mode();

@@ -162,6 +164,46 @@ namespace Tensorflow.Keras.Engine
return outputs;
}

bool _in_functional_construction_mode(Tensors inputs)
{
return inputs.Count(x => !x.IsEagerTensor) == inputs.Count();
}

Tensors _functional_construction_call(Tensors inputs)
{
bool mask_arg_passed_by_framework = false;
bool training_arg_passed_by_framework = false;
Tensor training_value = null;
if(training_value == null)
{
training_arg_passed_by_framework = true;
}

Tensors outputs = null;
using var ctxManager = CallContext.enter();

// using var graph = tf.keras.backend.get_graph().as_default();

if (!inputs.IsEagerTensor)
tf.Context.graph_mode();

tf_with(ops.name_scope(_name_scope()), scope =>
{
MaybeBuild(inputs);

outputs = call(inputs);

outputs = _set_connectivity_metadata_(inputs, outputs);
_handle_activity_regularization(inputs, outputs);
_set_mask_metadata(inputs, outputs, null);
});

if (!inputs.IsEagerTensor)
tf.Context.restore_mode();

return outputs;
}

private Tensors _set_connectivity_metadata_(Tensors inputs, Tensors outputs)
{
/*var returnOutputs = new List<Tensor>();
@@ -219,8 +261,12 @@ namespace Tensorflow.Keras.Engine
if (DType == TF_DataType.DtInvalid)
args.DType = inputs.dtype;

var input_shapes = inputs.shape;
build(input_shapes);
tf.init_scope();

//tf.Context.eager_mode();
build(inputs.shape);
//tf.Context.restore_mode();
built = true;
}

@@ -229,10 +275,16 @@ namespace Tensorflow.Keras.Engine
built = true;
}

protected virtual void add_loss(Func<Tensor> losses)
{
}

protected virtual IVariableV1 add_weight(string name,
TensorShape shape,
TF_DataType dtype = TF_DataType.DtInvalid,
IInitializer initializer = null,
IRegularizer regularizer = null,
bool? trainable = null,
Func<VariableArgs, IVariableV1> getter = null)
{
@@ -251,7 +303,7 @@ namespace Tensorflow.Keras.Engine
else if (dtype.is_integer())
initializer = tf.zeros_initializer;
else
throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {this.Name}");
throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}");
}

var args = new VariableArgs
@@ -266,6 +318,12 @@ namespace Tensorflow.Keras.Engine
};
var variable = _add_variable_with_custom_getter(args);

if(regularizer != null)
{
var name_in_scope = variable.Name.Split(':')[0];
_handle_weight_regularization(name_in_scope, variable, regularizer);
}

//backend.track_variable(variable);
if (trainable == true)
trainableWeights.Add(variable);
@@ -275,6 +333,20 @@ namespace Tensorflow.Keras.Engine
return variable;
}

/// <summary>
/// Create lambdas which compute regularization losses.
/// </summary>
/// <param name="name"></param>
/// <param name="variable"></param>
/// <param name="regularizer"></param>
void _handle_weight_regularization(string name, IVariableV1 variable, IRegularizer regularizer)
{
add_loss(() => regularizer.Apply(new RegularizerArgs
{
}));
}

protected virtual void add_update(Tensor[] updates, bool inputs = false)
{
var updates_op = updates.Select(x => x.op).ToArray();
@@ -284,17 +356,13 @@ namespace Tensorflow.Keras.Engine
// Determine layer name (non-unique).
protected virtual void _init_set_name(string name, bool zero_based = true)
{
var base_name = name;
base_name = name;
this.name = name;
if (name == null)
(this.name, baseName) = _make_unique_name();
}

protected virtual (string, string) _make_unique_name()
{
string base_name = generic_utils.to_snake_case(this.GetType().Name);
string name = base_layer_utils.unique_layer_name(base_name);
return (name, base_name);
{
base_name = generic_utils.to_snake_case(this.GetType().Name);
this.name = base_layer_utils.unique_layer_name(base_name, zero_based: zero_based);
}
}
}
}

+ 4
- 5
src/TensorFlowNET.Core/Keras/Engine/Model.cs View File

@@ -23,15 +23,14 @@ namespace Tensorflow.Keras.Engine
string loss;
IOptimizer optimizer;
IVariableV1 _steps_per_execution;
protected bool _is_graph_network;
protected Tensors inputs;
protected Tensors outputs;

public Model(ModelArgs args)
: base(args)
{
// Build _output_layers
/*foreach(var x in args.Outputs)
{
var layer = x.KerasHistory;
}*/
}

public void compile(string optimizerName, string lossName)


+ 3
- 6
src/TensorFlowNET.Core/Keras/KerasApi.cs View File

@@ -16,6 +16,7 @@ namespace Tensorflow
{
public KerasDataset datasets { get; } = new KerasDataset();
public Initializers initializers { get; } = new Initializers();
public Regularizers regularizers { get; } = new Regularizers();
public LayersApi layers { get; } = new LayersApi();
public LossesApi losses { get; } = new LossesApi();
public Activations activations { get; } = new Activations();
@@ -36,12 +37,8 @@ namespace Tensorflow
/// <param name="input"></param>
/// <param name="output"></param>
/// <returns></returns>
public Model Model(Tensor input, Tensor output)
=> new Model(new ModelArgs
{
Inputs = new[] { input },
Outputs = new[] { output }
});
public Functional Model(Tensors inputs, Tensors outputs)
=> new Functional(inputs, outputs);

/// <summary>
/// Instantiate a Keras tensor.


+ 50
- 78
src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs View File

@@ -15,73 +15,41 @@
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Layers
{
public class BatchNormalization : Tensorflow.Layers.Layer
public class BatchNormalization : Layer
{
#pragma warning disable CS0414 // The field 'BatchNormalization._USE_V2_BEHAVIOR' is assigned but its value is never used
private bool _USE_V2_BEHAVIOR = true;
#pragma warning restore CS0414 // The field 'BatchNormalization._USE_V2_BEHAVIOR' is assigned but its value is never used
private float momentum;
private float epsilon;
private bool center;
private bool scale;
private bool renorm;
private bool fused;
#pragma warning disable CS0414 // The field 'BatchNormalization._bessels_correction_test_only' is assigned but its value is never used
private bool _bessels_correction_test_only;
#pragma warning restore CS0414 // The field 'BatchNormalization._bessels_correction_test_only' is assigned but its value is never used
private int[] axis;
private string _data_format;
private IInitializer beta_initializer;
private IInitializer gamma_initializer;
private IInitializer moving_mean_initializer;
private IInitializer moving_variance_initializer;
private IVariableV1 gamma;
private IVariableV1 beta;
private RefVariable moving_mean;
private RefVariable moving_variance;

public BatchNormalization(int axis = -1,
float momentum = 0.99f,
float epsilon = 0.001f,
bool center = true,
bool scale = true,
IInitializer beta_initializer = null,
IInitializer gamma_initializer = null,
IInitializer moving_mean_initializer = null,
IInitializer moving_variance_initializer = null,
bool renorm = false,
float renorm_momentum = 0.99f,
bool trainable = true,
string name = null) : base(trainable: trainable,
name: name)
BatchNormalizationArgs args;

float momentum => args.Momentum;
float epsilon => args.Epsilon;
bool center => args.Center;
bool scale => args.Scale;
bool renorm => args.Renorm;
bool fused;
int[] axis;
string _data_format;
IInitializer beta_initializer => args.BetaInitializer;
IInitializer gamma_initializer => args.GammaInitializer;
IInitializer moving_mean_initializer;
IInitializer moving_variance_initializer;
IRegularizer gamma_regularizer => args.GammaRegularizer;
IVariableV1 gamma;
IVariableV1 beta;
IVariableV1 moving_mean;
IVariableV1 moving_variance;

public BatchNormalization(BatchNormalizationArgs args) : base(args)
{
this.axis = new int[] { axis };
this.momentum = momentum;
this.epsilon = epsilon;
this.center = center;
this.scale = scale;
if (beta_initializer == null)
beta_initializer = tf.zeros_initializer;
if (gamma_initializer == null)
gamma_initializer = tf.ones_initializer;
if (moving_mean_initializer == null)
moving_mean_initializer = tf.zeros_initializer;
if (moving_variance_initializer == null)
moving_variance_initializer = tf.ones_initializer;
this.beta_initializer = beta_initializer;
this.gamma_initializer = gamma_initializer;
this.moving_mean_initializer = moving_mean_initializer;
this.moving_variance_initializer = moving_variance_initializer;
this.renorm = renorm;
this.fused = true;
this.SupportsMasking = true;
this._bessels_correction_test_only = true;
this.args = args;
axis = args.Axis.dims;
}

protected override void build(TensorShape input_shape)
@@ -91,12 +59,25 @@ namespace Tensorflow.Keras.Layers
if (x < 0)
axis[idx] = ndims + x;

fused = ndims == 4;

if (fused)
if (Enumerable.SequenceEqual(axis, new int[] { 3 }))
{
if (Enumerable.SequenceEqual(axis, new int[] { 1 }))
_data_format = "NCHW";
else if (Enumerable.SequenceEqual(axis, new int[] { 3 }))
_data_format = "NHWC";
else
throw new ValueError($"Unsupported axis, fused batch norm only supports axis == [1] or axis == [3]");
}

var axis_to_dim = new Dictionary<int, int>();
foreach(var x in axis)
axis_to_dim[x] = input_shape[x];

inputSpec = new InputSpec(ndim: ndims, axes: axis_to_dim);
var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType;
var param_shape = new int[] { input_shape.dims[axis[0]] };
var param_shape = inputSpec.AllAxisDim;

if (scale)
gamma = add_weight("gamma",
@@ -116,26 +97,17 @@ namespace Tensorflow.Keras.Layers
else
throw new NotImplementedException("add_weight beta");

if(_scope != null)
{
}

moving_mean = (RefVariable)add_weight("moving_mean",
moving_mean = add_weight("moving_mean",
param_shape,
dtype: param_dtype,
initializer: moving_mean_initializer,
synchronization: VariableSynchronization.OnRead,
trainable: false,
aggregation: VariableAggregation.Mean);
trainable: false);

moving_variance = (RefVariable)add_weight("moving_variance",
moving_variance = add_weight("moving_variance",
shape: param_shape,
dtype: param_dtype,
initializer: moving_variance_initializer,
synchronization: VariableSynchronization.OnRead,
trainable: false,
aggregation: VariableAggregation.Mean);
trainable: false);

if (renorm)
throw new NotImplementedException("build when renorm is true");
@@ -178,8 +150,8 @@ namespace Tensorflow.Keras.Layers
inputs,
gamma,
beta,
mean: moving_mean,
variance: moving_variance,
mean: moving_mean.AsTensor(),
variance: moving_variance.AsTensor(),
epsilon: epsilon,
is_training: false,
data_format: _data_format);
@@ -202,8 +174,8 @@ namespace Tensorflow.Keras.Layers
if(training_value == null)
{
var mean_update = _assign_moving_average(moving_mean, mean, momentum_tensor);
var variance_update = _assign_moving_average(moving_variance, variance, momentum_tensor);
var mean_update = _assign_moving_average(moving_mean.AsTensor(), mean, momentum_tensor);
var variance_update = _assign_moving_average(moving_variance.AsTensor(), variance, momentum_tensor);
add_update(new Tensor[] { mean_update }, inputs: true);
add_update(new Tensor[] { variance_update }, inputs: true);
}


+ 3
- 4
src/TensorFlowNET.Core/Keras/Layers/Conv2D.cs View File

@@ -19,12 +19,11 @@ using Tensorflow.Operations.Activation;

namespace Tensorflow.Keras.Layers
{
public class Conv2D : Conv
public class Conv2D : Convolutional
{
public Conv2D(Conv2DArgs args)
: base(args)
public Conv2D(Conv2DArgs args) : base(args)
{
}
}
}

src/TensorFlowNET.Core/Keras/Layers/Conv.cs → src/TensorFlowNET.Core/Keras/Layers/Convolutional.cs View File

@@ -20,13 +20,13 @@ using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;
using Tensorflow.Operations;
using Tensorflow.Operations.Activation;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Layers
{
public class Conv : Layer
public class Convolutional : Layer
{
ConvArgs args;
ConvolutionalArgs args;
protected int rank => args.Rank;
protected int filters => args.Filters;
protected TensorShape kernel_size => args.KernelSize;
@@ -37,13 +37,14 @@ namespace Tensorflow.Keras.Layers
protected Activation activation => args.Activation;
protected bool use_bias => args.UseBias;
protected IInitializer kernel_initializer => args.KernelInitializer;
protected IRegularizer kernel_regularizer => args.KernelRegularizer;
protected IInitializer bias_initializer => args.BiasInitializer;
protected IVariableV1 kernel;
protected IVariableV1 bias;
protected Convolution _convolution_op;
string _tf_data_format;
ConvolutionInternal _convolution_op;
protected string _tf_data_format;

public Conv(ConvArgs args) : base(args)
public Convolutional(ConvolutionalArgs args) : base(args)
{
this.args = args;
args.KernelSize = conv_utils.normalize_tuple(args.KernelSize.dims, args.Rank, "kernel_size");
@@ -65,6 +66,7 @@ namespace Tensorflow.Keras.Layers
kernel = add_weight(name: "kernel",
shape: kernel_shape,
initializer: kernel_initializer,
regularizer: kernel_regularizer,
trainable: true,
dtype: DType);
if (use_bias)
@@ -76,7 +78,7 @@ namespace Tensorflow.Keras.Layers

var axes = new Dictionary<int, int>();
axes.Add(-1, input_channel);
inputSpec = new InputSpec(ndim: rank + 2, axes: axes);
inputSpec = new InputSpec(min_ndim: rank + 2, axes: axes);

string tf_padding;
if (padding == "causal")
@@ -84,20 +86,21 @@ namespace Tensorflow.Keras.Layers
else
tf_padding = padding.ToUpper();

_convolution_op = nn_ops.Convolution(input_shape,
kernel.shape,
tf_padding,
string tf_op_name = GetType().Name;
_convolution_op = nn_ops.convolution_internal(tf_padding,
strides,
dilation_rate,
data_format: _tf_data_format);
data_format: _tf_data_format,
name: tf_op_name);

built = true;
}

protected override Tensors call(Tensors inputs, Tensor state = null, bool training = false)
{
var outputs = _convolution_op.__call__(inputs, kernel);
var outputs = _convolution_op.Apply(inputs, kernel);
if (use_bias)
{
if (data_format == "channels_first")

+ 2
- 3
src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs View File

@@ -47,10 +47,10 @@ namespace Tensorflow.Keras.Layers
}

// moved to base class
if (string.IsNullOrEmpty(Name))
if (string.IsNullOrEmpty(args.Name))
{
var prefix = "input";
args.Name = prefix + '_' + tf.keras.backend.get_uid(prefix);
name = prefix + '_' + tf.keras.backend.get_uid(prefix);
}
if(args.DType == TF_DataType.DtInvalid)
@@ -91,7 +91,6 @@ namespace Tensorflow.Keras.Layers
// input_tensor._keras_mask = None
new Node(this, new NodeArgs
{
InputTensors = args.InputTensor,
Outputs = args.InputTensor
});



+ 52
- 8
src/TensorFlowNET.Core/Keras/Layers/LayersApi.cs View File

@@ -11,15 +11,35 @@ namespace Tensorflow.Keras.Layers
{
public Conv2D Conv2D(int filters,
TensorShape kernel_size = null,
TensorShape strides = null,
string padding = "valid",
string activation = "relu")
=> new Conv2D(new Conv2DArgs
{
Filters = filters,
KernelSize = kernel_size,
Padding = padding,
Activation = GetActivationByName(activation)
});
string data_format = null,
TensorShape dilation_rate = null,
int groups = 1,
string activation = null,
bool use_bias = true,
IInitializer kernel_initializer = null,
IInitializer bias_initializer = null,
IRegularizer kernel_regularizer = null,
IRegularizer bias_regularizer = null,
IRegularizer activity_regularizer = null)
=> new Conv2D(new Conv2DArgs
{
Rank = 2,
Filters = filters,
KernelSize = kernel_size,
Strides = strides == null ? (1, 1) : strides,
Padding = padding,
DataFormat = data_format,
DilationRate = dilation_rate == null ? (1, 1) : dilation_rate,
Groups = groups,
KernelRegularizer = kernel_regularizer,
KernelInitializer = kernel_initializer == null ? tf.glorot_uniform_initializer : kernel_initializer,
BiasInitializer = bias_initializer == null ? tf.zeros_initializer : bias_initializer,
BiasRegularizer = bias_regularizer,
ActivityRegularizer = activity_regularizer,
Activation = GetActivationByName(activation)
});


public Dense Dense(int units,
@@ -65,6 +85,30 @@ namespace Tensorflow.Keras.Layers
DataFormat = data_format
});

/// <summary>
/// `Input()` is used to instantiate a Keras tensor.
/// </summary>
/// <param name="shape">A shape tuple not including the batch size.</param>
/// <param name="name"></param>
/// <param name="sparse"></param>
/// <param name="ragged"></param>
/// <returns></returns>
public Tensors Input(TensorShape shape,
string name = null,
bool sparse = false,
bool ragged = false)
{
var input_layer = new InputLayer(new InputLayerArgs
{
InputShape = shape,
Name = name,
Sparse = sparse,
Ragged = ragged
});

return input_layer.InboundNodes[0].Outputs;
}

public MaxPooling2D MaxPooling2D(TensorShape pool_size = null,
TensorShape strides = null,
string padding = "valid")


+ 12
- 0
src/TensorFlowNET.Core/Keras/Regularizers.cs View File

@@ -0,0 +1,12 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras
{
public class Regularizers
{
public IRegularizer l2(float l2 = 0.01f)
=> new L2(l2);
}
}

+ 11
- 0
src/TensorFlowNET.Core/Keras/Regularizers/IRegularizer.cs View File

@@ -0,0 +1,11 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras
{
public interface IRegularizer
{
Tensor Apply(RegularizerArgs args);
}
}

+ 21
- 0
src/TensorFlowNET.Core/Keras/Regularizers/L2.cs View File

@@ -0,0 +1,21 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras
{
public class L2 : IRegularizer
{
float l2;

public L2(float l2 = 0.01f)
{
this.l2 = l2;
}

public Tensor Apply(RegularizerArgs args)
{
throw new NotImplementedException();
}
}
}

+ 10
- 0
src/TensorFlowNET.Core/Keras/Regularizers/RegularizerArgs.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras
{
public class RegularizerArgs
{
}
}

+ 11
- 12
src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs View File

@@ -55,8 +55,8 @@ namespace Tensorflow.Keras.Utils
/// </summary>
/// <param name="name"></param>
/// <returns></returns>
public static string unique_layer_name(string name, Dictionary<(string, string), int> name_uid_map = null,
string[] avoid_names = null, string @namespace = "", bool zero_based = false)
public static string unique_layer_name(string name, Dictionary<string, int> name_uid_map = null,
string[] avoid_names = null, bool zero_based = false)
{
if (name_uid_map == null)
name_uid_map = get_default_graph_uid_map();
@@ -66,41 +66,40 @@ namespace Tensorflow.Keras.Utils
string proposed_name = null;
while (proposed_name == null || avoid_names.Contains(proposed_name))
{
var name_key = (@namespace, name);
if (!name_uid_map.ContainsKey(name_key))
name_uid_map[name_key] = 0;
if (!name_uid_map.ContainsKey(name))
name_uid_map[name] = 0;

if (zero_based)
{
int number = name_uid_map[name_key];
int number = name_uid_map[name];
if (number > 0)
proposed_name = $"{name}_{number}";
else
proposed_name = name;

name_uid_map[name_key] += 1;
name_uid_map[name] += 1;
}
else
{
name_uid_map[name_key] += 1;
proposed_name = $"{name}_{name_uid_map[name_key]}";
name_uid_map[name] += 1;
proposed_name = $"{name}_{name_uid_map[name]}";
}
}

return proposed_name;
}

public static Dictionary<(string, string), int> get_default_graph_uid_map()
public static Dictionary<string, int> get_default_graph_uid_map()
{
var graph = ops.get_default_graph();
Dictionary<(string, string), int> name_uid_map = null;
Dictionary<string, int> name_uid_map = null;
if (tf.keras.backend.PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph))
{
name_uid_map = tf.keras.backend.PER_GRAPH_LAYER_NAME_UIDS[graph];
}
else
{
name_uid_map = new Dictionary<(string, string), int>();
name_uid_map = new Dictionary<string, int>();
tf.keras.backend.PER_GRAPH_LAYER_NAME_UIDS[graph] = name_uid_map;
}



+ 1
- 3
src/TensorFlowNET.Core/Layers/Layer.cs View File

@@ -183,8 +183,6 @@ namespace Tensorflow.Layers
});
}



protected override string _name_scope()
{
return _current_scope.original_name_scope;
@@ -202,7 +200,7 @@ namespace Tensorflow.Layers
}
else
{
tf_with(tf.variable_scope(scope, default_name: baseName), captured_scope =>
tf_with(tf.variable_scope(scope, default_name: base_name), captured_scope =>
{
// convert variable_scope to VariableScope
_scope = captured_scope;


+ 2
- 2
src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs View File

@@ -41,8 +41,8 @@ namespace Tensorflow.Operations.Initializers
public Tensor Apply(InitializerArgs args)
{
if (args.DType == TF_DataType.DtInvalid)
args.DType = this.dtype;
return random_ops.random_normal(args.Shape, mean, stddev, dtype, seed: seed);
args.DType = dtype;
return random_ops.random_normal(args.Shape, mean, stddev, args.DType, seed: seed);
}
}
}

+ 0
- 84
src/TensorFlowNET.Core/Operations/NnOps/Convolution.cs View File

@@ -1,84 +0,0 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System.Linq;

namespace Tensorflow.Operations
{
public class Convolution
{
public TensorShape input_shape;
public TensorShape filter_shape;
public string data_format;
public int[] strides;
public string name;
public _WithSpaceToBatch conv_op;

public Convolution(TensorShape input_shape,
TensorShape filter_shape,
string padding,
int[] strides,
int[] dilation_rate,
string name = null,
string data_format = null)
{
var num_total_dims = filter_shape.ndim;
var num_spatial_dims = num_total_dims - 2;
int input_channels_dim;
int[] spatial_dims;
if (string.IsNullOrEmpty(data_format) || !data_format.StartsWith("NC"))
{
input_channels_dim = input_shape.dims[num_spatial_dims + 1];
spatial_dims = Enumerable.Range(1, num_spatial_dims).ToArray();
}
else
{
input_channels_dim = input_shape.dims[1];
spatial_dims = Enumerable.Range(2, num_spatial_dims).ToArray();
}

this.input_shape = input_shape;
this.filter_shape = filter_shape;
this.data_format = data_format;
this.strides = strides;
this.name = name;

conv_op = new _WithSpaceToBatch(
input_shape,
dilation_rate: dilation_rate,
padding: padding,
build_op: _build_op,
filter_shape: filter_shape,
spatial_dims: spatial_dims,
data_format: data_format);
}

public _NonAtrousConvolution _build_op(int _, string padding)
{
return new _NonAtrousConvolution(input_shape,
filter_shape: filter_shape,
padding: padding,
data_format: data_format,
strides: strides,
name: name);
}

public Tensor __call__(Tensor inp, IVariableV1 filter)
{
return conv_op.__call__(inp, filter);
}
}
}

+ 100
- 0
src/TensorFlowNET.Core/Operations/NnOps/ConvolutionInternal.cs View File

@@ -0,0 +1,100 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Linq;
using System.Xml;
using Tensorflow.Keras.ArgsDefinition;
using static Tensorflow.Binding;

namespace Tensorflow.Operations
{
internal class ConvolutionInternal
{
ConvolutionalArgs args;
string data_format => args.DataFormat;
string name;
string padding => args.Padding;

public ConvolutionInternal(ConvolutionalArgs args)
{
this.args = args;
name = args.Name;
}

public Tensor Apply(Tensors input, IVariableV1 filters)
{
var filters_rank = filters.shape.rank;
var inputs_rank = input.shape.rank;
var num_spatial_dims = args.NumSpatialDims;
if (num_spatial_dims == Unknown)
num_spatial_dims = filters_rank - 2;

// Channel dimension.
var num_batch_dims = inputs_rank - num_spatial_dims - 1;
if (!new[] { 1, 2, 3 }.Contains(num_spatial_dims))
throw new ValueError($"num_spatial_dims (input.shape.ndims - num_batch_dims - 1) must be one " +
$"of 1, 2 or 3 but saw {num_spatial_dims}. num_batch_dims: {num_batch_dims}.");

var channel_index = num_batch_dims + num_spatial_dims;
var dilations = _get_sequence(args.DilationRate, num_spatial_dims, channel_index);
var strides = _get_sequence(args.Strides, num_spatial_dims, channel_index);

Tensor result = null;
tf_with(ops.name_scope(name, default_name: null, (input, filters)), scope =>
{
name = scope;
if (num_spatial_dims == 2)
result = gen_nn_ops.conv2d(new Conv2dParams
{
Input = input,
Filter = filters.AsTensor(),
Strides = strides,
Padding = padding,
DataFormat = data_format,
Dilations = dilations,
Name = name
});
else
throw new NotImplementedException("");
});

return result;
}

int[] _get_sequence(int[] value, int n, int channel_index)
{
var seq = new List<int>();

if (channel_index == 1)
{
seq.Add(1);
seq.Add(1);
seq.AddRange(value);
}
else
{
seq.Add(1);
seq.AddRange(value);
seq.Add(1);
}

return seq.ToArray();
}
}
}

+ 0
- 83
src/TensorFlowNET.Core/Operations/NnOps/_NonAtrousConvolution.cs View File

@@ -1,83 +0,0 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Linq;

namespace Tensorflow.Operations
{
public class _NonAtrousConvolution
{
public string padding;
public string name;
public int[] strides;
public string data_format;
private Func<Conv2dParams, Tensor> conv_op;

public _NonAtrousConvolution(TensorShape input_shape,
TensorShape filter_shape,
string padding,
string data_format,
int[] strides,
string name)
{
this.padding = padding;
this.name = name;
var conv_dims = input_shape.ndim - 2;
if (conv_dims == 1)
{
throw new NotImplementedException("_NonAtrousConvolution conv_dims 1");
}
else if (conv_dims == 2)
{
var list = strides.ToList();

if (string.IsNullOrEmpty(data_format) || data_format == "NHWC")
{
data_format = "NHWC";
list.Insert(0, 1);
list.Add(1);
}
else if (data_format == "NCHW")
list.InsertRange(0, new int[] { 1, 1 });
else
throw new ValueError("data_format must be \"NHWC\" or \"NCHW\".");

strides = list.ToArray();
this.strides = strides;
this.data_format = data_format;
conv_op = gen_nn_ops.conv2d;
}
else if (conv_dims == 3)
{
throw new NotImplementedException("_NonAtrousConvolution conv_dims 3");
}
}

public Tensor __call__(Tensor inp, IVariableV1 filter)
{
return conv_op(new Conv2dParams
{
Input = inp,
Filter = filter.AsTensor(),
Strides = strides,
Padding = padding,
DataFormat = data_format,
Name = name
});
}
}
}

+ 0
- 76
src/TensorFlowNET.Core/Operations/NnOps/_WithSpaceToBatch.cs View File

@@ -1,76 +0,0 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Linq;

namespace Tensorflow.Operations
{
public class _WithSpaceToBatch
{
private _NonAtrousConvolution call;

public _WithSpaceToBatch(TensorShape input_shape,
int[] dilation_rate,
string padding,
Func<int, string, _NonAtrousConvolution> build_op,
TensorShape filter_shape = null,
int[] spatial_dims = null,
string data_format = null)
{
var dilation_rate_tensor = ops.convert_to_tensor(dilation_rate, TF_DataType.TF_INT32, name: "dilation_rate");
var rate_shape = dilation_rate_tensor.TensorShape;
var num_spatial_dims = rate_shape.dims[0];
#pragma warning disable CS0219 // Variable is assigned but its value is never used
int starting_spatial_dim = -1;
#pragma warning restore CS0219 // Variable is assigned but its value is never used
if (!string.IsNullOrEmpty(data_format) && data_format.StartsWith("NC"))
starting_spatial_dim = 2;
else
starting_spatial_dim = 1;

if (spatial_dims == null)
throw new NotImplementedException("_WithSpaceToBatch spatial_dims");

var orig_spatial_dims = spatial_dims;
spatial_dims = spatial_dims.OrderBy(x => x).ToArray();
if (!Enumerable.SequenceEqual(spatial_dims, orig_spatial_dims) || spatial_dims.Any(x => x < 1))
throw new ValueError("spatial_dims must be a montonically increasing sequence of positive integers");

int expected_input_rank = -1;
if (!string.IsNullOrEmpty(data_format) && data_format.StartsWith("NC"))
expected_input_rank = spatial_dims.Last();
else
expected_input_rank = spatial_dims.Last() + 1;

var const_rate = tensor_util.constant_value(dilation_rate_tensor);
var rate_or_const_rate = dilation_rate;
if(!(const_rate is null))
{
if (const_rate.Data<int>().Count(x => x == 1) == const_rate.size)
{
call = build_op(num_spatial_dims, padding);
return;
}
}
}

public Tensor __call__(Tensor inp, IVariableV1 filter)
{
return call.__call__(inp, filter);
}
}
}

+ 58
- 38
src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs View File

@@ -78,35 +78,45 @@ namespace Tensorflow.Operations
/// </summary>
/// <param name="parameters"></param>
/// <returns></returns>
public static Tensor conv2d_backprop_filter(Conv2dParams parameters)
public static Tensor conv2d_backprop_filter(Tensor input, Tensor filter_sizes, Tensor out_backprop,
int[] strides, string padding, bool use_cudnn_on_gpu = true,
int[] explicit_paddings = null,
string data_format = "NHWC",
int[] dilations = null,
string name = null)
{
if (explicit_paddings == null)
explicit_paddings = new int[0];
if (dilations == null)
dilations = new int[] { 1, 1, 1, 1 };

if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Conv2DBackpropFilter", parameters.Name,
"Conv2DBackpropFilter", name,
null,
parameters.Input, parameters.FilterSizes, parameters.OutBackProp,
"strides", parameters.Strides,
"use_cudnn_on_gpu", parameters.UseCudnnOnGpu,
"padding", parameters.Padding,
"explicit_paddings", parameters.ExplicitPaddings,
"data_format", parameters.DataFormat,
"dilations", parameters.Dilations);
input, filter_sizes, out_backprop,
"strides", strides,
"use_cudnn_on_gpu", use_cudnn_on_gpu,
"padding", padding,
"explicit_paddings", explicit_paddings,
"data_format", data_format,
"dilations", dilations);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropFilter", name: parameters.Name, args: new
var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropFilter", name: name, args: new
{
input = parameters.Input,
filter_sizes = parameters.FilterSizes,
out_backprop = parameters.OutBackProp,
strides = parameters.Strides,
padding = parameters.Padding,
use_cudnn_on_gpu = parameters.UseCudnnOnGpu,
explicit_paddings = parameters.ExplicitPaddings,
data_format = parameters.DataFormat,
dilations = parameters.Dilations
input,
filter_sizes,
out_backprop,
strides,
padding,
use_cudnn_on_gpu,
explicit_paddings,
data_format,
dilations
});

return _op.outputs[0];
@@ -117,35 +127,45 @@ namespace Tensorflow.Operations
/// </summary>
/// <param name="parameters"></param>
/// <returns></returns>
public static Tensor conv2d_backprop_input(Conv2dParams parameters)
public static Tensor conv2d_backprop_input(Tensor input_sizes, Tensor filter, Tensor out_backprop,
int[] strides, string padding, bool use_cudnn_on_gpu = true,
int[] explicit_paddings = null,
string data_format= "NHWC",
int[] dilations = null,
string name = null)
{
if (explicit_paddings == null)
explicit_paddings = new int[0];
if (dilations == null)
dilations = new int[] { 1, 1, 1, 1 };

if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Conv2DBackpropInput", parameters.Name,
"Conv2DBackpropInput", name,
null,
parameters.InputSizes, parameters.Filter, parameters.OutBackProp,
"strides", parameters.Strides,
"use_cudnn_on_gpu", parameters.UseCudnnOnGpu,
"padding", parameters.Padding,
"explicit_paddings", parameters.ExplicitPaddings,
"data_format", parameters.DataFormat,
"dilations", parameters.Dilations);
input_sizes, filter, out_backprop,
"strides", strides,
"use_cudnn_on_gpu", use_cudnn_on_gpu,
"padding", padding,
"explicit_paddings", explicit_paddings,
"data_format", data_format,
"dilations", dilations);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropInput", name: parameters.Name, args: new
var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropInput", name: name, args: new
{
input_sizes = parameters.InputSizes,
filter = parameters.Filter,
out_backprop = parameters.OutBackProp,
strides = parameters.Strides,
padding = parameters.Padding,
use_cudnn_on_gpu = parameters.UseCudnnOnGpu,
explicit_paddings = parameters.ExplicitPaddings,
data_format = parameters.DataFormat,
dilations = parameters.Dilations
input_sizes,
filter,
out_backprop,
strides,
padding,
use_cudnn_on_gpu,
explicit_paddings,
data_format,
dilations
});

return _op.outputs[0];


+ 5
- 5
src/TensorFlowNET.Core/Operations/gen_random_ops.cs View File

@@ -33,11 +33,6 @@ namespace Tensorflow
/// <returns></returns>
public static Tensor random_standard_normal(Tensor shape, TF_DataType dtype = TF_DataType.DtInvalid, int? seed = null, int? seed2 = null, string name = null)
{
if (!seed.HasValue)
seed = 0;
if (!seed2.HasValue)
seed2 = 0;

if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
@@ -51,6 +46,11 @@ namespace Tensorflow
return results[0];
}

if (!seed.HasValue)
seed = 0;
if (!seed2.HasValue)
seed2 = 0;

var _op = tf.OpDefLib._apply_op_helper("RandomStandardNormal",
name: name,
args: new { shape, dtype, seed, seed2 });


+ 10
- 10
src/TensorFlowNET.Core/Operations/nn_ops.cs View File

@@ -16,6 +16,7 @@

using System;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Operations;
using static Tensorflow.Binding;

@@ -23,19 +24,18 @@ namespace Tensorflow
{
public class nn_ops
{
public static Convolution Convolution(TensorShape input_shape,
TensorShape filter_shape,
string padding,
internal static ConvolutionInternal convolution_internal(string padding,
int[] strides,
int[] dilation_rate,
string name = null,
string data_format = null) => new Convolution(input_shape,
filter_shape,
padding,
strides,
dilation_rate,
name: name,
data_format: data_format);
string data_format = null) => new ConvolutionInternal(new ConvolutionalArgs
{
Padding = padding,
Strides = strides,
DilationRate = dilation_rate,
DataFormat = data_format,
Name = name
});

/// <summary>
/// Adds `bias` to `value`.


+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -64,7 +64,7 @@ namespace Tensorflow
/// The string name of this tensor.<br/>
/// Tensor.name is meaningless when eager execution is enabled.
/// </summary>
public string name => $"{(op == null ? "<unnamed>" : $"{op.name}:{_value_index}")}";
public virtual string name => $"{(op == null ? "<unnamed>" : $"{op.name}:{_value_index}")}";

/// <summary>
/// The index of this tensor in the outputs of its Operation.


+ 1
- 1
src/TensorFlowNET.Core/Tensors/TensorShape.cs View File

@@ -132,7 +132,7 @@ namespace Tensorflow
}
}

public int this[int index] => dims[index];
public int this[int index] => index < 0 ? dims[ndim + index] : dims[index];

/// <summary>
/// Returns True iff `self` is fully defined in every dimension.


+ 13
- 4
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -8,7 +8,7 @@ using static Tensorflow.Binding;

namespace Tensorflow
{
public class BaseResourceVariable : DisposableObject, IVariableV1
public class BaseResourceVariable : DisposableObject
{
protected string _name;
public virtual string Name => _handle_name;
@@ -92,7 +92,8 @@ namespace Tensorflow
return assign_op;
}

public Tensor value() => tf.executing_eagerly() ? _read_variable_op() : GraphElement;
public Tensor value()
=> GraphElement ?? _read_variable_op();

protected Tensor _read_variable_op()
{
@@ -159,7 +160,15 @@ namespace Tensorflow
{
}

public Tensor AsTensor()
=> tf.executing_eagerly() ? read_value() : GraphElement;
public Tensor AsTensor(bool as_ref = true)
{
if (!as_ref && GraphElement != null)
return GraphElement;

if (as_ref)
return tf.executing_eagerly() ? read_value() : GraphElement;
else
return _read_variable_op();
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Variables/IVariableV1.cs View File

@@ -49,6 +49,6 @@ namespace Tensorflow
public TensorShape shape { get; }
Tensor assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true);
Tensor assign<T>(T value, bool use_locking = false, string name = null, bool read_value = true);
Tensor AsTensor();
Tensor AsTensor(bool as_ref = true);
}
}

+ 2
- 2
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -152,7 +152,7 @@ namespace Tensorflow
if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES))
collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES);

tf_with(ops.init_scope2(), delegate
tf_with(ops.init_scope(), init_scope =>
{
var values = init_from_fn ? new object[0] : new object[] { initial_value };
tf_with(ops.name_scope(name, "Variable", values), scope =>
@@ -222,7 +222,7 @@ namespace Tensorflow

public Tensor value() => _snapshot;

public Tensor AsTensor() => _snapshot;
public Tensor AsTensor(bool as_ref = true) => _snapshot;

public Tensor _as_graph_element() => _variable;



+ 2
- 2
src/TensorFlowNET.Core/Variables/ResourceVariable.cs View File

@@ -26,7 +26,7 @@ namespace Tensorflow
/// <summary>
/// Variable based on resource handles.
/// </summary>
public partial class ResourceVariable : BaseResourceVariable
public partial class ResourceVariable : BaseResourceVariable, IVariableV1
{
Tensor _cached_value;
public string Device => handle.Device;
@@ -90,7 +90,7 @@ namespace Tensorflow
collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES);
_in_graph_mode = !tf.Context.executing_eagerly();
tf_with(ops.init_scope2(), delegate
tf_with(ops.init_scope(), init_scope =>
{
var values = init_from_fn ? new object[0] : new object[] { initial_value };
tf_with(ops.name_scope(name, "Variable", values), scope =>


+ 5
- 20
src/TensorFlowNET.Core/ops.cs View File

@@ -239,11 +239,8 @@ namespace Tensorflow
/// A context manager that lifts ops out of control-flow scopes and function-building graphs.
/// </summary>
/// <returns></returns>
public static void init_scope()
public static NameScope init_scope()
{
if (tf.Context.executing_eagerly())
return;

// Retrieve the active name scope: entering an `init_scope` preserves
// the name scope of the current context.
var default_graph = get_default_graph();
@@ -257,25 +254,11 @@ namespace Tensorflow

tf_with(ops.control_dependencies(null), delegate
{
var outer_graph = get_default_graph();
// var outer_graph = get_default_graph();
// outer_device_stack = None
});
}

public static ITensorFlowObject init_scope2()
{
// Retrieve the active name scope: entering an `init_scope` preserves
// the name scope of the current context.
var default_graph = get_default_graph();
var scope = default_graph.get_name_scope();
if (!String.IsNullOrEmpty(scope) && !scope.EndsWith("/"))
// Names that end with trailing slashes are treated by `name_scope` as
// absolute.
scope += "/";
// inner_device_stack = default_graph._device_function_stack
// var outer_context = default_graph.as_default;

return ops.control_dependencies(null);
return ops.name_scope(scope);
}

private static int uid_number = -1;
@@ -460,6 +443,8 @@ namespace Tensorflow
{
case NDArray nd:
return constant_op.constant(nd, dtype: dtype, name: name);
case EagerTensor tensor:
return tf.executing_eagerly() ? tensor : tensor.AsPlaceholder(name: name);
case Tensor tensor:
return tensor;
case Tensor[] tensors:


+ 1
- 0
src/TensorFlowNET.Core/ops.name_scope.cs View File

@@ -90,6 +90,7 @@ namespace Tensorflow
return (scope_name, old_name);
}

[DebuggerHidden]
public void Dispose()
{
if (tf.Context.executing_eagerly())


+ 2
- 1
test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs View File

@@ -28,7 +28,8 @@ namespace TensorFlowNET.UnitTest.Keras

// Create a simple model.
var inputs = keras.Input(shape: 32);
var outputs = keras.layers.Dense(1).Apply(inputs);
var dense_layer = keras.layers.Dense(1);
var outputs = dense_layer.Apply(inputs);
var model = keras.Model(inputs, outputs);
model.compile("adam", "mean_squared_error");
return model;


+ 5
- 0
test/TensorFlowNET.UnitTest/ManagedAPI/BitwiseApiTest.cs View File

@@ -8,6 +8,11 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
[TestClass]
public class BitwiseApiTest : TFNetApiTest
{
[TestInitialize]
public void Init()
{
tf.enable_eager_execution();
}

[TestMethod]
public void BitwiseAnd()


Loading…
Cancel
Save