diff --git a/src/TensorFlowNET.Core/APIs/tf.layers.cs b/src/TensorFlowNET.Core/APIs/tf.layers.cs
index 3485fbd5..7330e957 100644
--- a/src/TensorFlowNET.Core/APIs/tf.layers.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.layers.cs
@@ -92,7 +92,7 @@ namespace Tensorflow
///
///
///
- public Tensor batch_normalization(Tensor inputs,
+ public Tensors batch_normalization(Tensor inputs,
int axis = -1,
float momentum = 0.99f,
float epsilon = 0.001f,
@@ -108,22 +108,24 @@ namespace Tensorflow
bool renorm = false,
float renorm_momentum = 0.99f)
{
- var layer = new BatchNormalization(
- axis: axis,
- momentum: momentum,
- epsilon: epsilon,
- center: center,
- scale: scale,
- beta_initializer: beta_initializer,
- gamma_initializer: gamma_initializer,
- moving_mean_initializer: moving_mean_initializer,
- moving_variance_initializer: moving_variance_initializer,
- renorm: renorm,
- renorm_momentum: renorm_momentum,
- trainable: trainable,
- name: name);
-
- return layer.apply(inputs, training: training).Item1;
+ var layer = new BatchNormalization(new BatchNormalizationArgs
+ {
+ Axis = axis,
+ Momentum = momentum,
+ Epsilon = epsilon,
+ Center = center,
+ Scale = scale,
+ BetaInitializer = beta_initializer,
+ GammaInitializer = gamma_initializer,
+ MovingMeanInitializer = moving_mean_initializer,
+ MovingVarianceInitializer = moving_variance_initializer,
+ Renorm = renorm,
+ RenormMomentum = renorm_momentum,
+ Trainable = trainable,
+ Name = name
+ });
+
+ return layer.Apply(inputs);
}
///
diff --git a/src/TensorFlowNET.Core/APIs/tf.ops.cs b/src/TensorFlowNET.Core/APIs/tf.ops.cs
index c651bba9..d8109676 100644
--- a/src/TensorFlowNET.Core/APIs/tf.ops.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.ops.cs
@@ -41,6 +41,8 @@ namespace Tensorflow
///
/// A context manager that lifts ops out of control-flow scopes and function-building graphs.
+ /// When eager execution is enabled, code inside an init_scope block runs with
+ /// eager execution enabled even when tracing a `tf.function`.
///
public void init_scope()
=> ops.init_scope();
diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
index 84e27cc6..d1c7eb13 100644
--- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
@@ -227,6 +227,11 @@ namespace Tensorflow.Eager
input_handle = input.EagerTensorHandle;
flattened_inputs.Add(input);
break;
+ case ResourceVariable variable:
+ var var_tensor = variable.AsTensor();
+ input_handle = var_tensor.EagerTensorHandle;
+ flattened_inputs.Add(var_tensor);
+ break;
default:
var tensor = tf.convert_to_tensor(inputs);
input_handle = tensor.EagerTensorHandle;
diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
index 809c4cea..5733e08d 100644
--- a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
@@ -57,6 +57,26 @@ namespace Tensorflow.Eager
return this;
}
+ ///
+ /// _create_substitute_placeholder
+ ///
+ ///
+ public Tensor AsPlaceholder(string name = null)
+ {
+ Tensor placeholder = null;
+ tf_with(ops.control_dependencies(null), delegate
+ {
+ placeholder = tf.placeholder(dtype, shape: shape, name: name ?? this.name);
+ });
+ // custom_gradient.copy_handle_data(value, placeholder)
+ return placeholder;
+ }
+
+ void copy_handle_data()
+ {
+
+ }
+
public override IntPtr ToPointer()
=> EagerTensorHandle?.DangerousGetHandle() ?? IntPtr.Zero;
diff --git a/src/TensorFlowNET.Core/Gradients/nn_grad.cs b/src/TensorFlowNET.Core/Gradients/nn_grad.cs
index e2564ff5..b3e4039c 100644
--- a/src/TensorFlowNET.Core/Gradients/nn_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/nn_grad.cs
@@ -138,30 +138,16 @@ namespace Tensorflow.Gradients
return new Tensor[]
{
- gen_nn_ops.conv2d_backprop_input(new Conv2dParams
- {
- InputSizes = shape[0],
- Filter = op.inputs[1],
- OutBackProp = grads[0],
- Dilations = dilations,
- Strides = strides,
- Padding = padding.ToString(),
- ExplicitPaddings = explicit_paddings,
- UseCudnnOnGpu = (bool)use_cudnn_on_gpu,
- DataFormat = data_format.ToString(),
- }),
- gen_nn_ops.conv2d_backprop_filter(new Conv2dParams
- {
- Input = op.inputs[0],
- FilterSizes = shape[1],
- OutBackProp = grads[0],
- Dilations = dilations,
- Strides = strides,
- Padding = padding.ToString(),
- ExplicitPaddings = explicit_paddings,
- UseCudnnOnGpu = (bool)use_cudnn_on_gpu,
- DataFormat = data_format.ToString()
- })
+ gen_nn_ops.conv2d_backprop_input(shape[0], op.inputs[1], grads[0],
+ strides, padding, use_cudnn_on_gpu, explicit_paddings,
+ dilations: dilations,
+ data_format: data_format),
+ gen_nn_ops.conv2d_backprop_filter(op.inputs[0], shape[1], grads[0],
+ strides, padding,
+ dilations: dilations,
+ explicit_paddings: explicit_paddings,
+ use_cudnn_on_gpu: use_cudnn_on_gpu,
+ data_format: data_format)
};
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/BatchNormalizationArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/BatchNormalizationArgs.cs
new file mode 100644
index 00000000..888082c7
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/BatchNormalizationArgs.cs
@@ -0,0 +1,24 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using static Tensorflow.Binding;
+
+namespace Tensorflow.Keras.ArgsDefinition
+{
+ public class BatchNormalizationArgs : LayerArgs
+ {
+ public TensorShape Axis { get; set; } = -1;
+ public float Momentum { get; set; } = 0.99f;
+ public float Epsilon { get; set; } = 1e-3f;
+ public bool Center { get; set; } = true;
+ public bool Scale { get; set; } = true;
+ public IInitializer BetaInitializer { get; set; } = tf.zeros_initializer;
+ public IInitializer GammaInitializer { get; set; } = tf.ones_initializer;
+ public IInitializer MovingMeanInitializer { get; set; } = tf.zeros_initializer;
+ public IInitializer MovingVarianceInitializer { get; set; } = tf.ones_initializer;
+ public IRegularizer BetaRegularizer { get; set; }
+ public IRegularizer GammaRegularizer { get; set; }
+ public bool Renorm { get; set; }
+ public float RenormMomentum { get; set; } = 0.99f;
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Conv2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Conv2DArgs.cs
index be0ef74e..838954fc 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Conv2DArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Conv2DArgs.cs
@@ -4,7 +4,7 @@ using System.Text;
namespace Tensorflow.Keras.ArgsDefinition
{
- public class Conv2DArgs : ConvArgs
+ public class Conv2DArgs : ConvolutionalArgs
{
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/ConvArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/ConvolutionalArgs.cs
similarity index 82%
rename from src/TensorFlowNET.Core/Keras/ArgsDefinition/ConvArgs.cs
rename to src/TensorFlowNET.Core/Keras/ArgsDefinition/ConvolutionalArgs.cs
index b96a6ba7..00d1706b 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/ConvArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/ConvolutionalArgs.cs
@@ -5,10 +5,11 @@ using static Tensorflow.Binding;
namespace Tensorflow.Keras.ArgsDefinition
{
- public class ConvArgs : LayerArgs
+ public class ConvolutionalArgs : LayerArgs
{
public int Rank { get; set; } = 2;
public int Filters { get; set; }
+ public int NumSpatialDims { get; set; } = Unknown;
public TensorShape KernelSize { get; set; } = 5;
///
@@ -24,8 +25,8 @@ namespace Tensorflow.Keras.ArgsDefinition
public bool UseBias { get; set; }
public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer;
public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer;
- public IInitializer KernelRegularizer { get; set; }
- public IInitializer BiasRegularizer { get; set; }
+ public IRegularizer KernelRegularizer { get; set; }
+ public IRegularizer BiasRegularizer { get; set; }
public Action KernelConstraint { get; set; }
public Action BiasConstraint { get; set; }
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs
index aaf89a0c..182e616e 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs
@@ -46,7 +46,7 @@ namespace Tensorflow.Keras.ArgsDefinition
///
/// Regularizer function applied to the output of the layer(its "activation").
///
- public IInitializer ActivityRegularizer { get; set; }
+ public IRegularizer ActivityRegularizer { get; set; }
public bool Autocast { get; set; }
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/ModelArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/ModelArgs.cs
index 70238405..b1f3569c 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/ModelArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/ModelArgs.cs
@@ -6,7 +6,7 @@ namespace Tensorflow.Keras.ArgsDefinition
{
public class ModelArgs : LayerArgs
{
- public Tensor[] Inputs { get; set; }
- public Tensor[] Outputs { get; set; }
+ public Tensors Inputs { get; set; }
+ public Tensors Outputs { get; set; }
}
}
diff --git a/src/TensorFlowNET.Core/Keras/BackendImpl.cs b/src/TensorFlowNET.Core/Keras/BackendImpl.cs
index 84b244a2..ef9b3d97 100644
--- a/src/TensorFlowNET.Core/Keras/BackendImpl.cs
+++ b/src/TensorFlowNET.Core/Keras/BackendImpl.cs
@@ -42,7 +42,7 @@ namespace Tensorflow.Keras
/// for various layer names in each graph.
/// Allows to give unique autogenerated names to layers, in a graph-specific way.
///
- public Dictionary> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary>();
+ public Dictionary> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary>();
public Dictionary _GRAPH_VARIABLES = new Dictionary();
public Dictionary _GRAPH_TF_OPTIMIZERS = new Dictionary();
@@ -80,25 +80,19 @@ namespace Tensorflow.Keras
return ops.get_default_graph();
}
- public int get_uid(string prefix, string @namespace = "")
+ public int get_uid(string prefix)
{
var graph = tf.get_default_graph();
if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph))
- PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<(string, string), int>());
- PER_GRAPH_LAYER_NAME_UIDS[graph][(@namespace, prefix)] += 1;
+ PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict());
+ if (!PER_GRAPH_LAYER_NAME_UIDS[graph].ContainsKey(prefix))
+ PER_GRAPH_LAYER_NAME_UIDS[graph][prefix] = 0;
+ PER_GRAPH_LAYER_NAME_UIDS[graph][prefix] += 1;
- return PER_GRAPH_LAYER_NAME_UIDS[graph][(@namespace, prefix)];
+ return PER_GRAPH_LAYER_NAME_UIDS[graph][prefix];
}
- public int get_uid((string, string) name)
- {
- var graph = tf.get_default_graph();
- if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph))
- PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<(string, string), int>());
- PER_GRAPH_LAYER_NAME_UIDS[graph][(name)] += 1;
- return PER_GRAPH_LAYER_NAME_UIDS[graph][name];
- }
- public void reset_uids() => PER_GRAPH_LAYER_NAME_UIDS = new Dictionary>();
+ public void reset_uids() => PER_GRAPH_LAYER_NAME_UIDS = new Dictionary>();
public void clear_session()
{
ops.reset_default_graph();
diff --git a/src/TensorFlowNET.Core/Keras/Engine/Functional.cs b/src/TensorFlowNET.Core/Keras/Engine/Functional.cs
new file mode 100644
index 00000000..fe2f0728
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Engine/Functional.cs
@@ -0,0 +1,67 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Keras.ArgsDefinition;
+
+namespace Tensorflow.Keras.Engine
+{
+ ///
+ /// A `Functional` model is a `Model` defined as a directed graph of layers.
+ ///
+ public class Functional : Model
+ {
+ TensorShape _build_input_shape;
+ bool _compute_output_and_mask_jointly;
+ bool _expects_training_arg;
+ bool _expects_mask_arg;
+ bool _autocast;
+ List _output_layers;
+ List _input_layers;
+ List _input_coordinates;
+ List _output_coordinates;
+
+ public Functional(Tensors inputs, Tensors outputs)
+ : base(new ModelArgs
+ {
+ Inputs = inputs,
+ Outputs = outputs
+ })
+ {
+ _input_layers = new List();
+ _output_layers = new List();
+ _input_coordinates = new List();
+ _output_coordinates = new List();
+ _init_graph_network(inputs, outputs);
+ }
+
+ void _init_graph_network(Tensors inputs, Tensors outputs)
+ {
+ _is_graph_network = true;
+ this.inputs = inputs;
+ this.outputs = outputs;
+ built = true;
+ _build_input_shape = inputs.shape;
+ _compute_output_and_mask_jointly = true;
+ _expects_training_arg = true;
+ _expects_mask_arg = true;
+ // A graph network does not autocast inputs, as its layers will cast them instead.
+ _autocast = false;
+
+ // Build self._output_layers:
+ foreach(var x in outputs)
+ {
+ var (layer, node_index, tensor_index) = x.KerasHistory;
+ _output_layers.append(layer);
+ _output_coordinates.append(new KerasHistory(layer, node_index, tensor_index));
+ }
+
+ // Build self._input_layers:
+ foreach(var x in inputs)
+ {
+ var (layer, node_index, tensor_index) = x.KerasHistory;
+ _input_layers.append(layer);
+ _input_coordinates.append(new KerasHistory(layer, node_index, tensor_index));
+ }
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs b/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs
index 2041fe7d..cae054ce 100644
--- a/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs
+++ b/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs
@@ -15,6 +15,7 @@
******************************************************************************/
using System.Collections.Generic;
+using System.Linq;
namespace Tensorflow.Keras.Engine
{
@@ -27,6 +28,7 @@ namespace Tensorflow.Keras.Engine
public int? min_ndim;
Dictionary axes;
TensorShape shape;
+ public int[] AllAxisDim;
public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid,
int? ndim = null,
@@ -42,6 +44,12 @@ namespace Tensorflow.Keras.Engine
this.shape = shape;
if (ndim == null && shape != null)
this.ndim = shape.ndim;
+
+ if(axes != null)
+ AllAxisDim = axes.Select(x => x.Value).ToArray();
}
+
+ public override string ToString()
+ => $"min_ndim={min_ndim}, , axes={axes.Count}";
}
}
diff --git a/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs b/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs
index 832124e4..dd32f473 100644
--- a/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs
+++ b/src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs
@@ -20,10 +20,14 @@ namespace Tensorflow.Keras.Engine
this.tensor_index = tensor_index;
}
+ public void Deconstruct(out Layer layer, out int node_index, out int tensor_index)
+ {
+ layer = this.layer;
+ node_index = this.node_index;
+ tensor_index = this.tensor_index;
+ }
+
public static implicit operator Layer(KerasHistory history)
=> history.layer;
-
- public static implicit operator (Layer, int, int)(KerasHistory history)
- => (history.layer, history.node_index, history.tensor_index);
}
}
diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs
index 8c943235..b9df4ce7 100644
--- a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs
+++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs
@@ -72,10 +72,10 @@ namespace Tensorflow.Keras.Engine
protected List nonTrainableWeights;
public List non_trainable_variables => nonTrainableWeights;
- string name;
+ protected string name;
+ protected string base_name;
public string Name => name;
-
- protected string baseName;
+
protected bool computePreviousMask;
protected List updates;
public TensorShape BatchInputShape => args.BatchInputShape;
@@ -98,9 +98,9 @@ namespace Tensorflow.Keras.Engine
// Indicates whether `build` needs to be called upon layer call, to create
// the layer's weights.
built = false;
- this.SupportsMasking = false;
+ SupportsMasking = false;
- _init_set_name(name);
+ _init_set_name(args.Name);
trainableWeights = new List();
nonTrainableWeights = new List();
computePreviousMask = false;
@@ -124,23 +124,25 @@ namespace Tensorflow.Keras.Engine
///
public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false)
{
- Tensors outputs = null;
-
callContext = callContext ?? new ThreadLocal()
{
Value = new CallContext()
};
+ if (_in_functional_construction_mode(inputs))
+ return _functional_construction_call(inputs);
+
+ Tensors outputs = null;
+
var eager = tf.executing_eagerly();
using var ctxManager = CallContext.enter();
string nameScope = "";
if (eager)
- nameScope = name;
+ nameScope = Name;
else
nameScope = _name_scope();
- // using var graph = tf.keras.backend.get_graph().as_default();
if (!inputs.IsEagerTensor)
tf.Context.graph_mode();
@@ -162,6 +164,46 @@ namespace Tensorflow.Keras.Engine
return outputs;
}
+ bool _in_functional_construction_mode(Tensors inputs)
+ {
+ return inputs.Count(x => !x.IsEagerTensor) == inputs.Count();
+ }
+
+ Tensors _functional_construction_call(Tensors inputs)
+ {
+ bool mask_arg_passed_by_framework = false;
+ bool training_arg_passed_by_framework = false;
+ Tensor training_value = null;
+ if(training_value == null)
+ {
+ training_arg_passed_by_framework = true;
+ }
+
+ Tensors outputs = null;
+ using var ctxManager = CallContext.enter();
+
+ // using var graph = tf.keras.backend.get_graph().as_default();
+
+ if (!inputs.IsEagerTensor)
+ tf.Context.graph_mode();
+
+ tf_with(ops.name_scope(_name_scope()), scope =>
+ {
+ MaybeBuild(inputs);
+
+ outputs = call(inputs);
+
+ outputs = _set_connectivity_metadata_(inputs, outputs);
+ _handle_activity_regularization(inputs, outputs);
+ _set_mask_metadata(inputs, outputs, null);
+ });
+
+ if (!inputs.IsEagerTensor)
+ tf.Context.restore_mode();
+
+ return outputs;
+ }
+
private Tensors _set_connectivity_metadata_(Tensors inputs, Tensors outputs)
{
/*var returnOutputs = new List();
@@ -219,8 +261,12 @@ namespace Tensorflow.Keras.Engine
if (DType == TF_DataType.DtInvalid)
args.DType = inputs.dtype;
- var input_shapes = inputs.shape;
- build(input_shapes);
+ tf.init_scope();
+
+ //tf.Context.eager_mode();
+ build(inputs.shape);
+ //tf.Context.restore_mode();
+
built = true;
}
@@ -229,10 +275,16 @@ namespace Tensorflow.Keras.Engine
built = true;
}
+ protected virtual void add_loss(Func losses)
+ {
+
+ }
+
protected virtual IVariableV1 add_weight(string name,
TensorShape shape,
TF_DataType dtype = TF_DataType.DtInvalid,
IInitializer initializer = null,
+ IRegularizer regularizer = null,
bool? trainable = null,
Func getter = null)
{
@@ -251,7 +303,7 @@ namespace Tensorflow.Keras.Engine
else if (dtype.is_integer())
initializer = tf.zeros_initializer;
else
- throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {this.Name}");
+ throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}");
}
var args = new VariableArgs
@@ -266,6 +318,12 @@ namespace Tensorflow.Keras.Engine
};
var variable = _add_variable_with_custom_getter(args);
+ if(regularizer != null)
+ {
+ var name_in_scope = variable.Name.Split(':')[0];
+ _handle_weight_regularization(name_in_scope, variable, regularizer);
+ }
+
//backend.track_variable(variable);
if (trainable == true)
trainableWeights.Add(variable);
@@ -275,6 +333,20 @@ namespace Tensorflow.Keras.Engine
return variable;
}
+ ///
+ /// Create lambdas which compute regularization losses.
+ ///
+ ///
+ ///
+ ///
+ void _handle_weight_regularization(string name, IVariableV1 variable, IRegularizer regularizer)
+ {
+ add_loss(() => regularizer.Apply(new RegularizerArgs
+ {
+
+ }));
+ }
+
protected virtual void add_update(Tensor[] updates, bool inputs = false)
{
var updates_op = updates.Select(x => x.op).ToArray();
@@ -284,17 +356,13 @@ namespace Tensorflow.Keras.Engine
// Determine layer name (non-unique).
protected virtual void _init_set_name(string name, bool zero_based = true)
{
- var base_name = name;
+ base_name = name;
this.name = name;
if (name == null)
- (this.name, baseName) = _make_unique_name();
- }
-
- protected virtual (string, string) _make_unique_name()
- {
- string base_name = generic_utils.to_snake_case(this.GetType().Name);
- string name = base_layer_utils.unique_layer_name(base_name);
- return (name, base_name);
+ {
+ base_name = generic_utils.to_snake_case(this.GetType().Name);
+ this.name = base_layer_utils.unique_layer_name(base_name, zero_based: zero_based);
+ }
}
}
}
diff --git a/src/TensorFlowNET.Core/Keras/Engine/Model.cs b/src/TensorFlowNET.Core/Keras/Engine/Model.cs
index b5e2b0c8..c816e85d 100644
--- a/src/TensorFlowNET.Core/Keras/Engine/Model.cs
+++ b/src/TensorFlowNET.Core/Keras/Engine/Model.cs
@@ -23,15 +23,14 @@ namespace Tensorflow.Keras.Engine
string loss;
IOptimizer optimizer;
IVariableV1 _steps_per_execution;
+ protected bool _is_graph_network;
+ protected Tensors inputs;
+ protected Tensors outputs;
public Model(ModelArgs args)
: base(args)
{
- // Build _output_layers
- /*foreach(var x in args.Outputs)
- {
- var layer = x.KerasHistory;
- }*/
+
}
public void compile(string optimizerName, string lossName)
diff --git a/src/TensorFlowNET.Core/Keras/KerasApi.cs b/src/TensorFlowNET.Core/Keras/KerasApi.cs
index 603dd2cf..5d08e8e8 100644
--- a/src/TensorFlowNET.Core/Keras/KerasApi.cs
+++ b/src/TensorFlowNET.Core/Keras/KerasApi.cs
@@ -16,6 +16,7 @@ namespace Tensorflow
{
public KerasDataset datasets { get; } = new KerasDataset();
public Initializers initializers { get; } = new Initializers();
+ public Regularizers regularizers { get; } = new Regularizers();
public LayersApi layers { get; } = new LayersApi();
public LossesApi losses { get; } = new LossesApi();
public Activations activations { get; } = new Activations();
@@ -36,12 +37,8 @@ namespace Tensorflow
///
///
///
- public Model Model(Tensor input, Tensor output)
- => new Model(new ModelArgs
- {
- Inputs = new[] { input },
- Outputs = new[] { output }
- });
+ public Functional Model(Tensors inputs, Tensors outputs)
+ => new Functional(inputs, outputs);
///
/// Instantiate a Keras tensor.
diff --git a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
index d7664493..3d6287cb 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
@@ -15,73 +15,41 @@
******************************************************************************/
using System;
+using System.Collections.Generic;
using System.Linq;
+using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;
using static Tensorflow.Binding;
namespace Tensorflow.Keras.Layers
{
- public class BatchNormalization : Tensorflow.Layers.Layer
+ public class BatchNormalization : Layer
{
-#pragma warning disable CS0414 // The field 'BatchNormalization._USE_V2_BEHAVIOR' is assigned but its value is never used
- private bool _USE_V2_BEHAVIOR = true;
-#pragma warning restore CS0414 // The field 'BatchNormalization._USE_V2_BEHAVIOR' is assigned but its value is never used
- private float momentum;
- private float epsilon;
- private bool center;
- private bool scale;
- private bool renorm;
- private bool fused;
-#pragma warning disable CS0414 // The field 'BatchNormalization._bessels_correction_test_only' is assigned but its value is never used
- private bool _bessels_correction_test_only;
-#pragma warning restore CS0414 // The field 'BatchNormalization._bessels_correction_test_only' is assigned but its value is never used
- private int[] axis;
- private string _data_format;
- private IInitializer beta_initializer;
- private IInitializer gamma_initializer;
- private IInitializer moving_mean_initializer;
- private IInitializer moving_variance_initializer;
- private IVariableV1 gamma;
- private IVariableV1 beta;
- private RefVariable moving_mean;
- private RefVariable moving_variance;
-
- public BatchNormalization(int axis = -1,
- float momentum = 0.99f,
- float epsilon = 0.001f,
- bool center = true,
- bool scale = true,
- IInitializer beta_initializer = null,
- IInitializer gamma_initializer = null,
- IInitializer moving_mean_initializer = null,
- IInitializer moving_variance_initializer = null,
- bool renorm = false,
- float renorm_momentum = 0.99f,
- bool trainable = true,
- string name = null) : base(trainable: trainable,
- name: name)
+ BatchNormalizationArgs args;
+
+ float momentum => args.Momentum;
+ float epsilon => args.Epsilon;
+ bool center => args.Center;
+ bool scale => args.Scale;
+ bool renorm => args.Renorm;
+ bool fused;
+ int[] axis;
+ string _data_format;
+ IInitializer beta_initializer => args.BetaInitializer;
+ IInitializer gamma_initializer => args.GammaInitializer;
+ IInitializer moving_mean_initializer;
+ IInitializer moving_variance_initializer;
+ IRegularizer gamma_regularizer => args.GammaRegularizer;
+ IVariableV1 gamma;
+ IVariableV1 beta;
+ IVariableV1 moving_mean;
+ IVariableV1 moving_variance;
+
+ public BatchNormalization(BatchNormalizationArgs args) : base(args)
{
- this.axis = new int[] { axis };
- this.momentum = momentum;
- this.epsilon = epsilon;
- this.center = center;
- this.scale = scale;
- if (beta_initializer == null)
- beta_initializer = tf.zeros_initializer;
- if (gamma_initializer == null)
- gamma_initializer = tf.ones_initializer;
- if (moving_mean_initializer == null)
- moving_mean_initializer = tf.zeros_initializer;
- if (moving_variance_initializer == null)
- moving_variance_initializer = tf.ones_initializer;
- this.beta_initializer = beta_initializer;
- this.gamma_initializer = gamma_initializer;
- this.moving_mean_initializer = moving_mean_initializer;
- this.moving_variance_initializer = moving_variance_initializer;
- this.renorm = renorm;
- this.fused = true;
- this.SupportsMasking = true;
- this._bessels_correction_test_only = true;
+ this.args = args;
+ axis = args.Axis.dims;
}
protected override void build(TensorShape input_shape)
@@ -91,12 +59,25 @@ namespace Tensorflow.Keras.Layers
if (x < 0)
axis[idx] = ndims + x;
+ fused = ndims == 4;
+
if (fused)
- if (Enumerable.SequenceEqual(axis, new int[] { 3 }))
+ {
+ if (Enumerable.SequenceEqual(axis, new int[] { 1 }))
+ _data_format = "NCHW";
+ else if (Enumerable.SequenceEqual(axis, new int[] { 3 }))
_data_format = "NHWC";
+ else
+ throw new ValueError($"Unsupported axis, fused batch norm only supports axis == [1] or axis == [3]");
+ }
+
+ var axis_to_dim = new Dictionary();
+ foreach(var x in axis)
+ axis_to_dim[x] = input_shape[x];
+ inputSpec = new InputSpec(ndim: ndims, axes: axis_to_dim);
var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType;
- var param_shape = new int[] { input_shape.dims[axis[0]] };
+ var param_shape = inputSpec.AllAxisDim;
if (scale)
gamma = add_weight("gamma",
@@ -116,26 +97,17 @@ namespace Tensorflow.Keras.Layers
else
throw new NotImplementedException("add_weight beta");
- if(_scope != null)
- {
-
- }
-
- moving_mean = (RefVariable)add_weight("moving_mean",
+ moving_mean = add_weight("moving_mean",
param_shape,
dtype: param_dtype,
initializer: moving_mean_initializer,
- synchronization: VariableSynchronization.OnRead,
- trainable: false,
- aggregation: VariableAggregation.Mean);
+ trainable: false);
- moving_variance = (RefVariable)add_weight("moving_variance",
+ moving_variance = add_weight("moving_variance",
shape: param_shape,
dtype: param_dtype,
initializer: moving_variance_initializer,
- synchronization: VariableSynchronization.OnRead,
- trainable: false,
- aggregation: VariableAggregation.Mean);
+ trainable: false);
if (renorm)
throw new NotImplementedException("build when renorm is true");
@@ -178,8 +150,8 @@ namespace Tensorflow.Keras.Layers
inputs,
gamma,
beta,
- mean: moving_mean,
- variance: moving_variance,
+ mean: moving_mean.AsTensor(),
+ variance: moving_variance.AsTensor(),
epsilon: epsilon,
is_training: false,
data_format: _data_format);
@@ -202,8 +174,8 @@ namespace Tensorflow.Keras.Layers
if(training_value == null)
{
- var mean_update = _assign_moving_average(moving_mean, mean, momentum_tensor);
- var variance_update = _assign_moving_average(moving_variance, variance, momentum_tensor);
+ var mean_update = _assign_moving_average(moving_mean.AsTensor(), mean, momentum_tensor);
+ var variance_update = _assign_moving_average(moving_variance.AsTensor(), variance, momentum_tensor);
add_update(new Tensor[] { mean_update }, inputs: true);
add_update(new Tensor[] { variance_update }, inputs: true);
}
diff --git a/src/TensorFlowNET.Core/Keras/Layers/Conv2D.cs b/src/TensorFlowNET.Core/Keras/Layers/Conv2D.cs
index 9fe38ad2..371d6cfd 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/Conv2D.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/Conv2D.cs
@@ -19,12 +19,11 @@ using Tensorflow.Operations.Activation;
namespace Tensorflow.Keras.Layers
{
- public class Conv2D : Conv
+ public class Conv2D : Convolutional
{
- public Conv2D(Conv2DArgs args)
- : base(args)
+ public Conv2D(Conv2DArgs args) : base(args)
{
-
+
}
}
}
diff --git a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs b/src/TensorFlowNET.Core/Keras/Layers/Convolutional.cs
similarity index 85%
rename from src/TensorFlowNET.Core/Keras/Layers/Conv.cs
rename to src/TensorFlowNET.Core/Keras/Layers/Convolutional.cs
index b26f5465..43739c7e 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/Conv.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/Convolutional.cs
@@ -20,13 +20,13 @@ using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;
using Tensorflow.Operations;
-using Tensorflow.Operations.Activation;
+using static Tensorflow.Binding;
namespace Tensorflow.Keras.Layers
{
- public class Conv : Layer
+ public class Convolutional : Layer
{
- ConvArgs args;
+ ConvolutionalArgs args;
protected int rank => args.Rank;
protected int filters => args.Filters;
protected TensorShape kernel_size => args.KernelSize;
@@ -37,13 +37,14 @@ namespace Tensorflow.Keras.Layers
protected Activation activation => args.Activation;
protected bool use_bias => args.UseBias;
protected IInitializer kernel_initializer => args.KernelInitializer;
+ protected IRegularizer kernel_regularizer => args.KernelRegularizer;
protected IInitializer bias_initializer => args.BiasInitializer;
protected IVariableV1 kernel;
protected IVariableV1 bias;
- protected Convolution _convolution_op;
- string _tf_data_format;
+ ConvolutionInternal _convolution_op;
+ protected string _tf_data_format;
- public Conv(ConvArgs args) : base(args)
+ public Convolutional(ConvolutionalArgs args) : base(args)
{
this.args = args;
args.KernelSize = conv_utils.normalize_tuple(args.KernelSize.dims, args.Rank, "kernel_size");
@@ -65,6 +66,7 @@ namespace Tensorflow.Keras.Layers
kernel = add_weight(name: "kernel",
shape: kernel_shape,
initializer: kernel_initializer,
+ regularizer: kernel_regularizer,
trainable: true,
dtype: DType);
if (use_bias)
@@ -76,7 +78,7 @@ namespace Tensorflow.Keras.Layers
var axes = new Dictionary();
axes.Add(-1, input_channel);
- inputSpec = new InputSpec(ndim: rank + 2, axes: axes);
+ inputSpec = new InputSpec(min_ndim: rank + 2, axes: axes);
string tf_padding;
if (padding == "causal")
@@ -84,20 +86,21 @@ namespace Tensorflow.Keras.Layers
else
tf_padding = padding.ToUpper();
-
- _convolution_op = nn_ops.Convolution(input_shape,
- kernel.shape,
- tf_padding,
+ string tf_op_name = GetType().Name;
+
+
+ _convolution_op = nn_ops.convolution_internal(tf_padding,
strides,
dilation_rate,
- data_format: _tf_data_format);
+ data_format: _tf_data_format,
+ name: tf_op_name);
built = true;
}
protected override Tensors call(Tensors inputs, Tensor state = null, bool training = false)
{
- var outputs = _convolution_op.__call__(inputs, kernel);
+ var outputs = _convolution_op.Apply(inputs, kernel);
if (use_bias)
{
if (data_format == "channels_first")
diff --git a/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs b/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs
index 7d31bd40..8cdaf101 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs
@@ -47,10 +47,10 @@ namespace Tensorflow.Keras.Layers
}
// moved to base class
- if (string.IsNullOrEmpty(Name))
+ if (string.IsNullOrEmpty(args.Name))
{
var prefix = "input";
- args.Name = prefix + '_' + tf.keras.backend.get_uid(prefix);
+ name = prefix + '_' + tf.keras.backend.get_uid(prefix);
}
if(args.DType == TF_DataType.DtInvalid)
@@ -91,7 +91,6 @@ namespace Tensorflow.Keras.Layers
// input_tensor._keras_mask = None
new Node(this, new NodeArgs
{
- InputTensors = args.InputTensor,
Outputs = args.InputTensor
});
diff --git a/src/TensorFlowNET.Core/Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/LayersApi.cs
index fc0b209f..51c1056a 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/LayersApi.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/LayersApi.cs
@@ -11,15 +11,35 @@ namespace Tensorflow.Keras.Layers
{
public Conv2D Conv2D(int filters,
TensorShape kernel_size = null,
+ TensorShape strides = null,
string padding = "valid",
- string activation = "relu")
- => new Conv2D(new Conv2DArgs
- {
- Filters = filters,
- KernelSize = kernel_size,
- Padding = padding,
- Activation = GetActivationByName(activation)
- });
+ string data_format = null,
+ TensorShape dilation_rate = null,
+ int groups = 1,
+ string activation = null,
+ bool use_bias = true,
+ IInitializer kernel_initializer = null,
+ IInitializer bias_initializer = null,
+ IRegularizer kernel_regularizer = null,
+ IRegularizer bias_regularizer = null,
+ IRegularizer activity_regularizer = null)
+ => new Conv2D(new Conv2DArgs
+ {
+ Rank = 2,
+ Filters = filters,
+ KernelSize = kernel_size,
+ Strides = strides == null ? (1, 1) : strides,
+ Padding = padding,
+ DataFormat = data_format,
+ DilationRate = dilation_rate == null ? (1, 1) : dilation_rate,
+ Groups = groups,
+ KernelRegularizer = kernel_regularizer,
+ KernelInitializer = kernel_initializer == null ? tf.glorot_uniform_initializer : kernel_initializer,
+ BiasInitializer = bias_initializer == null ? tf.zeros_initializer : bias_initializer,
+ BiasRegularizer = bias_regularizer,
+ ActivityRegularizer = activity_regularizer,
+ Activation = GetActivationByName(activation)
+ });
public Dense Dense(int units,
@@ -65,6 +85,30 @@ namespace Tensorflow.Keras.Layers
DataFormat = data_format
});
+ ///
+ /// `Input()` is used to instantiate a Keras tensor.
+ ///
+ /// A shape tuple not including the batch size.
+ ///
+ ///
+ ///
+ ///
+ public Tensors Input(TensorShape shape,
+ string name = null,
+ bool sparse = false,
+ bool ragged = false)
+ {
+ var input_layer = new InputLayer(new InputLayerArgs
+ {
+ InputShape = shape,
+ Name = name,
+ Sparse = sparse,
+ Ragged = ragged
+ });
+
+ return input_layer.InboundNodes[0].Outputs;
+ }
+
public MaxPooling2D MaxPooling2D(TensorShape pool_size = null,
TensorShape strides = null,
string padding = "valid")
diff --git a/src/TensorFlowNET.Core/Keras/Regularizers.cs b/src/TensorFlowNET.Core/Keras/Regularizers.cs
new file mode 100644
index 00000000..1102b62b
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Regularizers.cs
@@ -0,0 +1,12 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Keras
+{
+ public class Regularizers
+ {
+ public IRegularizer l2(float l2 = 0.01f)
+ => new L2(l2);
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/Regularizers/IRegularizer.cs b/src/TensorFlowNET.Core/Keras/Regularizers/IRegularizer.cs
new file mode 100644
index 00000000..a54a81c7
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Regularizers/IRegularizer.cs
@@ -0,0 +1,11 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Keras
+{
+ public interface IRegularizer
+ {
+ Tensor Apply(RegularizerArgs args);
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/Regularizers/L2.cs b/src/TensorFlowNET.Core/Keras/Regularizers/L2.cs
new file mode 100644
index 00000000..c0fa7078
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Regularizers/L2.cs
@@ -0,0 +1,21 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Keras
+{
+ public class L2 : IRegularizer
+ {
+ float l2;
+
+ public L2(float l2 = 0.01f)
+ {
+ this.l2 = l2;
+ }
+
+ public Tensor Apply(RegularizerArgs args)
+ {
+ throw new NotImplementedException();
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/Regularizers/RegularizerArgs.cs b/src/TensorFlowNET.Core/Keras/Regularizers/RegularizerArgs.cs
new file mode 100644
index 00000000..18bf87a5
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Regularizers/RegularizerArgs.cs
@@ -0,0 +1,10 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Keras
+{
+ public class RegularizerArgs
+ {
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs
index de9f479b..c49618cf 100644
--- a/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs
+++ b/src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs
@@ -55,8 +55,8 @@ namespace Tensorflow.Keras.Utils
///
///
///
- public static string unique_layer_name(string name, Dictionary<(string, string), int> name_uid_map = null,
- string[] avoid_names = null, string @namespace = "", bool zero_based = false)
+ public static string unique_layer_name(string name, Dictionary name_uid_map = null,
+ string[] avoid_names = null, bool zero_based = false)
{
if (name_uid_map == null)
name_uid_map = get_default_graph_uid_map();
@@ -66,41 +66,40 @@ namespace Tensorflow.Keras.Utils
string proposed_name = null;
while (proposed_name == null || avoid_names.Contains(proposed_name))
{
- var name_key = (@namespace, name);
- if (!name_uid_map.ContainsKey(name_key))
- name_uid_map[name_key] = 0;
+ if (!name_uid_map.ContainsKey(name))
+ name_uid_map[name] = 0;
if (zero_based)
{
- int number = name_uid_map[name_key];
+ int number = name_uid_map[name];
if (number > 0)
proposed_name = $"{name}_{number}";
else
proposed_name = name;
- name_uid_map[name_key] += 1;
+ name_uid_map[name] += 1;
}
else
{
- name_uid_map[name_key] += 1;
- proposed_name = $"{name}_{name_uid_map[name_key]}";
+ name_uid_map[name] += 1;
+ proposed_name = $"{name}_{name_uid_map[name]}";
}
}
return proposed_name;
}
- public static Dictionary<(string, string), int> get_default_graph_uid_map()
+ public static Dictionary get_default_graph_uid_map()
{
var graph = ops.get_default_graph();
- Dictionary<(string, string), int> name_uid_map = null;
+ Dictionary name_uid_map = null;
if (tf.keras.backend.PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph))
{
name_uid_map = tf.keras.backend.PER_GRAPH_LAYER_NAME_UIDS[graph];
}
else
{
- name_uid_map = new Dictionary<(string, string), int>();
+ name_uid_map = new Dictionary();
tf.keras.backend.PER_GRAPH_LAYER_NAME_UIDS[graph] = name_uid_map;
}
diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs
index e07677e5..688e8266 100644
--- a/src/TensorFlowNET.Core/Layers/Layer.cs
+++ b/src/TensorFlowNET.Core/Layers/Layer.cs
@@ -183,8 +183,6 @@ namespace Tensorflow.Layers
});
}
-
-
protected override string _name_scope()
{
return _current_scope.original_name_scope;
@@ -202,7 +200,7 @@ namespace Tensorflow.Layers
}
else
{
- tf_with(tf.variable_scope(scope, default_name: baseName), captured_scope =>
+ tf_with(tf.variable_scope(scope, default_name: base_name), captured_scope =>
{
// convert variable_scope to VariableScope
_scope = captured_scope;
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs
index 147dccde..13635860 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs
@@ -41,8 +41,8 @@ namespace Tensorflow.Operations.Initializers
public Tensor Apply(InitializerArgs args)
{
if (args.DType == TF_DataType.DtInvalid)
- args.DType = this.dtype;
- return random_ops.random_normal(args.Shape, mean, stddev, dtype, seed: seed);
+ args.DType = dtype;
+ return random_ops.random_normal(args.Shape, mean, stddev, args.DType, seed: seed);
}
}
}
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/Convolution.cs b/src/TensorFlowNET.Core/Operations/NnOps/Convolution.cs
deleted file mode 100644
index be4aca3c..00000000
--- a/src/TensorFlowNET.Core/Operations/NnOps/Convolution.cs
+++ /dev/null
@@ -1,84 +0,0 @@
-/*****************************************************************************
- Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-******************************************************************************/
-
-using System.Linq;
-
-namespace Tensorflow.Operations
-{
- public class Convolution
- {
- public TensorShape input_shape;
- public TensorShape filter_shape;
- public string data_format;
- public int[] strides;
- public string name;
- public _WithSpaceToBatch conv_op;
-
- public Convolution(TensorShape input_shape,
- TensorShape filter_shape,
- string padding,
- int[] strides,
- int[] dilation_rate,
- string name = null,
- string data_format = null)
- {
- var num_total_dims = filter_shape.ndim;
- var num_spatial_dims = num_total_dims - 2;
- int input_channels_dim;
- int[] spatial_dims;
- if (string.IsNullOrEmpty(data_format) || !data_format.StartsWith("NC"))
- {
- input_channels_dim = input_shape.dims[num_spatial_dims + 1];
- spatial_dims = Enumerable.Range(1, num_spatial_dims).ToArray();
- }
- else
- {
- input_channels_dim = input_shape.dims[1];
- spatial_dims = Enumerable.Range(2, num_spatial_dims).ToArray();
- }
-
- this.input_shape = input_shape;
- this.filter_shape = filter_shape;
- this.data_format = data_format;
- this.strides = strides;
- this.name = name;
-
- conv_op = new _WithSpaceToBatch(
- input_shape,
- dilation_rate: dilation_rate,
- padding: padding,
- build_op: _build_op,
- filter_shape: filter_shape,
- spatial_dims: spatial_dims,
- data_format: data_format);
- }
-
- public _NonAtrousConvolution _build_op(int _, string padding)
- {
- return new _NonAtrousConvolution(input_shape,
- filter_shape: filter_shape,
- padding: padding,
- data_format: data_format,
- strides: strides,
- name: name);
- }
-
- public Tensor __call__(Tensor inp, IVariableV1 filter)
- {
- return conv_op.__call__(inp, filter);
- }
- }
-}
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/ConvolutionInternal.cs b/src/TensorFlowNET.Core/Operations/NnOps/ConvolutionInternal.cs
new file mode 100644
index 00000000..75b44af3
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/NnOps/ConvolutionInternal.cs
@@ -0,0 +1,100 @@
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Xml;
+using Tensorflow.Keras.ArgsDefinition;
+using static Tensorflow.Binding;
+
+namespace Tensorflow.Operations
+{
+ internal class ConvolutionInternal
+ {
+ ConvolutionalArgs args;
+
+ string data_format => args.DataFormat;
+ string name;
+ string padding => args.Padding;
+
+ public ConvolutionInternal(ConvolutionalArgs args)
+ {
+ this.args = args;
+ name = args.Name;
+ }
+
+ public Tensor Apply(Tensors input, IVariableV1 filters)
+ {
+ var filters_rank = filters.shape.rank;
+ var inputs_rank = input.shape.rank;
+ var num_spatial_dims = args.NumSpatialDims;
+ if (num_spatial_dims == Unknown)
+ num_spatial_dims = filters_rank - 2;
+
+ // Channel dimension.
+ var num_batch_dims = inputs_rank - num_spatial_dims - 1;
+ if (!new[] { 1, 2, 3 }.Contains(num_spatial_dims))
+ throw new ValueError($"num_spatial_dims (input.shape.ndims - num_batch_dims - 1) must be one " +
+ $"of 1, 2 or 3 but saw {num_spatial_dims}. num_batch_dims: {num_batch_dims}.");
+
+ var channel_index = num_batch_dims + num_spatial_dims;
+ var dilations = _get_sequence(args.DilationRate, num_spatial_dims, channel_index);
+ var strides = _get_sequence(args.Strides, num_spatial_dims, channel_index);
+
+ Tensor result = null;
+ tf_with(ops.name_scope(name, default_name: null, (input, filters)), scope =>
+ {
+ name = scope;
+ if (num_spatial_dims == 2)
+ result = gen_nn_ops.conv2d(new Conv2dParams
+ {
+ Input = input,
+ Filter = filters.AsTensor(),
+ Strides = strides,
+ Padding = padding,
+ DataFormat = data_format,
+ Dilations = dilations,
+ Name = name
+ });
+ else
+ throw new NotImplementedException("");
+ });
+
+ return result;
+ }
+
+ int[] _get_sequence(int[] value, int n, int channel_index)
+ {
+ var seq = new List();
+
+ if (channel_index == 1)
+ {
+ seq.Add(1);
+ seq.Add(1);
+ seq.AddRange(value);
+ }
+ else
+ {
+ seq.Add(1);
+ seq.AddRange(value);
+ seq.Add(1);
+ }
+
+ return seq.ToArray();
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/_NonAtrousConvolution.cs b/src/TensorFlowNET.Core/Operations/NnOps/_NonAtrousConvolution.cs
deleted file mode 100644
index f947cdbc..00000000
--- a/src/TensorFlowNET.Core/Operations/NnOps/_NonAtrousConvolution.cs
+++ /dev/null
@@ -1,83 +0,0 @@
-/*****************************************************************************
- Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-******************************************************************************/
-
-using System;
-using System.Linq;
-
-namespace Tensorflow.Operations
-{
- public class _NonAtrousConvolution
- {
- public string padding;
- public string name;
- public int[] strides;
- public string data_format;
- private Func conv_op;
-
- public _NonAtrousConvolution(TensorShape input_shape,
- TensorShape filter_shape,
- string padding,
- string data_format,
- int[] strides,
- string name)
- {
- this.padding = padding;
- this.name = name;
- var conv_dims = input_shape.ndim - 2;
- if (conv_dims == 1)
- {
- throw new NotImplementedException("_NonAtrousConvolution conv_dims 1");
- }
- else if (conv_dims == 2)
- {
- var list = strides.ToList();
-
- if (string.IsNullOrEmpty(data_format) || data_format == "NHWC")
- {
- data_format = "NHWC";
- list.Insert(0, 1);
- list.Add(1);
- }
- else if (data_format == "NCHW")
- list.InsertRange(0, new int[] { 1, 1 });
- else
- throw new ValueError("data_format must be \"NHWC\" or \"NCHW\".");
-
- strides = list.ToArray();
- this.strides = strides;
- this.data_format = data_format;
- conv_op = gen_nn_ops.conv2d;
- }
- else if (conv_dims == 3)
- {
- throw new NotImplementedException("_NonAtrousConvolution conv_dims 3");
- }
- }
-
- public Tensor __call__(Tensor inp, IVariableV1 filter)
- {
- return conv_op(new Conv2dParams
- {
- Input = inp,
- Filter = filter.AsTensor(),
- Strides = strides,
- Padding = padding,
- DataFormat = data_format,
- Name = name
- });
- }
- }
-}
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/_WithSpaceToBatch.cs b/src/TensorFlowNET.Core/Operations/NnOps/_WithSpaceToBatch.cs
deleted file mode 100644
index 8ae4ee36..00000000
--- a/src/TensorFlowNET.Core/Operations/NnOps/_WithSpaceToBatch.cs
+++ /dev/null
@@ -1,76 +0,0 @@
-/*****************************************************************************
- Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-******************************************************************************/
-
-using System;
-using System.Linq;
-
-namespace Tensorflow.Operations
-{
- public class _WithSpaceToBatch
- {
- private _NonAtrousConvolution call;
-
- public _WithSpaceToBatch(TensorShape input_shape,
- int[] dilation_rate,
- string padding,
- Func build_op,
- TensorShape filter_shape = null,
- int[] spatial_dims = null,
- string data_format = null)
- {
- var dilation_rate_tensor = ops.convert_to_tensor(dilation_rate, TF_DataType.TF_INT32, name: "dilation_rate");
- var rate_shape = dilation_rate_tensor.TensorShape;
- var num_spatial_dims = rate_shape.dims[0];
-#pragma warning disable CS0219 // Variable is assigned but its value is never used
- int starting_spatial_dim = -1;
-#pragma warning restore CS0219 // Variable is assigned but its value is never used
- if (!string.IsNullOrEmpty(data_format) && data_format.StartsWith("NC"))
- starting_spatial_dim = 2;
- else
- starting_spatial_dim = 1;
-
- if (spatial_dims == null)
- throw new NotImplementedException("_WithSpaceToBatch spatial_dims");
-
- var orig_spatial_dims = spatial_dims;
- spatial_dims = spatial_dims.OrderBy(x => x).ToArray();
- if (!Enumerable.SequenceEqual(spatial_dims, orig_spatial_dims) || spatial_dims.Any(x => x < 1))
- throw new ValueError("spatial_dims must be a montonically increasing sequence of positive integers");
-
- int expected_input_rank = -1;
- if (!string.IsNullOrEmpty(data_format) && data_format.StartsWith("NC"))
- expected_input_rank = spatial_dims.Last();
- else
- expected_input_rank = spatial_dims.Last() + 1;
-
- var const_rate = tensor_util.constant_value(dilation_rate_tensor);
- var rate_or_const_rate = dilation_rate;
- if(!(const_rate is null))
- {
- if (const_rate.Data().Count(x => x == 1) == const_rate.size)
- {
- call = build_op(num_spatial_dims, padding);
- return;
- }
- }
- }
-
- public Tensor __call__(Tensor inp, IVariableV1 filter)
- {
- return call.__call__(inp, filter);
- }
- }
-}
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
index b239cfd8..fb19ab4e 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
@@ -78,35 +78,45 @@ namespace Tensorflow.Operations
///
///
///
- public static Tensor conv2d_backprop_filter(Conv2dParams parameters)
+ public static Tensor conv2d_backprop_filter(Tensor input, Tensor filter_sizes, Tensor out_backprop,
+ int[] strides, string padding, bool use_cudnn_on_gpu = true,
+ int[] explicit_paddings = null,
+ string data_format = "NHWC",
+ int[] dilations = null,
+ string name = null)
{
+ if (explicit_paddings == null)
+ explicit_paddings = new int[0];
+ if (dilations == null)
+ dilations = new int[] { 1, 1, 1, 1 };
+
if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
- "Conv2DBackpropFilter", parameters.Name,
+ "Conv2DBackpropFilter", name,
null,
- parameters.Input, parameters.FilterSizes, parameters.OutBackProp,
- "strides", parameters.Strides,
- "use_cudnn_on_gpu", parameters.UseCudnnOnGpu,
- "padding", parameters.Padding,
- "explicit_paddings", parameters.ExplicitPaddings,
- "data_format", parameters.DataFormat,
- "dilations", parameters.Dilations);
+ input, filter_sizes, out_backprop,
+ "strides", strides,
+ "use_cudnn_on_gpu", use_cudnn_on_gpu,
+ "padding", padding,
+ "explicit_paddings", explicit_paddings,
+ "data_format", data_format,
+ "dilations", dilations);
return results[0];
}
- var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropFilter", name: parameters.Name, args: new
+ var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropFilter", name: name, args: new
{
- input = parameters.Input,
- filter_sizes = parameters.FilterSizes,
- out_backprop = parameters.OutBackProp,
- strides = parameters.Strides,
- padding = parameters.Padding,
- use_cudnn_on_gpu = parameters.UseCudnnOnGpu,
- explicit_paddings = parameters.ExplicitPaddings,
- data_format = parameters.DataFormat,
- dilations = parameters.Dilations
+ input,
+ filter_sizes,
+ out_backprop,
+ strides,
+ padding,
+ use_cudnn_on_gpu,
+ explicit_paddings,
+ data_format,
+ dilations
});
return _op.outputs[0];
@@ -117,35 +127,45 @@ namespace Tensorflow.Operations
///
///
///
- public static Tensor conv2d_backprop_input(Conv2dParams parameters)
+ public static Tensor conv2d_backprop_input(Tensor input_sizes, Tensor filter, Tensor out_backprop,
+ int[] strides, string padding, bool use_cudnn_on_gpu = true,
+ int[] explicit_paddings = null,
+ string data_format= "NHWC",
+ int[] dilations = null,
+ string name = null)
{
+ if (explicit_paddings == null)
+ explicit_paddings = new int[0];
+ if (dilations == null)
+ dilations = new int[] { 1, 1, 1, 1 };
+
if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
- "Conv2DBackpropInput", parameters.Name,
+ "Conv2DBackpropInput", name,
null,
- parameters.InputSizes, parameters.Filter, parameters.OutBackProp,
- "strides", parameters.Strides,
- "use_cudnn_on_gpu", parameters.UseCudnnOnGpu,
- "padding", parameters.Padding,
- "explicit_paddings", parameters.ExplicitPaddings,
- "data_format", parameters.DataFormat,
- "dilations", parameters.Dilations);
+ input_sizes, filter, out_backprop,
+ "strides", strides,
+ "use_cudnn_on_gpu", use_cudnn_on_gpu,
+ "padding", padding,
+ "explicit_paddings", explicit_paddings,
+ "data_format", data_format,
+ "dilations", dilations);
return results[0];
}
- var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropInput", name: parameters.Name, args: new
+ var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropInput", name: name, args: new
{
- input_sizes = parameters.InputSizes,
- filter = parameters.Filter,
- out_backprop = parameters.OutBackProp,
- strides = parameters.Strides,
- padding = parameters.Padding,
- use_cudnn_on_gpu = parameters.UseCudnnOnGpu,
- explicit_paddings = parameters.ExplicitPaddings,
- data_format = parameters.DataFormat,
- dilations = parameters.Dilations
+ input_sizes,
+ filter,
+ out_backprop,
+ strides,
+ padding,
+ use_cudnn_on_gpu,
+ explicit_paddings,
+ data_format,
+ dilations
});
return _op.outputs[0];
diff --git a/src/TensorFlowNET.Core/Operations/gen_random_ops.cs b/src/TensorFlowNET.Core/Operations/gen_random_ops.cs
index f3442be8..a56a4cdc 100644
--- a/src/TensorFlowNET.Core/Operations/gen_random_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_random_ops.cs
@@ -33,11 +33,6 @@ namespace Tensorflow
///
public static Tensor random_standard_normal(Tensor shape, TF_DataType dtype = TF_DataType.DtInvalid, int? seed = null, int? seed2 = null, string name = null)
{
- if (!seed.HasValue)
- seed = 0;
- if (!seed2.HasValue)
- seed2 = 0;
-
if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
@@ -51,6 +46,11 @@ namespace Tensorflow
return results[0];
}
+ if (!seed.HasValue)
+ seed = 0;
+ if (!seed2.HasValue)
+ seed2 = 0;
+
var _op = tf.OpDefLib._apply_op_helper("RandomStandardNormal",
name: name,
args: new { shape, dtype, seed, seed2 });
diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs
index 4c30c34e..8ded44f1 100644
--- a/src/TensorFlowNET.Core/Operations/nn_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs
@@ -16,6 +16,7 @@
using System;
using System.Linq;
+using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Operations;
using static Tensorflow.Binding;
@@ -23,19 +24,18 @@ namespace Tensorflow
{
public class nn_ops
{
- public static Convolution Convolution(TensorShape input_shape,
- TensorShape filter_shape,
- string padding,
+ internal static ConvolutionInternal convolution_internal(string padding,
int[] strides,
int[] dilation_rate,
string name = null,
- string data_format = null) => new Convolution(input_shape,
- filter_shape,
- padding,
- strides,
- dilation_rate,
- name: name,
- data_format: data_format);
+ string data_format = null) => new ConvolutionInternal(new ConvolutionalArgs
+ {
+ Padding = padding,
+ Strides = strides,
+ DilationRate = dilation_rate,
+ DataFormat = data_format,
+ Name = name
+ });
///
/// Adds `bias` to `value`.
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs
index 0a3ea47a..7d4f57d9 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs
@@ -64,7 +64,7 @@ namespace Tensorflow
/// The string name of this tensor.
/// Tensor.name is meaningless when eager execution is enabled.
///
- public string name => $"{(op == null ? "" : $"{op.name}:{_value_index}")}";
+ public virtual string name => $"{(op == null ? "" : $"{op.name}:{_value_index}")}";
///
/// The index of this tensor in the outputs of its Operation.
diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs
index 2f130002..34c26bbb 100644
--- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs
+++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs
@@ -132,7 +132,7 @@ namespace Tensorflow
}
}
- public int this[int index] => dims[index];
+ public int this[int index] => index < 0 ? dims[ndim + index] : dims[index];
///
/// Returns True iff `self` is fully defined in every dimension.
diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
index 5fe0043e..fca60f88 100644
--- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
+++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
@@ -8,7 +8,7 @@ using static Tensorflow.Binding;
namespace Tensorflow
{
- public class BaseResourceVariable : DisposableObject, IVariableV1
+ public class BaseResourceVariable : DisposableObject
{
protected string _name;
public virtual string Name => _handle_name;
@@ -92,7 +92,8 @@ namespace Tensorflow
return assign_op;
}
- public Tensor value() => tf.executing_eagerly() ? _read_variable_op() : GraphElement;
+ public Tensor value()
+ => GraphElement ?? _read_variable_op();
protected Tensor _read_variable_op()
{
@@ -159,7 +160,15 @@ namespace Tensorflow
{
}
- public Tensor AsTensor()
- => tf.executing_eagerly() ? read_value() : GraphElement;
+ public Tensor AsTensor(bool as_ref = true)
+ {
+ if (!as_ref && GraphElement != null)
+ return GraphElement;
+
+ if (as_ref)
+ return tf.executing_eagerly() ? read_value() : GraphElement;
+ else
+ return _read_variable_op();
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Variables/IVariableV1.cs b/src/TensorFlowNET.Core/Variables/IVariableV1.cs
index 52549ecc..4367cf09 100644
--- a/src/TensorFlowNET.Core/Variables/IVariableV1.cs
+++ b/src/TensorFlowNET.Core/Variables/IVariableV1.cs
@@ -49,6 +49,6 @@ namespace Tensorflow
public TensorShape shape { get; }
Tensor assign_add(T delta, bool use_locking = false, string name = null, bool read_value = true);
Tensor assign(T value, bool use_locking = false, string name = null, bool read_value = true);
- Tensor AsTensor();
+ Tensor AsTensor(bool as_ref = true);
}
}
diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs
index cf9fe2f1..68df1e66 100644
--- a/src/TensorFlowNET.Core/Variables/RefVariable.cs
+++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs
@@ -152,7 +152,7 @@ namespace Tensorflow
if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES))
collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES);
- tf_with(ops.init_scope2(), delegate
+ tf_with(ops.init_scope(), init_scope =>
{
var values = init_from_fn ? new object[0] : new object[] { initial_value };
tf_with(ops.name_scope(name, "Variable", values), scope =>
@@ -222,7 +222,7 @@ namespace Tensorflow
public Tensor value() => _snapshot;
- public Tensor AsTensor() => _snapshot;
+ public Tensor AsTensor(bool as_ref = true) => _snapshot;
public Tensor _as_graph_element() => _variable;
diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs
index 3655a6db..40fb07bc 100644
--- a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs
+++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs
@@ -26,7 +26,7 @@ namespace Tensorflow
///
/// Variable based on resource handles.
///
- public partial class ResourceVariable : BaseResourceVariable
+ public partial class ResourceVariable : BaseResourceVariable, IVariableV1
{
Tensor _cached_value;
public string Device => handle.Device;
@@ -90,7 +90,7 @@ namespace Tensorflow
collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES);
_in_graph_mode = !tf.Context.executing_eagerly();
- tf_with(ops.init_scope2(), delegate
+ tf_with(ops.init_scope(), init_scope =>
{
var values = init_from_fn ? new object[0] : new object[] { initial_value };
tf_with(ops.name_scope(name, "Variable", values), scope =>
diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs
index cf935ab3..fb74cb89 100644
--- a/src/TensorFlowNET.Core/ops.cs
+++ b/src/TensorFlowNET.Core/ops.cs
@@ -239,11 +239,8 @@ namespace Tensorflow
/// A context manager that lifts ops out of control-flow scopes and function-building graphs.
///
///
- public static void init_scope()
+ public static NameScope init_scope()
{
- if (tf.Context.executing_eagerly())
- return;
-
// Retrieve the active name scope: entering an `init_scope` preserves
// the name scope of the current context.
var default_graph = get_default_graph();
@@ -257,25 +254,11 @@ namespace Tensorflow
tf_with(ops.control_dependencies(null), delegate
{
- var outer_graph = get_default_graph();
+ // var outer_graph = get_default_graph();
// outer_device_stack = None
});
- }
-
- public static ITensorFlowObject init_scope2()
- {
- // Retrieve the active name scope: entering an `init_scope` preserves
- // the name scope of the current context.
- var default_graph = get_default_graph();
- var scope = default_graph.get_name_scope();
- if (!String.IsNullOrEmpty(scope) && !scope.EndsWith("/"))
- // Names that end with trailing slashes are treated by `name_scope` as
- // absolute.
- scope += "/";
- // inner_device_stack = default_graph._device_function_stack
- // var outer_context = default_graph.as_default;
- return ops.control_dependencies(null);
+ return ops.name_scope(scope);
}
private static int uid_number = -1;
@@ -460,6 +443,8 @@ namespace Tensorflow
{
case NDArray nd:
return constant_op.constant(nd, dtype: dtype, name: name);
+ case EagerTensor tensor:
+ return tf.executing_eagerly() ? tensor : tensor.AsPlaceholder(name: name);
case Tensor tensor:
return tensor;
case Tensor[] tensors:
diff --git a/src/TensorFlowNET.Core/ops.name_scope.cs b/src/TensorFlowNET.Core/ops.name_scope.cs
index 97a24525..1fb2b43c 100644
--- a/src/TensorFlowNET.Core/ops.name_scope.cs
+++ b/src/TensorFlowNET.Core/ops.name_scope.cs
@@ -90,6 +90,7 @@ namespace Tensorflow
return (scope_name, old_name);
}
+ [DebuggerHidden]
public void Dispose()
{
if (tf.Context.executing_eagerly())
diff --git a/test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs b/test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs
index 050151af..40d0e22b 100644
--- a/test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs
+++ b/test/TensorFlowNET.UnitTest/Keras/ModelSaveTest.cs
@@ -28,7 +28,8 @@ namespace TensorFlowNET.UnitTest.Keras
// Create a simple model.
var inputs = keras.Input(shape: 32);
- var outputs = keras.layers.Dense(1).Apply(inputs);
+ var dense_layer = keras.layers.Dense(1);
+ var outputs = dense_layer.Apply(inputs);
var model = keras.Model(inputs, outputs);
model.compile("adam", "mean_squared_error");
return model;
diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/BitwiseApiTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/BitwiseApiTest.cs
index 8c5420da..4ec4eb25 100644
--- a/test/TensorFlowNET.UnitTest/ManagedAPI/BitwiseApiTest.cs
+++ b/test/TensorFlowNET.UnitTest/ManagedAPI/BitwiseApiTest.cs
@@ -8,6 +8,11 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
[TestClass]
public class BitwiseApiTest : TFNetApiTest
{
+ [TestInitialize]
+ public void Init()
+ {
+ tf.enable_eager_execution();
+ }
[TestMethod]
public void BitwiseAnd()