@@ -92,7 +92,7 @@ namespace Tensorflow | |||
/// <param name="renorm"></param> | |||
/// <param name="renorm_momentum"></param> | |||
/// <returns></returns> | |||
public Tensor batch_normalization(Tensor inputs, | |||
public Tensors batch_normalization(Tensor inputs, | |||
int axis = -1, | |||
float momentum = 0.99f, | |||
float epsilon = 0.001f, | |||
@@ -108,22 +108,24 @@ namespace Tensorflow | |||
bool renorm = false, | |||
float renorm_momentum = 0.99f) | |||
{ | |||
var layer = new BatchNormalization( | |||
axis: axis, | |||
momentum: momentum, | |||
epsilon: epsilon, | |||
center: center, | |||
scale: scale, | |||
beta_initializer: beta_initializer, | |||
gamma_initializer: gamma_initializer, | |||
moving_mean_initializer: moving_mean_initializer, | |||
moving_variance_initializer: moving_variance_initializer, | |||
renorm: renorm, | |||
renorm_momentum: renorm_momentum, | |||
trainable: trainable, | |||
name: name); | |||
return layer.apply(inputs, training: training).Item1; | |||
var layer = new BatchNormalization(new BatchNormalizationArgs | |||
{ | |||
Axis = axis, | |||
Momentum = momentum, | |||
Epsilon = epsilon, | |||
Center = center, | |||
Scale = scale, | |||
BetaInitializer = beta_initializer, | |||
GammaInitializer = gamma_initializer, | |||
MovingMeanInitializer = moving_mean_initializer, | |||
MovingVarianceInitializer = moving_variance_initializer, | |||
Renorm = renorm, | |||
RenormMomentum = renorm_momentum, | |||
Trainable = trainable, | |||
Name = name | |||
}); | |||
return layer.Apply(inputs); | |||
} | |||
/// <summary> | |||
@@ -41,6 +41,8 @@ namespace Tensorflow | |||
/// <summary> | |||
/// A context manager that lifts ops out of control-flow scopes and function-building graphs. | |||
/// When eager execution is enabled, code inside an init_scope block runs with | |||
/// eager execution enabled even when tracing a `tf.function`. | |||
/// </summary> | |||
public void init_scope() | |||
=> ops.init_scope(); | |||
@@ -227,6 +227,11 @@ namespace Tensorflow.Eager | |||
input_handle = input.EagerTensorHandle; | |||
flattened_inputs.Add(input); | |||
break; | |||
case ResourceVariable variable: | |||
var var_tensor = variable.AsTensor(); | |||
input_handle = var_tensor.EagerTensorHandle; | |||
flattened_inputs.Add(var_tensor); | |||
break; | |||
default: | |||
var tensor = tf.convert_to_tensor(inputs); | |||
input_handle = tensor.EagerTensorHandle; | |||
@@ -57,6 +57,26 @@ namespace Tensorflow.Eager | |||
return this; | |||
} | |||
/// <summary> | |||
/// _create_substitute_placeholder | |||
/// </summary> | |||
/// <returns></returns> | |||
public Tensor AsPlaceholder(string name = null) | |||
{ | |||
Tensor placeholder = null; | |||
tf_with(ops.control_dependencies(null), delegate | |||
{ | |||
placeholder = tf.placeholder(dtype, shape: shape, name: name ?? this.name); | |||
}); | |||
// custom_gradient.copy_handle_data(value, placeholder) | |||
return placeholder; | |||
} | |||
void copy_handle_data() | |||
{ | |||
} | |||
public override IntPtr ToPointer() | |||
=> EagerTensorHandle?.DangerousGetHandle() ?? IntPtr.Zero; | |||
@@ -138,30 +138,16 @@ namespace Tensorflow.Gradients | |||
return new Tensor[] | |||
{ | |||
gen_nn_ops.conv2d_backprop_input(new Conv2dParams | |||
{ | |||
InputSizes = shape[0], | |||
Filter = op.inputs[1], | |||
OutBackProp = grads[0], | |||
Dilations = dilations, | |||
Strides = strides, | |||
Padding = padding.ToString(), | |||
ExplicitPaddings = explicit_paddings, | |||
UseCudnnOnGpu = (bool)use_cudnn_on_gpu, | |||
DataFormat = data_format.ToString(), | |||
}), | |||
gen_nn_ops.conv2d_backprop_filter(new Conv2dParams | |||
{ | |||
Input = op.inputs[0], | |||
FilterSizes = shape[1], | |||
OutBackProp = grads[0], | |||
Dilations = dilations, | |||
Strides = strides, | |||
Padding = padding.ToString(), | |||
ExplicitPaddings = explicit_paddings, | |||
UseCudnnOnGpu = (bool)use_cudnn_on_gpu, | |||
DataFormat = data_format.ToString() | |||
}) | |||
gen_nn_ops.conv2d_backprop_input(shape[0], op.inputs[1], grads[0], | |||
strides, padding, use_cudnn_on_gpu, explicit_paddings, | |||
dilations: dilations, | |||
data_format: data_format), | |||
gen_nn_ops.conv2d_backprop_filter(op.inputs[0], shape[1], grads[0], | |||
strides, padding, | |||
dilations: dilations, | |||
explicit_paddings: explicit_paddings, | |||
use_cudnn_on_gpu: use_cudnn_on_gpu, | |||
data_format: data_format) | |||
}; | |||
} | |||
@@ -0,0 +1,24 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class BatchNormalizationArgs : LayerArgs | |||
{ | |||
public TensorShape Axis { get; set; } = -1; | |||
public float Momentum { get; set; } = 0.99f; | |||
public float Epsilon { get; set; } = 1e-3f; | |||
public bool Center { get; set; } = true; | |||
public bool Scale { get; set; } = true; | |||
public IInitializer BetaInitializer { get; set; } = tf.zeros_initializer; | |||
public IInitializer GammaInitializer { get; set; } = tf.ones_initializer; | |||
public IInitializer MovingMeanInitializer { get; set; } = tf.zeros_initializer; | |||
public IInitializer MovingVarianceInitializer { get; set; } = tf.ones_initializer; | |||
public IRegularizer BetaRegularizer { get; set; } | |||
public IRegularizer GammaRegularizer { get; set; } | |||
public bool Renorm { get; set; } | |||
public float RenormMomentum { get; set; } = 0.99f; | |||
} | |||
} |
@@ -4,7 +4,7 @@ using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class Conv2DArgs : ConvArgs | |||
public class Conv2DArgs : ConvolutionalArgs | |||
{ | |||
} | |||
@@ -5,10 +5,11 @@ using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class ConvArgs : LayerArgs | |||
public class ConvolutionalArgs : LayerArgs | |||
{ | |||
public int Rank { get; set; } = 2; | |||
public int Filters { get; set; } | |||
public int NumSpatialDims { get; set; } = Unknown; | |||
public TensorShape KernelSize { get; set; } = 5; | |||
/// <summary> | |||
@@ -24,8 +25,8 @@ namespace Tensorflow.Keras.ArgsDefinition | |||
public bool UseBias { get; set; } | |||
public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; | |||
public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; | |||
public IInitializer KernelRegularizer { get; set; } | |||
public IInitializer BiasRegularizer { get; set; } | |||
public IRegularizer KernelRegularizer { get; set; } | |||
public IRegularizer BiasRegularizer { get; set; } | |||
public Action KernelConstraint { get; set; } | |||
public Action BiasConstraint { get; set; } | |||
} |
@@ -46,7 +46,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||
/// <summary> | |||
/// Regularizer function applied to the output of the layer(its "activation"). | |||
/// </summary> | |||
public IInitializer ActivityRegularizer { get; set; } | |||
public IRegularizer ActivityRegularizer { get; set; } | |||
public bool Autocast { get; set; } | |||
} | |||
@@ -6,7 +6,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class ModelArgs : LayerArgs | |||
{ | |||
public Tensor[] Inputs { get; set; } | |||
public Tensor[] Outputs { get; set; } | |||
public Tensors Inputs { get; set; } | |||
public Tensors Outputs { get; set; } | |||
} | |||
} |
@@ -42,7 +42,7 @@ namespace Tensorflow.Keras | |||
/// for various layer names in each graph. | |||
/// Allows to give unique autogenerated names to layers, in a graph-specific way. | |||
/// </summary> | |||
public Dictionary<Graph, Dictionary<(string, string), int>> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<(string, string), int>>(); | |||
public Dictionary<Graph, Dictionary<string, int>> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<string, int>>(); | |||
public Dictionary<string, IVariableV1> _GRAPH_VARIABLES = new Dictionary<string, IVariableV1>(); | |||
public Dictionary<string, Optimizer> _GRAPH_TF_OPTIMIZERS = new Dictionary<string, Optimizer>(); | |||
@@ -80,25 +80,19 @@ namespace Tensorflow.Keras | |||
return ops.get_default_graph(); | |||
} | |||
public int get_uid(string prefix, string @namespace = "") | |||
public int get_uid(string prefix) | |||
{ | |||
var graph = tf.get_default_graph(); | |||
if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph)) | |||
PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<(string, string), int>()); | |||
PER_GRAPH_LAYER_NAME_UIDS[graph][(@namespace, prefix)] += 1; | |||
PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<string, int>()); | |||
if (!PER_GRAPH_LAYER_NAME_UIDS[graph].ContainsKey(prefix)) | |||
PER_GRAPH_LAYER_NAME_UIDS[graph][prefix] = 0; | |||
PER_GRAPH_LAYER_NAME_UIDS[graph][prefix] += 1; | |||
return PER_GRAPH_LAYER_NAME_UIDS[graph][(@namespace, prefix)]; | |||
return PER_GRAPH_LAYER_NAME_UIDS[graph][prefix]; | |||
} | |||
public int get_uid((string, string) name) | |||
{ | |||
var graph = tf.get_default_graph(); | |||
if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph)) | |||
PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<(string, string), int>()); | |||
PER_GRAPH_LAYER_NAME_UIDS[graph][(name)] += 1; | |||
return PER_GRAPH_LAYER_NAME_UIDS[graph][name]; | |||
} | |||
public void reset_uids() => PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<(string, string), int>>(); | |||
public void reset_uids() => PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<string, int>>(); | |||
public void clear_session() | |||
{ | |||
ops.reset_default_graph(); | |||
@@ -0,0 +1,67 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
/// <summary> | |||
/// A `Functional` model is a `Model` defined as a directed graph of layers. | |||
/// </summary> | |||
public class Functional : Model | |||
{ | |||
TensorShape _build_input_shape; | |||
bool _compute_output_and_mask_jointly; | |||
bool _expects_training_arg; | |||
bool _expects_mask_arg; | |||
bool _autocast; | |||
List<Layer> _output_layers; | |||
List<Layer> _input_layers; | |||
List<KerasHistory> _input_coordinates; | |||
List<KerasHistory> _output_coordinates; | |||
public Functional(Tensors inputs, Tensors outputs) | |||
: base(new ModelArgs | |||
{ | |||
Inputs = inputs, | |||
Outputs = outputs | |||
}) | |||
{ | |||
_input_layers = new List<Layer>(); | |||
_output_layers = new List<Layer>(); | |||
_input_coordinates = new List<KerasHistory>(); | |||
_output_coordinates = new List<KerasHistory>(); | |||
_init_graph_network(inputs, outputs); | |||
} | |||
void _init_graph_network(Tensors inputs, Tensors outputs) | |||
{ | |||
_is_graph_network = true; | |||
this.inputs = inputs; | |||
this.outputs = outputs; | |||
built = true; | |||
_build_input_shape = inputs.shape; | |||
_compute_output_and_mask_jointly = true; | |||
_expects_training_arg = true; | |||
_expects_mask_arg = true; | |||
// A graph network does not autocast inputs, as its layers will cast them instead. | |||
_autocast = false; | |||
// Build self._output_layers: | |||
foreach(var x in outputs) | |||
{ | |||
var (layer, node_index, tensor_index) = x.KerasHistory; | |||
_output_layers.append(layer); | |||
_output_coordinates.append(new KerasHistory(layer, node_index, tensor_index)); | |||
} | |||
// Build self._input_layers: | |||
foreach(var x in inputs) | |||
{ | |||
var (layer, node_index, tensor_index) = x.KerasHistory; | |||
_input_layers.append(layer); | |||
_input_coordinates.append(new KerasHistory(layer, node_index, tensor_index)); | |||
} | |||
} | |||
} | |||
} |
@@ -15,6 +15,7 @@ | |||
******************************************************************************/ | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
@@ -27,6 +28,7 @@ namespace Tensorflow.Keras.Engine | |||
public int? min_ndim; | |||
Dictionary<int, int> axes; | |||
TensorShape shape; | |||
public int[] AllAxisDim; | |||
public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid, | |||
int? ndim = null, | |||
@@ -42,6 +44,12 @@ namespace Tensorflow.Keras.Engine | |||
this.shape = shape; | |||
if (ndim == null && shape != null) | |||
this.ndim = shape.ndim; | |||
if(axes != null) | |||
AllAxisDim = axes.Select(x => x.Value).ToArray(); | |||
} | |||
public override string ToString() | |||
=> $"min_ndim={min_ndim}, , axes={axes.Count}"; | |||
} | |||
} |
@@ -20,10 +20,14 @@ namespace Tensorflow.Keras.Engine | |||
this.tensor_index = tensor_index; | |||
} | |||
public void Deconstruct(out Layer layer, out int node_index, out int tensor_index) | |||
{ | |||
layer = this.layer; | |||
node_index = this.node_index; | |||
tensor_index = this.tensor_index; | |||
} | |||
public static implicit operator Layer(KerasHistory history) | |||
=> history.layer; | |||
public static implicit operator (Layer, int, int)(KerasHistory history) | |||
=> (history.layer, history.node_index, history.tensor_index); | |||
} | |||
} |
@@ -72,10 +72,10 @@ namespace Tensorflow.Keras.Engine | |||
protected List<IVariableV1> nonTrainableWeights; | |||
public List<IVariableV1> non_trainable_variables => nonTrainableWeights; | |||
string name; | |||
protected string name; | |||
protected string base_name; | |||
public string Name => name; | |||
protected string baseName; | |||
protected bool computePreviousMask; | |||
protected List<Operation> updates; | |||
public TensorShape BatchInputShape => args.BatchInputShape; | |||
@@ -98,9 +98,9 @@ namespace Tensorflow.Keras.Engine | |||
// Indicates whether `build` needs to be called upon layer call, to create | |||
// the layer's weights. | |||
built = false; | |||
this.SupportsMasking = false; | |||
SupportsMasking = false; | |||
_init_set_name(name); | |||
_init_set_name(args.Name); | |||
trainableWeights = new List<IVariableV1>(); | |||
nonTrainableWeights = new List<IVariableV1>(); | |||
computePreviousMask = false; | |||
@@ -124,23 +124,25 @@ namespace Tensorflow.Keras.Engine | |||
/// <returns></returns> | |||
public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false) | |||
{ | |||
Tensors outputs = null; | |||
callContext = callContext ?? new ThreadLocal<CallContext>() | |||
{ | |||
Value = new CallContext() | |||
}; | |||
if (_in_functional_construction_mode(inputs)) | |||
return _functional_construction_call(inputs); | |||
Tensors outputs = null; | |||
var eager = tf.executing_eagerly(); | |||
using var ctxManager = CallContext.enter(); | |||
string nameScope = ""; | |||
if (eager) | |||
nameScope = name; | |||
nameScope = Name; | |||
else | |||
nameScope = _name_scope(); | |||
// using var graph = tf.keras.backend.get_graph().as_default(); | |||
if (!inputs.IsEagerTensor) | |||
tf.Context.graph_mode(); | |||
@@ -162,6 +164,46 @@ namespace Tensorflow.Keras.Engine | |||
return outputs; | |||
} | |||
bool _in_functional_construction_mode(Tensors inputs) | |||
{ | |||
return inputs.Count(x => !x.IsEagerTensor) == inputs.Count(); | |||
} | |||
Tensors _functional_construction_call(Tensors inputs) | |||
{ | |||
bool mask_arg_passed_by_framework = false; | |||
bool training_arg_passed_by_framework = false; | |||
Tensor training_value = null; | |||
if(training_value == null) | |||
{ | |||
training_arg_passed_by_framework = true; | |||
} | |||
Tensors outputs = null; | |||
using var ctxManager = CallContext.enter(); | |||
// using var graph = tf.keras.backend.get_graph().as_default(); | |||
if (!inputs.IsEagerTensor) | |||
tf.Context.graph_mode(); | |||
tf_with(ops.name_scope(_name_scope()), scope => | |||
{ | |||
MaybeBuild(inputs); | |||
outputs = call(inputs); | |||
outputs = _set_connectivity_metadata_(inputs, outputs); | |||
_handle_activity_regularization(inputs, outputs); | |||
_set_mask_metadata(inputs, outputs, null); | |||
}); | |||
if (!inputs.IsEagerTensor) | |||
tf.Context.restore_mode(); | |||
return outputs; | |||
} | |||
private Tensors _set_connectivity_metadata_(Tensors inputs, Tensors outputs) | |||
{ | |||
/*var returnOutputs = new List<Tensor>(); | |||
@@ -219,8 +261,12 @@ namespace Tensorflow.Keras.Engine | |||
if (DType == TF_DataType.DtInvalid) | |||
args.DType = inputs.dtype; | |||
var input_shapes = inputs.shape; | |||
build(input_shapes); | |||
tf.init_scope(); | |||
//tf.Context.eager_mode(); | |||
build(inputs.shape); | |||
//tf.Context.restore_mode(); | |||
built = true; | |||
} | |||
@@ -229,10 +275,16 @@ namespace Tensorflow.Keras.Engine | |||
built = true; | |||
} | |||
protected virtual void add_loss(Func<Tensor> losses) | |||
{ | |||
} | |||
protected virtual IVariableV1 add_weight(string name, | |||
TensorShape shape, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
IInitializer initializer = null, | |||
IRegularizer regularizer = null, | |||
bool? trainable = null, | |||
Func<VariableArgs, IVariableV1> getter = null) | |||
{ | |||
@@ -251,7 +303,7 @@ namespace Tensorflow.Keras.Engine | |||
else if (dtype.is_integer()) | |||
initializer = tf.zeros_initializer; | |||
else | |||
throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {this.Name}"); | |||
throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}"); | |||
} | |||
var args = new VariableArgs | |||
@@ -266,6 +318,12 @@ namespace Tensorflow.Keras.Engine | |||
}; | |||
var variable = _add_variable_with_custom_getter(args); | |||
if(regularizer != null) | |||
{ | |||
var name_in_scope = variable.Name.Split(':')[0]; | |||
_handle_weight_regularization(name_in_scope, variable, regularizer); | |||
} | |||
//backend.track_variable(variable); | |||
if (trainable == true) | |||
trainableWeights.Add(variable); | |||
@@ -275,6 +333,20 @@ namespace Tensorflow.Keras.Engine | |||
return variable; | |||
} | |||
/// <summary> | |||
/// Create lambdas which compute regularization losses. | |||
/// </summary> | |||
/// <param name="name"></param> | |||
/// <param name="variable"></param> | |||
/// <param name="regularizer"></param> | |||
void _handle_weight_regularization(string name, IVariableV1 variable, IRegularizer regularizer) | |||
{ | |||
add_loss(() => regularizer.Apply(new RegularizerArgs | |||
{ | |||
})); | |||
} | |||
protected virtual void add_update(Tensor[] updates, bool inputs = false) | |||
{ | |||
var updates_op = updates.Select(x => x.op).ToArray(); | |||
@@ -284,17 +356,13 @@ namespace Tensorflow.Keras.Engine | |||
// Determine layer name (non-unique). | |||
protected virtual void _init_set_name(string name, bool zero_based = true) | |||
{ | |||
var base_name = name; | |||
base_name = name; | |||
this.name = name; | |||
if (name == null) | |||
(this.name, baseName) = _make_unique_name(); | |||
} | |||
protected virtual (string, string) _make_unique_name() | |||
{ | |||
string base_name = generic_utils.to_snake_case(this.GetType().Name); | |||
string name = base_layer_utils.unique_layer_name(base_name); | |||
return (name, base_name); | |||
{ | |||
base_name = generic_utils.to_snake_case(this.GetType().Name); | |||
this.name = base_layer_utils.unique_layer_name(base_name, zero_based: zero_based); | |||
} | |||
} | |||
} | |||
} |
@@ -23,15 +23,14 @@ namespace Tensorflow.Keras.Engine | |||
string loss; | |||
IOptimizer optimizer; | |||
IVariableV1 _steps_per_execution; | |||
protected bool _is_graph_network; | |||
protected Tensors inputs; | |||
protected Tensors outputs; | |||
public Model(ModelArgs args) | |||
: base(args) | |||
{ | |||
// Build _output_layers | |||
/*foreach(var x in args.Outputs) | |||
{ | |||
var layer = x.KerasHistory; | |||
}*/ | |||
} | |||
public void compile(string optimizerName, string lossName) | |||
@@ -16,6 +16,7 @@ namespace Tensorflow | |||
{ | |||
public KerasDataset datasets { get; } = new KerasDataset(); | |||
public Initializers initializers { get; } = new Initializers(); | |||
public Regularizers regularizers { get; } = new Regularizers(); | |||
public LayersApi layers { get; } = new LayersApi(); | |||
public LossesApi losses { get; } = new LossesApi(); | |||
public Activations activations { get; } = new Activations(); | |||
@@ -36,12 +37,8 @@ namespace Tensorflow | |||
/// <param name="input"></param> | |||
/// <param name="output"></param> | |||
/// <returns></returns> | |||
public Model Model(Tensor input, Tensor output) | |||
=> new Model(new ModelArgs | |||
{ | |||
Inputs = new[] { input }, | |||
Outputs = new[] { output } | |||
}); | |||
public Functional Model(Tensors inputs, Tensors outputs) | |||
=> new Functional(inputs, outputs); | |||
/// <summary> | |||
/// Instantiate a Keras tensor. | |||
@@ -15,73 +15,41 @@ | |||
******************************************************************************/ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Utils; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
public class BatchNormalization : Tensorflow.Layers.Layer | |||
public class BatchNormalization : Layer | |||
{ | |||
#pragma warning disable CS0414 // The field 'BatchNormalization._USE_V2_BEHAVIOR' is assigned but its value is never used | |||
private bool _USE_V2_BEHAVIOR = true; | |||
#pragma warning restore CS0414 // The field 'BatchNormalization._USE_V2_BEHAVIOR' is assigned but its value is never used | |||
private float momentum; | |||
private float epsilon; | |||
private bool center; | |||
private bool scale; | |||
private bool renorm; | |||
private bool fused; | |||
#pragma warning disable CS0414 // The field 'BatchNormalization._bessels_correction_test_only' is assigned but its value is never used | |||
private bool _bessels_correction_test_only; | |||
#pragma warning restore CS0414 // The field 'BatchNormalization._bessels_correction_test_only' is assigned but its value is never used | |||
private int[] axis; | |||
private string _data_format; | |||
private IInitializer beta_initializer; | |||
private IInitializer gamma_initializer; | |||
private IInitializer moving_mean_initializer; | |||
private IInitializer moving_variance_initializer; | |||
private IVariableV1 gamma; | |||
private IVariableV1 beta; | |||
private RefVariable moving_mean; | |||
private RefVariable moving_variance; | |||
public BatchNormalization(int axis = -1, | |||
float momentum = 0.99f, | |||
float epsilon = 0.001f, | |||
bool center = true, | |||
bool scale = true, | |||
IInitializer beta_initializer = null, | |||
IInitializer gamma_initializer = null, | |||
IInitializer moving_mean_initializer = null, | |||
IInitializer moving_variance_initializer = null, | |||
bool renorm = false, | |||
float renorm_momentum = 0.99f, | |||
bool trainable = true, | |||
string name = null) : base(trainable: trainable, | |||
name: name) | |||
BatchNormalizationArgs args; | |||
float momentum => args.Momentum; | |||
float epsilon => args.Epsilon; | |||
bool center => args.Center; | |||
bool scale => args.Scale; | |||
bool renorm => args.Renorm; | |||
bool fused; | |||
int[] axis; | |||
string _data_format; | |||
IInitializer beta_initializer => args.BetaInitializer; | |||
IInitializer gamma_initializer => args.GammaInitializer; | |||
IInitializer moving_mean_initializer; | |||
IInitializer moving_variance_initializer; | |||
IRegularizer gamma_regularizer => args.GammaRegularizer; | |||
IVariableV1 gamma; | |||
IVariableV1 beta; | |||
IVariableV1 moving_mean; | |||
IVariableV1 moving_variance; | |||
public BatchNormalization(BatchNormalizationArgs args) : base(args) | |||
{ | |||
this.axis = new int[] { axis }; | |||
this.momentum = momentum; | |||
this.epsilon = epsilon; | |||
this.center = center; | |||
this.scale = scale; | |||
if (beta_initializer == null) | |||
beta_initializer = tf.zeros_initializer; | |||
if (gamma_initializer == null) | |||
gamma_initializer = tf.ones_initializer; | |||
if (moving_mean_initializer == null) | |||
moving_mean_initializer = tf.zeros_initializer; | |||
if (moving_variance_initializer == null) | |||
moving_variance_initializer = tf.ones_initializer; | |||
this.beta_initializer = beta_initializer; | |||
this.gamma_initializer = gamma_initializer; | |||
this.moving_mean_initializer = moving_mean_initializer; | |||
this.moving_variance_initializer = moving_variance_initializer; | |||
this.renorm = renorm; | |||
this.fused = true; | |||
this.SupportsMasking = true; | |||
this._bessels_correction_test_only = true; | |||
this.args = args; | |||
axis = args.Axis.dims; | |||
} | |||
protected override void build(TensorShape input_shape) | |||
@@ -91,12 +59,25 @@ namespace Tensorflow.Keras.Layers | |||
if (x < 0) | |||
axis[idx] = ndims + x; | |||
fused = ndims == 4; | |||
if (fused) | |||
if (Enumerable.SequenceEqual(axis, new int[] { 3 })) | |||
{ | |||
if (Enumerable.SequenceEqual(axis, new int[] { 1 })) | |||
_data_format = "NCHW"; | |||
else if (Enumerable.SequenceEqual(axis, new int[] { 3 })) | |||
_data_format = "NHWC"; | |||
else | |||
throw new ValueError($"Unsupported axis, fused batch norm only supports axis == [1] or axis == [3]"); | |||
} | |||
var axis_to_dim = new Dictionary<int, int>(); | |||
foreach(var x in axis) | |||
axis_to_dim[x] = input_shape[x]; | |||
inputSpec = new InputSpec(ndim: ndims, axes: axis_to_dim); | |||
var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType; | |||
var param_shape = new int[] { input_shape.dims[axis[0]] }; | |||
var param_shape = inputSpec.AllAxisDim; | |||
if (scale) | |||
gamma = add_weight("gamma", | |||
@@ -116,26 +97,17 @@ namespace Tensorflow.Keras.Layers | |||
else | |||
throw new NotImplementedException("add_weight beta"); | |||
if(_scope != null) | |||
{ | |||
} | |||
moving_mean = (RefVariable)add_weight("moving_mean", | |||
moving_mean = add_weight("moving_mean", | |||
param_shape, | |||
dtype: param_dtype, | |||
initializer: moving_mean_initializer, | |||
synchronization: VariableSynchronization.OnRead, | |||
trainable: false, | |||
aggregation: VariableAggregation.Mean); | |||
trainable: false); | |||
moving_variance = (RefVariable)add_weight("moving_variance", | |||
moving_variance = add_weight("moving_variance", | |||
shape: param_shape, | |||
dtype: param_dtype, | |||
initializer: moving_variance_initializer, | |||
synchronization: VariableSynchronization.OnRead, | |||
trainable: false, | |||
aggregation: VariableAggregation.Mean); | |||
trainable: false); | |||
if (renorm) | |||
throw new NotImplementedException("build when renorm is true"); | |||
@@ -178,8 +150,8 @@ namespace Tensorflow.Keras.Layers | |||
inputs, | |||
gamma, | |||
beta, | |||
mean: moving_mean, | |||
variance: moving_variance, | |||
mean: moving_mean.AsTensor(), | |||
variance: moving_variance.AsTensor(), | |||
epsilon: epsilon, | |||
is_training: false, | |||
data_format: _data_format); | |||
@@ -202,8 +174,8 @@ namespace Tensorflow.Keras.Layers | |||
if(training_value == null) | |||
{ | |||
var mean_update = _assign_moving_average(moving_mean, mean, momentum_tensor); | |||
var variance_update = _assign_moving_average(moving_variance, variance, momentum_tensor); | |||
var mean_update = _assign_moving_average(moving_mean.AsTensor(), mean, momentum_tensor); | |||
var variance_update = _assign_moving_average(moving_variance.AsTensor(), variance, momentum_tensor); | |||
add_update(new Tensor[] { mean_update }, inputs: true); | |||
add_update(new Tensor[] { variance_update }, inputs: true); | |||
} | |||
@@ -19,12 +19,11 @@ using Tensorflow.Operations.Activation; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
public class Conv2D : Conv | |||
public class Conv2D : Convolutional | |||
{ | |||
public Conv2D(Conv2DArgs args) | |||
: base(args) | |||
public Conv2D(Conv2DArgs args) : base(args) | |||
{ | |||
} | |||
} | |||
} |
@@ -20,13 +20,13 @@ using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Utils; | |||
using Tensorflow.Operations; | |||
using Tensorflow.Operations.Activation; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
public class Conv : Layer | |||
public class Convolutional : Layer | |||
{ | |||
ConvArgs args; | |||
ConvolutionalArgs args; | |||
protected int rank => args.Rank; | |||
protected int filters => args.Filters; | |||
protected TensorShape kernel_size => args.KernelSize; | |||
@@ -37,13 +37,14 @@ namespace Tensorflow.Keras.Layers | |||
protected Activation activation => args.Activation; | |||
protected bool use_bias => args.UseBias; | |||
protected IInitializer kernel_initializer => args.KernelInitializer; | |||
protected IRegularizer kernel_regularizer => args.KernelRegularizer; | |||
protected IInitializer bias_initializer => args.BiasInitializer; | |||
protected IVariableV1 kernel; | |||
protected IVariableV1 bias; | |||
protected Convolution _convolution_op; | |||
string _tf_data_format; | |||
ConvolutionInternal _convolution_op; | |||
protected string _tf_data_format; | |||
public Conv(ConvArgs args) : base(args) | |||
public Convolutional(ConvolutionalArgs args) : base(args) | |||
{ | |||
this.args = args; | |||
args.KernelSize = conv_utils.normalize_tuple(args.KernelSize.dims, args.Rank, "kernel_size"); | |||
@@ -65,6 +66,7 @@ namespace Tensorflow.Keras.Layers | |||
kernel = add_weight(name: "kernel", | |||
shape: kernel_shape, | |||
initializer: kernel_initializer, | |||
regularizer: kernel_regularizer, | |||
trainable: true, | |||
dtype: DType); | |||
if (use_bias) | |||
@@ -76,7 +78,7 @@ namespace Tensorflow.Keras.Layers | |||
var axes = new Dictionary<int, int>(); | |||
axes.Add(-1, input_channel); | |||
inputSpec = new InputSpec(ndim: rank + 2, axes: axes); | |||
inputSpec = new InputSpec(min_ndim: rank + 2, axes: axes); | |||
string tf_padding; | |||
if (padding == "causal") | |||
@@ -84,20 +86,21 @@ namespace Tensorflow.Keras.Layers | |||
else | |||
tf_padding = padding.ToUpper(); | |||
_convolution_op = nn_ops.Convolution(input_shape, | |||
kernel.shape, | |||
tf_padding, | |||
string tf_op_name = GetType().Name; | |||
_convolution_op = nn_ops.convolution_internal(tf_padding, | |||
strides, | |||
dilation_rate, | |||
data_format: _tf_data_format); | |||
data_format: _tf_data_format, | |||
name: tf_op_name); | |||
built = true; | |||
} | |||
protected override Tensors call(Tensors inputs, Tensor state = null, bool training = false) | |||
{ | |||
var outputs = _convolution_op.__call__(inputs, kernel); | |||
var outputs = _convolution_op.Apply(inputs, kernel); | |||
if (use_bias) | |||
{ | |||
if (data_format == "channels_first") |
@@ -47,10 +47,10 @@ namespace Tensorflow.Keras.Layers | |||
} | |||
// moved to base class | |||
if (string.IsNullOrEmpty(Name)) | |||
if (string.IsNullOrEmpty(args.Name)) | |||
{ | |||
var prefix = "input"; | |||
args.Name = prefix + '_' + tf.keras.backend.get_uid(prefix); | |||
name = prefix + '_' + tf.keras.backend.get_uid(prefix); | |||
} | |||
if(args.DType == TF_DataType.DtInvalid) | |||
@@ -91,7 +91,6 @@ namespace Tensorflow.Keras.Layers | |||
// input_tensor._keras_mask = None | |||
new Node(this, new NodeArgs | |||
{ | |||
InputTensors = args.InputTensor, | |||
Outputs = args.InputTensor | |||
}); | |||
@@ -11,15 +11,35 @@ namespace Tensorflow.Keras.Layers | |||
{ | |||
public Conv2D Conv2D(int filters, | |||
TensorShape kernel_size = null, | |||
TensorShape strides = null, | |||
string padding = "valid", | |||
string activation = "relu") | |||
=> new Conv2D(new Conv2DArgs | |||
{ | |||
Filters = filters, | |||
KernelSize = kernel_size, | |||
Padding = padding, | |||
Activation = GetActivationByName(activation) | |||
}); | |||
string data_format = null, | |||
TensorShape dilation_rate = null, | |||
int groups = 1, | |||
string activation = null, | |||
bool use_bias = true, | |||
IInitializer kernel_initializer = null, | |||
IInitializer bias_initializer = null, | |||
IRegularizer kernel_regularizer = null, | |||
IRegularizer bias_regularizer = null, | |||
IRegularizer activity_regularizer = null) | |||
=> new Conv2D(new Conv2DArgs | |||
{ | |||
Rank = 2, | |||
Filters = filters, | |||
KernelSize = kernel_size, | |||
Strides = strides == null ? (1, 1) : strides, | |||
Padding = padding, | |||
DataFormat = data_format, | |||
DilationRate = dilation_rate == null ? (1, 1) : dilation_rate, | |||
Groups = groups, | |||
KernelRegularizer = kernel_regularizer, | |||
KernelInitializer = kernel_initializer == null ? tf.glorot_uniform_initializer : kernel_initializer, | |||
BiasInitializer = bias_initializer == null ? tf.zeros_initializer : bias_initializer, | |||
BiasRegularizer = bias_regularizer, | |||
ActivityRegularizer = activity_regularizer, | |||
Activation = GetActivationByName(activation) | |||
}); | |||
public Dense Dense(int units, | |||
@@ -65,6 +85,30 @@ namespace Tensorflow.Keras.Layers | |||
DataFormat = data_format | |||
}); | |||
/// <summary> | |||
/// `Input()` is used to instantiate a Keras tensor. | |||
/// </summary> | |||
/// <param name="shape">A shape tuple not including the batch size.</param> | |||
/// <param name="name"></param> | |||
/// <param name="sparse"></param> | |||
/// <param name="ragged"></param> | |||
/// <returns></returns> | |||
public Tensors Input(TensorShape shape, | |||
string name = null, | |||
bool sparse = false, | |||
bool ragged = false) | |||
{ | |||
var input_layer = new InputLayer(new InputLayerArgs | |||
{ | |||
InputShape = shape, | |||
Name = name, | |||
Sparse = sparse, | |||
Ragged = ragged | |||
}); | |||
return input_layer.InboundNodes[0].Outputs; | |||
} | |||
public MaxPooling2D MaxPooling2D(TensorShape pool_size = null, | |||
TensorShape strides = null, | |||
string padding = "valid") | |||
@@ -0,0 +1,12 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras | |||
{ | |||
public class Regularizers | |||
{ | |||
public IRegularizer l2(float l2 = 0.01f) | |||
=> new L2(l2); | |||
} | |||
} |
@@ -0,0 +1,11 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras | |||
{ | |||
public interface IRegularizer | |||
{ | |||
Tensor Apply(RegularizerArgs args); | |||
} | |||
} |
@@ -0,0 +1,21 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras | |||
{ | |||
public class L2 : IRegularizer | |||
{ | |||
float l2; | |||
public L2(float l2 = 0.01f) | |||
{ | |||
this.l2 = l2; | |||
} | |||
public Tensor Apply(RegularizerArgs args) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
} | |||
} |
@@ -0,0 +1,10 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras | |||
{ | |||
public class RegularizerArgs | |||
{ | |||
} | |||
} |
@@ -55,8 +55,8 @@ namespace Tensorflow.Keras.Utils | |||
/// </summary> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static string unique_layer_name(string name, Dictionary<(string, string), int> name_uid_map = null, | |||
string[] avoid_names = null, string @namespace = "", bool zero_based = false) | |||
public static string unique_layer_name(string name, Dictionary<string, int> name_uid_map = null, | |||
string[] avoid_names = null, bool zero_based = false) | |||
{ | |||
if (name_uid_map == null) | |||
name_uid_map = get_default_graph_uid_map(); | |||
@@ -66,41 +66,40 @@ namespace Tensorflow.Keras.Utils | |||
string proposed_name = null; | |||
while (proposed_name == null || avoid_names.Contains(proposed_name)) | |||
{ | |||
var name_key = (@namespace, name); | |||
if (!name_uid_map.ContainsKey(name_key)) | |||
name_uid_map[name_key] = 0; | |||
if (!name_uid_map.ContainsKey(name)) | |||
name_uid_map[name] = 0; | |||
if (zero_based) | |||
{ | |||
int number = name_uid_map[name_key]; | |||
int number = name_uid_map[name]; | |||
if (number > 0) | |||
proposed_name = $"{name}_{number}"; | |||
else | |||
proposed_name = name; | |||
name_uid_map[name_key] += 1; | |||
name_uid_map[name] += 1; | |||
} | |||
else | |||
{ | |||
name_uid_map[name_key] += 1; | |||
proposed_name = $"{name}_{name_uid_map[name_key]}"; | |||
name_uid_map[name] += 1; | |||
proposed_name = $"{name}_{name_uid_map[name]}"; | |||
} | |||
} | |||
return proposed_name; | |||
} | |||
public static Dictionary<(string, string), int> get_default_graph_uid_map() | |||
public static Dictionary<string, int> get_default_graph_uid_map() | |||
{ | |||
var graph = ops.get_default_graph(); | |||
Dictionary<(string, string), int> name_uid_map = null; | |||
Dictionary<string, int> name_uid_map = null; | |||
if (tf.keras.backend.PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph)) | |||
{ | |||
name_uid_map = tf.keras.backend.PER_GRAPH_LAYER_NAME_UIDS[graph]; | |||
} | |||
else | |||
{ | |||
name_uid_map = new Dictionary<(string, string), int>(); | |||
name_uid_map = new Dictionary<string, int>(); | |||
tf.keras.backend.PER_GRAPH_LAYER_NAME_UIDS[graph] = name_uid_map; | |||
} | |||
@@ -183,8 +183,6 @@ namespace Tensorflow.Layers | |||
}); | |||
} | |||
protected override string _name_scope() | |||
{ | |||
return _current_scope.original_name_scope; | |||
@@ -202,7 +200,7 @@ namespace Tensorflow.Layers | |||
} | |||
else | |||
{ | |||
tf_with(tf.variable_scope(scope, default_name: baseName), captured_scope => | |||
tf_with(tf.variable_scope(scope, default_name: base_name), captured_scope => | |||
{ | |||
// convert variable_scope to VariableScope | |||
_scope = captured_scope; | |||
@@ -41,8 +41,8 @@ namespace Tensorflow.Operations.Initializers | |||
public Tensor Apply(InitializerArgs args) | |||
{ | |||
if (args.DType == TF_DataType.DtInvalid) | |||
args.DType = this.dtype; | |||
return random_ops.random_normal(args.Shape, mean, stddev, dtype, seed: seed); | |||
args.DType = dtype; | |||
return random_ops.random_normal(args.Shape, mean, stddev, args.DType, seed: seed); | |||
} | |||
} | |||
} |
@@ -1,84 +0,0 @@ | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System.Linq; | |||
namespace Tensorflow.Operations | |||
{ | |||
public class Convolution | |||
{ | |||
public TensorShape input_shape; | |||
public TensorShape filter_shape; | |||
public string data_format; | |||
public int[] strides; | |||
public string name; | |||
public _WithSpaceToBatch conv_op; | |||
public Convolution(TensorShape input_shape, | |||
TensorShape filter_shape, | |||
string padding, | |||
int[] strides, | |||
int[] dilation_rate, | |||
string name = null, | |||
string data_format = null) | |||
{ | |||
var num_total_dims = filter_shape.ndim; | |||
var num_spatial_dims = num_total_dims - 2; | |||
int input_channels_dim; | |||
int[] spatial_dims; | |||
if (string.IsNullOrEmpty(data_format) || !data_format.StartsWith("NC")) | |||
{ | |||
input_channels_dim = input_shape.dims[num_spatial_dims + 1]; | |||
spatial_dims = Enumerable.Range(1, num_spatial_dims).ToArray(); | |||
} | |||
else | |||
{ | |||
input_channels_dim = input_shape.dims[1]; | |||
spatial_dims = Enumerable.Range(2, num_spatial_dims).ToArray(); | |||
} | |||
this.input_shape = input_shape; | |||
this.filter_shape = filter_shape; | |||
this.data_format = data_format; | |||
this.strides = strides; | |||
this.name = name; | |||
conv_op = new _WithSpaceToBatch( | |||
input_shape, | |||
dilation_rate: dilation_rate, | |||
padding: padding, | |||
build_op: _build_op, | |||
filter_shape: filter_shape, | |||
spatial_dims: spatial_dims, | |||
data_format: data_format); | |||
} | |||
public _NonAtrousConvolution _build_op(int _, string padding) | |||
{ | |||
return new _NonAtrousConvolution(input_shape, | |||
filter_shape: filter_shape, | |||
padding: padding, | |||
data_format: data_format, | |||
strides: strides, | |||
name: name); | |||
} | |||
public Tensor __call__(Tensor inp, IVariableV1 filter) | |||
{ | |||
return conv_op.__call__(inp, filter); | |||
} | |||
} | |||
} |
@@ -0,0 +1,100 @@ | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Xml; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Operations | |||
{ | |||
internal class ConvolutionInternal | |||
{ | |||
ConvolutionalArgs args; | |||
string data_format => args.DataFormat; | |||
string name; | |||
string padding => args.Padding; | |||
public ConvolutionInternal(ConvolutionalArgs args) | |||
{ | |||
this.args = args; | |||
name = args.Name; | |||
} | |||
public Tensor Apply(Tensors input, IVariableV1 filters) | |||
{ | |||
var filters_rank = filters.shape.rank; | |||
var inputs_rank = input.shape.rank; | |||
var num_spatial_dims = args.NumSpatialDims; | |||
if (num_spatial_dims == Unknown) | |||
num_spatial_dims = filters_rank - 2; | |||
// Channel dimension. | |||
var num_batch_dims = inputs_rank - num_spatial_dims - 1; | |||
if (!new[] { 1, 2, 3 }.Contains(num_spatial_dims)) | |||
throw new ValueError($"num_spatial_dims (input.shape.ndims - num_batch_dims - 1) must be one " + | |||
$"of 1, 2 or 3 but saw {num_spatial_dims}. num_batch_dims: {num_batch_dims}."); | |||
var channel_index = num_batch_dims + num_spatial_dims; | |||
var dilations = _get_sequence(args.DilationRate, num_spatial_dims, channel_index); | |||
var strides = _get_sequence(args.Strides, num_spatial_dims, channel_index); | |||
Tensor result = null; | |||
tf_with(ops.name_scope(name, default_name: null, (input, filters)), scope => | |||
{ | |||
name = scope; | |||
if (num_spatial_dims == 2) | |||
result = gen_nn_ops.conv2d(new Conv2dParams | |||
{ | |||
Input = input, | |||
Filter = filters.AsTensor(), | |||
Strides = strides, | |||
Padding = padding, | |||
DataFormat = data_format, | |||
Dilations = dilations, | |||
Name = name | |||
}); | |||
else | |||
throw new NotImplementedException(""); | |||
}); | |||
return result; | |||
} | |||
int[] _get_sequence(int[] value, int n, int channel_index) | |||
{ | |||
var seq = new List<int>(); | |||
if (channel_index == 1) | |||
{ | |||
seq.Add(1); | |||
seq.Add(1); | |||
seq.AddRange(value); | |||
} | |||
else | |||
{ | |||
seq.Add(1); | |||
seq.AddRange(value); | |||
seq.Add(1); | |||
} | |||
return seq.ToArray(); | |||
} | |||
} | |||
} |
@@ -1,83 +0,0 @@ | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using System.Linq; | |||
namespace Tensorflow.Operations | |||
{ | |||
public class _NonAtrousConvolution | |||
{ | |||
public string padding; | |||
public string name; | |||
public int[] strides; | |||
public string data_format; | |||
private Func<Conv2dParams, Tensor> conv_op; | |||
public _NonAtrousConvolution(TensorShape input_shape, | |||
TensorShape filter_shape, | |||
string padding, | |||
string data_format, | |||
int[] strides, | |||
string name) | |||
{ | |||
this.padding = padding; | |||
this.name = name; | |||
var conv_dims = input_shape.ndim - 2; | |||
if (conv_dims == 1) | |||
{ | |||
throw new NotImplementedException("_NonAtrousConvolution conv_dims 1"); | |||
} | |||
else if (conv_dims == 2) | |||
{ | |||
var list = strides.ToList(); | |||
if (string.IsNullOrEmpty(data_format) || data_format == "NHWC") | |||
{ | |||
data_format = "NHWC"; | |||
list.Insert(0, 1); | |||
list.Add(1); | |||
} | |||
else if (data_format == "NCHW") | |||
list.InsertRange(0, new int[] { 1, 1 }); | |||
else | |||
throw new ValueError("data_format must be \"NHWC\" or \"NCHW\"."); | |||
strides = list.ToArray(); | |||
this.strides = strides; | |||
this.data_format = data_format; | |||
conv_op = gen_nn_ops.conv2d; | |||
} | |||
else if (conv_dims == 3) | |||
{ | |||
throw new NotImplementedException("_NonAtrousConvolution conv_dims 3"); | |||
} | |||
} | |||
public Tensor __call__(Tensor inp, IVariableV1 filter) | |||
{ | |||
return conv_op(new Conv2dParams | |||
{ | |||
Input = inp, | |||
Filter = filter.AsTensor(), | |||
Strides = strides, | |||
Padding = padding, | |||
DataFormat = data_format, | |||
Name = name | |||
}); | |||
} | |||
} | |||
} |
@@ -1,76 +0,0 @@ | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using System.Linq; | |||
namespace Tensorflow.Operations | |||
{ | |||
public class _WithSpaceToBatch | |||
{ | |||
private _NonAtrousConvolution call; | |||
public _WithSpaceToBatch(TensorShape input_shape, | |||
int[] dilation_rate, | |||
string padding, | |||
Func<int, string, _NonAtrousConvolution> build_op, | |||
TensorShape filter_shape = null, | |||
int[] spatial_dims = null, | |||
string data_format = null) | |||
{ | |||
var dilation_rate_tensor = ops.convert_to_tensor(dilation_rate, TF_DataType.TF_INT32, name: "dilation_rate"); | |||
var rate_shape = dilation_rate_tensor.TensorShape; | |||
var num_spatial_dims = rate_shape.dims[0]; | |||
#pragma warning disable CS0219 // Variable is assigned but its value is never used | |||
int starting_spatial_dim = -1; | |||
#pragma warning restore CS0219 // Variable is assigned but its value is never used | |||
if (!string.IsNullOrEmpty(data_format) && data_format.StartsWith("NC")) | |||
starting_spatial_dim = 2; | |||
else | |||
starting_spatial_dim = 1; | |||
if (spatial_dims == null) | |||
throw new NotImplementedException("_WithSpaceToBatch spatial_dims"); | |||
var orig_spatial_dims = spatial_dims; | |||
spatial_dims = spatial_dims.OrderBy(x => x).ToArray(); | |||
if (!Enumerable.SequenceEqual(spatial_dims, orig_spatial_dims) || spatial_dims.Any(x => x < 1)) | |||
throw new ValueError("spatial_dims must be a montonically increasing sequence of positive integers"); | |||
int expected_input_rank = -1; | |||
if (!string.IsNullOrEmpty(data_format) && data_format.StartsWith("NC")) | |||
expected_input_rank = spatial_dims.Last(); | |||
else | |||
expected_input_rank = spatial_dims.Last() + 1; | |||
var const_rate = tensor_util.constant_value(dilation_rate_tensor); | |||
var rate_or_const_rate = dilation_rate; | |||
if(!(const_rate is null)) | |||
{ | |||
if (const_rate.Data<int>().Count(x => x == 1) == const_rate.size) | |||
{ | |||
call = build_op(num_spatial_dims, padding); | |||
return; | |||
} | |||
} | |||
} | |||
public Tensor __call__(Tensor inp, IVariableV1 filter) | |||
{ | |||
return call.__call__(inp, filter); | |||
} | |||
} | |||
} |
@@ -78,35 +78,45 @@ namespace Tensorflow.Operations | |||
/// </summary> | |||
/// <param name="parameters"></param> | |||
/// <returns></returns> | |||
public static Tensor conv2d_backprop_filter(Conv2dParams parameters) | |||
public static Tensor conv2d_backprop_filter(Tensor input, Tensor filter_sizes, Tensor out_backprop, | |||
int[] strides, string padding, bool use_cudnn_on_gpu = true, | |||
int[] explicit_paddings = null, | |||
string data_format = "NHWC", | |||
int[] dilations = null, | |||
string name = null) | |||
{ | |||
if (explicit_paddings == null) | |||
explicit_paddings = new int[0]; | |||
if (dilations == null) | |||
dilations = new int[] { 1, 1, 1, 1 }; | |||
if (tf.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Conv2DBackpropFilter", parameters.Name, | |||
"Conv2DBackpropFilter", name, | |||
null, | |||
parameters.Input, parameters.FilterSizes, parameters.OutBackProp, | |||
"strides", parameters.Strides, | |||
"use_cudnn_on_gpu", parameters.UseCudnnOnGpu, | |||
"padding", parameters.Padding, | |||
"explicit_paddings", parameters.ExplicitPaddings, | |||
"data_format", parameters.DataFormat, | |||
"dilations", parameters.Dilations); | |||
input, filter_sizes, out_backprop, | |||
"strides", strides, | |||
"use_cudnn_on_gpu", use_cudnn_on_gpu, | |||
"padding", padding, | |||
"explicit_paddings", explicit_paddings, | |||
"data_format", data_format, | |||
"dilations", dilations); | |||
return results[0]; | |||
} | |||
var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropFilter", name: parameters.Name, args: new | |||
var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropFilter", name: name, args: new | |||
{ | |||
input = parameters.Input, | |||
filter_sizes = parameters.FilterSizes, | |||
out_backprop = parameters.OutBackProp, | |||
strides = parameters.Strides, | |||
padding = parameters.Padding, | |||
use_cudnn_on_gpu = parameters.UseCudnnOnGpu, | |||
explicit_paddings = parameters.ExplicitPaddings, | |||
data_format = parameters.DataFormat, | |||
dilations = parameters.Dilations | |||
input, | |||
filter_sizes, | |||
out_backprop, | |||
strides, | |||
padding, | |||
use_cudnn_on_gpu, | |||
explicit_paddings, | |||
data_format, | |||
dilations | |||
}); | |||
return _op.outputs[0]; | |||
@@ -117,35 +127,45 @@ namespace Tensorflow.Operations | |||
/// </summary> | |||
/// <param name="parameters"></param> | |||
/// <returns></returns> | |||
public static Tensor conv2d_backprop_input(Conv2dParams parameters) | |||
public static Tensor conv2d_backprop_input(Tensor input_sizes, Tensor filter, Tensor out_backprop, | |||
int[] strides, string padding, bool use_cudnn_on_gpu = true, | |||
int[] explicit_paddings = null, | |||
string data_format= "NHWC", | |||
int[] dilations = null, | |||
string name = null) | |||
{ | |||
if (explicit_paddings == null) | |||
explicit_paddings = new int[0]; | |||
if (dilations == null) | |||
dilations = new int[] { 1, 1, 1, 1 }; | |||
if (tf.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Conv2DBackpropInput", parameters.Name, | |||
"Conv2DBackpropInput", name, | |||
null, | |||
parameters.InputSizes, parameters.Filter, parameters.OutBackProp, | |||
"strides", parameters.Strides, | |||
"use_cudnn_on_gpu", parameters.UseCudnnOnGpu, | |||
"padding", parameters.Padding, | |||
"explicit_paddings", parameters.ExplicitPaddings, | |||
"data_format", parameters.DataFormat, | |||
"dilations", parameters.Dilations); | |||
input_sizes, filter, out_backprop, | |||
"strides", strides, | |||
"use_cudnn_on_gpu", use_cudnn_on_gpu, | |||
"padding", padding, | |||
"explicit_paddings", explicit_paddings, | |||
"data_format", data_format, | |||
"dilations", dilations); | |||
return results[0]; | |||
} | |||
var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropInput", name: parameters.Name, args: new | |||
var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropInput", name: name, args: new | |||
{ | |||
input_sizes = parameters.InputSizes, | |||
filter = parameters.Filter, | |||
out_backprop = parameters.OutBackProp, | |||
strides = parameters.Strides, | |||
padding = parameters.Padding, | |||
use_cudnn_on_gpu = parameters.UseCudnnOnGpu, | |||
explicit_paddings = parameters.ExplicitPaddings, | |||
data_format = parameters.DataFormat, | |||
dilations = parameters.Dilations | |||
input_sizes, | |||
filter, | |||
out_backprop, | |||
strides, | |||
padding, | |||
use_cudnn_on_gpu, | |||
explicit_paddings, | |||
data_format, | |||
dilations | |||
}); | |||
return _op.outputs[0]; | |||
@@ -33,11 +33,6 @@ namespace Tensorflow | |||
/// <returns></returns> | |||
public static Tensor random_standard_normal(Tensor shape, TF_DataType dtype = TF_DataType.DtInvalid, int? seed = null, int? seed2 = null, string name = null) | |||
{ | |||
if (!seed.HasValue) | |||
seed = 0; | |||
if (!seed2.HasValue) | |||
seed2 = 0; | |||
if (tf.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
@@ -51,6 +46,11 @@ namespace Tensorflow | |||
return results[0]; | |||
} | |||
if (!seed.HasValue) | |||
seed = 0; | |||
if (!seed2.HasValue) | |||
seed2 = 0; | |||
var _op = tf.OpDefLib._apply_op_helper("RandomStandardNormal", | |||
name: name, | |||
args: new { shape, dtype, seed, seed2 }); | |||
@@ -16,6 +16,7 @@ | |||
using System; | |||
using System.Linq; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Operations; | |||
using static Tensorflow.Binding; | |||
@@ -23,19 +24,18 @@ namespace Tensorflow | |||
{ | |||
public class nn_ops | |||
{ | |||
public static Convolution Convolution(TensorShape input_shape, | |||
TensorShape filter_shape, | |||
string padding, | |||
internal static ConvolutionInternal convolution_internal(string padding, | |||
int[] strides, | |||
int[] dilation_rate, | |||
string name = null, | |||
string data_format = null) => new Convolution(input_shape, | |||
filter_shape, | |||
padding, | |||
strides, | |||
dilation_rate, | |||
name: name, | |||
data_format: data_format); | |||
string data_format = null) => new ConvolutionInternal(new ConvolutionalArgs | |||
{ | |||
Padding = padding, | |||
Strides = strides, | |||
DilationRate = dilation_rate, | |||
DataFormat = data_format, | |||
Name = name | |||
}); | |||
/// <summary> | |||
/// Adds `bias` to `value`. | |||
@@ -64,7 +64,7 @@ namespace Tensorflow | |||
/// The string name of this tensor.<br/> | |||
/// Tensor.name is meaningless when eager execution is enabled. | |||
/// </summary> | |||
public string name => $"{(op == null ? "<unnamed>" : $"{op.name}:{_value_index}")}"; | |||
public virtual string name => $"{(op == null ? "<unnamed>" : $"{op.name}:{_value_index}")}"; | |||
/// <summary> | |||
/// The index of this tensor in the outputs of its Operation. | |||
@@ -132,7 +132,7 @@ namespace Tensorflow | |||
} | |||
} | |||
public int this[int index] => dims[index]; | |||
public int this[int index] => index < 0 ? dims[ndim + index] : dims[index]; | |||
/// <summary> | |||
/// Returns True iff `self` is fully defined in every dimension. | |||
@@ -8,7 +8,7 @@ using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
public class BaseResourceVariable : DisposableObject, IVariableV1 | |||
public class BaseResourceVariable : DisposableObject | |||
{ | |||
protected string _name; | |||
public virtual string Name => _handle_name; | |||
@@ -92,7 +92,8 @@ namespace Tensorflow | |||
return assign_op; | |||
} | |||
public Tensor value() => tf.executing_eagerly() ? _read_variable_op() : GraphElement; | |||
public Tensor value() | |||
=> GraphElement ?? _read_variable_op(); | |||
protected Tensor _read_variable_op() | |||
{ | |||
@@ -159,7 +160,15 @@ namespace Tensorflow | |||
{ | |||
} | |||
public Tensor AsTensor() | |||
=> tf.executing_eagerly() ? read_value() : GraphElement; | |||
public Tensor AsTensor(bool as_ref = true) | |||
{ | |||
if (!as_ref && GraphElement != null) | |||
return GraphElement; | |||
if (as_ref) | |||
return tf.executing_eagerly() ? read_value() : GraphElement; | |||
else | |||
return _read_variable_op(); | |||
} | |||
} | |||
} |
@@ -49,6 +49,6 @@ namespace Tensorflow | |||
public TensorShape shape { get; } | |||
Tensor assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true); | |||
Tensor assign<T>(T value, bool use_locking = false, string name = null, bool read_value = true); | |||
Tensor AsTensor(); | |||
Tensor AsTensor(bool as_ref = true); | |||
} | |||
} |
@@ -152,7 +152,7 @@ namespace Tensorflow | |||
if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES)) | |||
collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); | |||
tf_with(ops.init_scope2(), delegate | |||
tf_with(ops.init_scope(), init_scope => | |||
{ | |||
var values = init_from_fn ? new object[0] : new object[] { initial_value }; | |||
tf_with(ops.name_scope(name, "Variable", values), scope => | |||
@@ -222,7 +222,7 @@ namespace Tensorflow | |||
public Tensor value() => _snapshot; | |||
public Tensor AsTensor() => _snapshot; | |||
public Tensor AsTensor(bool as_ref = true) => _snapshot; | |||
public Tensor _as_graph_element() => _variable; | |||
@@ -26,7 +26,7 @@ namespace Tensorflow | |||
/// <summary> | |||
/// Variable based on resource handles. | |||
/// </summary> | |||
public partial class ResourceVariable : BaseResourceVariable | |||
public partial class ResourceVariable : BaseResourceVariable, IVariableV1 | |||
{ | |||
Tensor _cached_value; | |||
public string Device => handle.Device; | |||
@@ -90,7 +90,7 @@ namespace Tensorflow | |||
collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); | |||
_in_graph_mode = !tf.Context.executing_eagerly(); | |||
tf_with(ops.init_scope2(), delegate | |||
tf_with(ops.init_scope(), init_scope => | |||
{ | |||
var values = init_from_fn ? new object[0] : new object[] { initial_value }; | |||
tf_with(ops.name_scope(name, "Variable", values), scope => | |||
@@ -239,11 +239,8 @@ namespace Tensorflow | |||
/// A context manager that lifts ops out of control-flow scopes and function-building graphs. | |||
/// </summary> | |||
/// <returns></returns> | |||
public static void init_scope() | |||
public static NameScope init_scope() | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
return; | |||
// Retrieve the active name scope: entering an `init_scope` preserves | |||
// the name scope of the current context. | |||
var default_graph = get_default_graph(); | |||
@@ -257,25 +254,11 @@ namespace Tensorflow | |||
tf_with(ops.control_dependencies(null), delegate | |||
{ | |||
var outer_graph = get_default_graph(); | |||
// var outer_graph = get_default_graph(); | |||
// outer_device_stack = None | |||
}); | |||
} | |||
public static ITensorFlowObject init_scope2() | |||
{ | |||
// Retrieve the active name scope: entering an `init_scope` preserves | |||
// the name scope of the current context. | |||
var default_graph = get_default_graph(); | |||
var scope = default_graph.get_name_scope(); | |||
if (!String.IsNullOrEmpty(scope) && !scope.EndsWith("/")) | |||
// Names that end with trailing slashes are treated by `name_scope` as | |||
// absolute. | |||
scope += "/"; | |||
// inner_device_stack = default_graph._device_function_stack | |||
// var outer_context = default_graph.as_default; | |||
return ops.control_dependencies(null); | |||
return ops.name_scope(scope); | |||
} | |||
private static int uid_number = -1; | |||
@@ -460,6 +443,8 @@ namespace Tensorflow | |||
{ | |||
case NDArray nd: | |||
return constant_op.constant(nd, dtype: dtype, name: name); | |||
case EagerTensor tensor: | |||
return tf.executing_eagerly() ? tensor : tensor.AsPlaceholder(name: name); | |||
case Tensor tensor: | |||
return tensor; | |||
case Tensor[] tensors: | |||
@@ -90,6 +90,7 @@ namespace Tensorflow | |||
return (scope_name, old_name); | |||
} | |||
[DebuggerHidden] | |||
public void Dispose() | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
@@ -28,7 +28,8 @@ namespace TensorFlowNET.UnitTest.Keras | |||
// Create a simple model. | |||
var inputs = keras.Input(shape: 32); | |||
var outputs = keras.layers.Dense(1).Apply(inputs); | |||
var dense_layer = keras.layers.Dense(1); | |||
var outputs = dense_layer.Apply(inputs); | |||
var model = keras.Model(inputs, outputs); | |||
model.compile("adam", "mean_squared_error"); | |||
return model; | |||
@@ -8,6 +8,11 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||
[TestClass] | |||
public class BitwiseApiTest : TFNetApiTest | |||
{ | |||
[TestInitialize] | |||
public void Init() | |||
{ | |||
tf.enable_eager_execution(); | |||
} | |||
[TestMethod] | |||
public void BitwiseAnd() | |||