@@ -92,7 +92,7 @@ namespace Tensorflow | |||||
/// <param name="renorm"></param> | /// <param name="renorm"></param> | ||||
/// <param name="renorm_momentum"></param> | /// <param name="renorm_momentum"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public Tensor batch_normalization(Tensor inputs, | |||||
public Tensors batch_normalization(Tensor inputs, | |||||
int axis = -1, | int axis = -1, | ||||
float momentum = 0.99f, | float momentum = 0.99f, | ||||
float epsilon = 0.001f, | float epsilon = 0.001f, | ||||
@@ -108,22 +108,24 @@ namespace Tensorflow | |||||
bool renorm = false, | bool renorm = false, | ||||
float renorm_momentum = 0.99f) | float renorm_momentum = 0.99f) | ||||
{ | { | ||||
var layer = new BatchNormalization( | |||||
axis: axis, | |||||
momentum: momentum, | |||||
epsilon: epsilon, | |||||
center: center, | |||||
scale: scale, | |||||
beta_initializer: beta_initializer, | |||||
gamma_initializer: gamma_initializer, | |||||
moving_mean_initializer: moving_mean_initializer, | |||||
moving_variance_initializer: moving_variance_initializer, | |||||
renorm: renorm, | |||||
renorm_momentum: renorm_momentum, | |||||
trainable: trainable, | |||||
name: name); | |||||
return layer.apply(inputs, training: training).Item1; | |||||
var layer = new BatchNormalization(new BatchNormalizationArgs | |||||
{ | |||||
Axis = axis, | |||||
Momentum = momentum, | |||||
Epsilon = epsilon, | |||||
Center = center, | |||||
Scale = scale, | |||||
BetaInitializer = beta_initializer, | |||||
GammaInitializer = gamma_initializer, | |||||
MovingMeanInitializer = moving_mean_initializer, | |||||
MovingVarianceInitializer = moving_variance_initializer, | |||||
Renorm = renorm, | |||||
RenormMomentum = renorm_momentum, | |||||
Trainable = trainable, | |||||
Name = name | |||||
}); | |||||
return layer.Apply(inputs); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -41,6 +41,8 @@ namespace Tensorflow | |||||
/// <summary> | /// <summary> | ||||
/// A context manager that lifts ops out of control-flow scopes and function-building graphs. | /// A context manager that lifts ops out of control-flow scopes and function-building graphs. | ||||
/// When eager execution is enabled, code inside an init_scope block runs with | |||||
/// eager execution enabled even when tracing a `tf.function`. | |||||
/// </summary> | /// </summary> | ||||
public void init_scope() | public void init_scope() | ||||
=> ops.init_scope(); | => ops.init_scope(); | ||||
@@ -227,6 +227,11 @@ namespace Tensorflow.Eager | |||||
input_handle = input.EagerTensorHandle; | input_handle = input.EagerTensorHandle; | ||||
flattened_inputs.Add(input); | flattened_inputs.Add(input); | ||||
break; | break; | ||||
case ResourceVariable variable: | |||||
var var_tensor = variable.AsTensor(); | |||||
input_handle = var_tensor.EagerTensorHandle; | |||||
flattened_inputs.Add(var_tensor); | |||||
break; | |||||
default: | default: | ||||
var tensor = tf.convert_to_tensor(inputs); | var tensor = tf.convert_to_tensor(inputs); | ||||
input_handle = tensor.EagerTensorHandle; | input_handle = tensor.EagerTensorHandle; | ||||
@@ -57,6 +57,26 @@ namespace Tensorflow.Eager | |||||
return this; | return this; | ||||
} | } | ||||
/// <summary> | |||||
/// _create_substitute_placeholder | |||||
/// </summary> | |||||
/// <returns></returns> | |||||
public Tensor AsPlaceholder(string name = null) | |||||
{ | |||||
Tensor placeholder = null; | |||||
tf_with(ops.control_dependencies(null), delegate | |||||
{ | |||||
placeholder = tf.placeholder(dtype, shape: shape, name: name ?? this.name); | |||||
}); | |||||
// custom_gradient.copy_handle_data(value, placeholder) | |||||
return placeholder; | |||||
} | |||||
void copy_handle_data() | |||||
{ | |||||
} | |||||
public override IntPtr ToPointer() | public override IntPtr ToPointer() | ||||
=> EagerTensorHandle?.DangerousGetHandle() ?? IntPtr.Zero; | => EagerTensorHandle?.DangerousGetHandle() ?? IntPtr.Zero; | ||||
@@ -138,30 +138,16 @@ namespace Tensorflow.Gradients | |||||
return new Tensor[] | return new Tensor[] | ||||
{ | { | ||||
gen_nn_ops.conv2d_backprop_input(new Conv2dParams | |||||
{ | |||||
InputSizes = shape[0], | |||||
Filter = op.inputs[1], | |||||
OutBackProp = grads[0], | |||||
Dilations = dilations, | |||||
Strides = strides, | |||||
Padding = padding.ToString(), | |||||
ExplicitPaddings = explicit_paddings, | |||||
UseCudnnOnGpu = (bool)use_cudnn_on_gpu, | |||||
DataFormat = data_format.ToString(), | |||||
}), | |||||
gen_nn_ops.conv2d_backprop_filter(new Conv2dParams | |||||
{ | |||||
Input = op.inputs[0], | |||||
FilterSizes = shape[1], | |||||
OutBackProp = grads[0], | |||||
Dilations = dilations, | |||||
Strides = strides, | |||||
Padding = padding.ToString(), | |||||
ExplicitPaddings = explicit_paddings, | |||||
UseCudnnOnGpu = (bool)use_cudnn_on_gpu, | |||||
DataFormat = data_format.ToString() | |||||
}) | |||||
gen_nn_ops.conv2d_backprop_input(shape[0], op.inputs[1], grads[0], | |||||
strides, padding, use_cudnn_on_gpu, explicit_paddings, | |||||
dilations: dilations, | |||||
data_format: data_format), | |||||
gen_nn_ops.conv2d_backprop_filter(op.inputs[0], shape[1], grads[0], | |||||
strides, padding, | |||||
dilations: dilations, | |||||
explicit_paddings: explicit_paddings, | |||||
use_cudnn_on_gpu: use_cudnn_on_gpu, | |||||
data_format: data_format) | |||||
}; | }; | ||||
} | } | ||||
@@ -0,0 +1,24 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class BatchNormalizationArgs : LayerArgs | |||||
{ | |||||
public TensorShape Axis { get; set; } = -1; | |||||
public float Momentum { get; set; } = 0.99f; | |||||
public float Epsilon { get; set; } = 1e-3f; | |||||
public bool Center { get; set; } = true; | |||||
public bool Scale { get; set; } = true; | |||||
public IInitializer BetaInitializer { get; set; } = tf.zeros_initializer; | |||||
public IInitializer GammaInitializer { get; set; } = tf.ones_initializer; | |||||
public IInitializer MovingMeanInitializer { get; set; } = tf.zeros_initializer; | |||||
public IInitializer MovingVarianceInitializer { get; set; } = tf.ones_initializer; | |||||
public IRegularizer BetaRegularizer { get; set; } | |||||
public IRegularizer GammaRegularizer { get; set; } | |||||
public bool Renorm { get; set; } | |||||
public float RenormMomentum { get; set; } = 0.99f; | |||||
} | |||||
} |
@@ -4,7 +4,7 @@ using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
{ | { | ||||
public class Conv2DArgs : ConvArgs | |||||
public class Conv2DArgs : ConvolutionalArgs | |||||
{ | { | ||||
} | } | ||||
@@ -5,10 +5,11 @@ using static Tensorflow.Binding; | |||||
namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
{ | { | ||||
public class ConvArgs : LayerArgs | |||||
public class ConvolutionalArgs : LayerArgs | |||||
{ | { | ||||
public int Rank { get; set; } = 2; | public int Rank { get; set; } = 2; | ||||
public int Filters { get; set; } | public int Filters { get; set; } | ||||
public int NumSpatialDims { get; set; } = Unknown; | |||||
public TensorShape KernelSize { get; set; } = 5; | public TensorShape KernelSize { get; set; } = 5; | ||||
/// <summary> | /// <summary> | ||||
@@ -24,8 +25,8 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
public bool UseBias { get; set; } | public bool UseBias { get; set; } | ||||
public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; | public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; | ||||
public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; | public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; | ||||
public IInitializer KernelRegularizer { get; set; } | |||||
public IInitializer BiasRegularizer { get; set; } | |||||
public IRegularizer KernelRegularizer { get; set; } | |||||
public IRegularizer BiasRegularizer { get; set; } | |||||
public Action KernelConstraint { get; set; } | public Action KernelConstraint { get; set; } | ||||
public Action BiasConstraint { get; set; } | public Action BiasConstraint { get; set; } | ||||
} | } |
@@ -46,7 +46,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
/// <summary> | /// <summary> | ||||
/// Regularizer function applied to the output of the layer(its "activation"). | /// Regularizer function applied to the output of the layer(its "activation"). | ||||
/// </summary> | /// </summary> | ||||
public IInitializer ActivityRegularizer { get; set; } | |||||
public IRegularizer ActivityRegularizer { get; set; } | |||||
public bool Autocast { get; set; } | public bool Autocast { get; set; } | ||||
} | } | ||||
@@ -6,7 +6,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | { | ||||
public class ModelArgs : LayerArgs | public class ModelArgs : LayerArgs | ||||
{ | { | ||||
public Tensor[] Inputs { get; set; } | |||||
public Tensor[] Outputs { get; set; } | |||||
public Tensors Inputs { get; set; } | |||||
public Tensors Outputs { get; set; } | |||||
} | } | ||||
} | } |
@@ -42,7 +42,7 @@ namespace Tensorflow.Keras | |||||
/// for various layer names in each graph. | /// for various layer names in each graph. | ||||
/// Allows to give unique autogenerated names to layers, in a graph-specific way. | /// Allows to give unique autogenerated names to layers, in a graph-specific way. | ||||
/// </summary> | /// </summary> | ||||
public Dictionary<Graph, Dictionary<(string, string), int>> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<(string, string), int>>(); | |||||
public Dictionary<Graph, Dictionary<string, int>> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<string, int>>(); | |||||
public Dictionary<string, IVariableV1> _GRAPH_VARIABLES = new Dictionary<string, IVariableV1>(); | public Dictionary<string, IVariableV1> _GRAPH_VARIABLES = new Dictionary<string, IVariableV1>(); | ||||
public Dictionary<string, Optimizer> _GRAPH_TF_OPTIMIZERS = new Dictionary<string, Optimizer>(); | public Dictionary<string, Optimizer> _GRAPH_TF_OPTIMIZERS = new Dictionary<string, Optimizer>(); | ||||
@@ -80,25 +80,19 @@ namespace Tensorflow.Keras | |||||
return ops.get_default_graph(); | return ops.get_default_graph(); | ||||
} | } | ||||
public int get_uid(string prefix, string @namespace = "") | |||||
public int get_uid(string prefix) | |||||
{ | { | ||||
var graph = tf.get_default_graph(); | var graph = tf.get_default_graph(); | ||||
if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph)) | if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph)) | ||||
PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<(string, string), int>()); | |||||
PER_GRAPH_LAYER_NAME_UIDS[graph][(@namespace, prefix)] += 1; | |||||
PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<string, int>()); | |||||
if (!PER_GRAPH_LAYER_NAME_UIDS[graph].ContainsKey(prefix)) | |||||
PER_GRAPH_LAYER_NAME_UIDS[graph][prefix] = 0; | |||||
PER_GRAPH_LAYER_NAME_UIDS[graph][prefix] += 1; | |||||
return PER_GRAPH_LAYER_NAME_UIDS[graph][(@namespace, prefix)]; | |||||
return PER_GRAPH_LAYER_NAME_UIDS[graph][prefix]; | |||||
} | } | ||||
public int get_uid((string, string) name) | |||||
{ | |||||
var graph = tf.get_default_graph(); | |||||
if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph)) | |||||
PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<(string, string), int>()); | |||||
PER_GRAPH_LAYER_NAME_UIDS[graph][(name)] += 1; | |||||
return PER_GRAPH_LAYER_NAME_UIDS[graph][name]; | |||||
} | |||||
public void reset_uids() => PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<(string, string), int>>(); | |||||
public void reset_uids() => PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<string, int>>(); | |||||
public void clear_session() | public void clear_session() | ||||
{ | { | ||||
ops.reset_default_graph(); | ops.reset_default_graph(); | ||||
@@ -0,0 +1,67 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
namespace Tensorflow.Keras.Engine | |||||
{ | |||||
/// <summary> | |||||
/// A `Functional` model is a `Model` defined as a directed graph of layers. | |||||
/// </summary> | |||||
public class Functional : Model | |||||
{ | |||||
TensorShape _build_input_shape; | |||||
bool _compute_output_and_mask_jointly; | |||||
bool _expects_training_arg; | |||||
bool _expects_mask_arg; | |||||
bool _autocast; | |||||
List<Layer> _output_layers; | |||||
List<Layer> _input_layers; | |||||
List<KerasHistory> _input_coordinates; | |||||
List<KerasHistory> _output_coordinates; | |||||
public Functional(Tensors inputs, Tensors outputs) | |||||
: base(new ModelArgs | |||||
{ | |||||
Inputs = inputs, | |||||
Outputs = outputs | |||||
}) | |||||
{ | |||||
_input_layers = new List<Layer>(); | |||||
_output_layers = new List<Layer>(); | |||||
_input_coordinates = new List<KerasHistory>(); | |||||
_output_coordinates = new List<KerasHistory>(); | |||||
_init_graph_network(inputs, outputs); | |||||
} | |||||
void _init_graph_network(Tensors inputs, Tensors outputs) | |||||
{ | |||||
_is_graph_network = true; | |||||
this.inputs = inputs; | |||||
this.outputs = outputs; | |||||
built = true; | |||||
_build_input_shape = inputs.shape; | |||||
_compute_output_and_mask_jointly = true; | |||||
_expects_training_arg = true; | |||||
_expects_mask_arg = true; | |||||
// A graph network does not autocast inputs, as its layers will cast them instead. | |||||
_autocast = false; | |||||
// Build self._output_layers: | |||||
foreach(var x in outputs) | |||||
{ | |||||
var (layer, node_index, tensor_index) = x.KerasHistory; | |||||
_output_layers.append(layer); | |||||
_output_coordinates.append(new KerasHistory(layer, node_index, tensor_index)); | |||||
} | |||||
// Build self._input_layers: | |||||
foreach(var x in inputs) | |||||
{ | |||||
var (layer, node_index, tensor_index) = x.KerasHistory; | |||||
_input_layers.append(layer); | |||||
_input_coordinates.append(new KerasHistory(layer, node_index, tensor_index)); | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -15,6 +15,7 @@ | |||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | |||||
namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
{ | { | ||||
@@ -27,6 +28,7 @@ namespace Tensorflow.Keras.Engine | |||||
public int? min_ndim; | public int? min_ndim; | ||||
Dictionary<int, int> axes; | Dictionary<int, int> axes; | ||||
TensorShape shape; | TensorShape shape; | ||||
public int[] AllAxisDim; | |||||
public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid, | public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid, | ||||
int? ndim = null, | int? ndim = null, | ||||
@@ -42,6 +44,12 @@ namespace Tensorflow.Keras.Engine | |||||
this.shape = shape; | this.shape = shape; | ||||
if (ndim == null && shape != null) | if (ndim == null && shape != null) | ||||
this.ndim = shape.ndim; | this.ndim = shape.ndim; | ||||
if(axes != null) | |||||
AllAxisDim = axes.Select(x => x.Value).ToArray(); | |||||
} | } | ||||
public override string ToString() | |||||
=> $"min_ndim={min_ndim}, , axes={axes.Count}"; | |||||
} | } | ||||
} | } |
@@ -20,10 +20,14 @@ namespace Tensorflow.Keras.Engine | |||||
this.tensor_index = tensor_index; | this.tensor_index = tensor_index; | ||||
} | } | ||||
public void Deconstruct(out Layer layer, out int node_index, out int tensor_index) | |||||
{ | |||||
layer = this.layer; | |||||
node_index = this.node_index; | |||||
tensor_index = this.tensor_index; | |||||
} | |||||
public static implicit operator Layer(KerasHistory history) | public static implicit operator Layer(KerasHistory history) | ||||
=> history.layer; | => history.layer; | ||||
public static implicit operator (Layer, int, int)(KerasHistory history) | |||||
=> (history.layer, history.node_index, history.tensor_index); | |||||
} | } | ||||
} | } |
@@ -72,10 +72,10 @@ namespace Tensorflow.Keras.Engine | |||||
protected List<IVariableV1> nonTrainableWeights; | protected List<IVariableV1> nonTrainableWeights; | ||||
public List<IVariableV1> non_trainable_variables => nonTrainableWeights; | public List<IVariableV1> non_trainable_variables => nonTrainableWeights; | ||||
string name; | |||||
protected string name; | |||||
protected string base_name; | |||||
public string Name => name; | public string Name => name; | ||||
protected string baseName; | |||||
protected bool computePreviousMask; | protected bool computePreviousMask; | ||||
protected List<Operation> updates; | protected List<Operation> updates; | ||||
public TensorShape BatchInputShape => args.BatchInputShape; | public TensorShape BatchInputShape => args.BatchInputShape; | ||||
@@ -98,9 +98,9 @@ namespace Tensorflow.Keras.Engine | |||||
// Indicates whether `build` needs to be called upon layer call, to create | // Indicates whether `build` needs to be called upon layer call, to create | ||||
// the layer's weights. | // the layer's weights. | ||||
built = false; | built = false; | ||||
this.SupportsMasking = false; | |||||
SupportsMasking = false; | |||||
_init_set_name(name); | |||||
_init_set_name(args.Name); | |||||
trainableWeights = new List<IVariableV1>(); | trainableWeights = new List<IVariableV1>(); | ||||
nonTrainableWeights = new List<IVariableV1>(); | nonTrainableWeights = new List<IVariableV1>(); | ||||
computePreviousMask = false; | computePreviousMask = false; | ||||
@@ -124,23 +124,25 @@ namespace Tensorflow.Keras.Engine | |||||
/// <returns></returns> | /// <returns></returns> | ||||
public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false) | public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false) | ||||
{ | { | ||||
Tensors outputs = null; | |||||
callContext = callContext ?? new ThreadLocal<CallContext>() | callContext = callContext ?? new ThreadLocal<CallContext>() | ||||
{ | { | ||||
Value = new CallContext() | Value = new CallContext() | ||||
}; | }; | ||||
if (_in_functional_construction_mode(inputs)) | |||||
return _functional_construction_call(inputs); | |||||
Tensors outputs = null; | |||||
var eager = tf.executing_eagerly(); | var eager = tf.executing_eagerly(); | ||||
using var ctxManager = CallContext.enter(); | using var ctxManager = CallContext.enter(); | ||||
string nameScope = ""; | string nameScope = ""; | ||||
if (eager) | if (eager) | ||||
nameScope = name; | |||||
nameScope = Name; | |||||
else | else | ||||
nameScope = _name_scope(); | nameScope = _name_scope(); | ||||
// using var graph = tf.keras.backend.get_graph().as_default(); | |||||
if (!inputs.IsEagerTensor) | if (!inputs.IsEagerTensor) | ||||
tf.Context.graph_mode(); | tf.Context.graph_mode(); | ||||
@@ -162,6 +164,46 @@ namespace Tensorflow.Keras.Engine | |||||
return outputs; | return outputs; | ||||
} | } | ||||
bool _in_functional_construction_mode(Tensors inputs) | |||||
{ | |||||
return inputs.Count(x => !x.IsEagerTensor) == inputs.Count(); | |||||
} | |||||
Tensors _functional_construction_call(Tensors inputs) | |||||
{ | |||||
bool mask_arg_passed_by_framework = false; | |||||
bool training_arg_passed_by_framework = false; | |||||
Tensor training_value = null; | |||||
if(training_value == null) | |||||
{ | |||||
training_arg_passed_by_framework = true; | |||||
} | |||||
Tensors outputs = null; | |||||
using var ctxManager = CallContext.enter(); | |||||
// using var graph = tf.keras.backend.get_graph().as_default(); | |||||
if (!inputs.IsEagerTensor) | |||||
tf.Context.graph_mode(); | |||||
tf_with(ops.name_scope(_name_scope()), scope => | |||||
{ | |||||
MaybeBuild(inputs); | |||||
outputs = call(inputs); | |||||
outputs = _set_connectivity_metadata_(inputs, outputs); | |||||
_handle_activity_regularization(inputs, outputs); | |||||
_set_mask_metadata(inputs, outputs, null); | |||||
}); | |||||
if (!inputs.IsEagerTensor) | |||||
tf.Context.restore_mode(); | |||||
return outputs; | |||||
} | |||||
private Tensors _set_connectivity_metadata_(Tensors inputs, Tensors outputs) | private Tensors _set_connectivity_metadata_(Tensors inputs, Tensors outputs) | ||||
{ | { | ||||
/*var returnOutputs = new List<Tensor>(); | /*var returnOutputs = new List<Tensor>(); | ||||
@@ -219,8 +261,12 @@ namespace Tensorflow.Keras.Engine | |||||
if (DType == TF_DataType.DtInvalid) | if (DType == TF_DataType.DtInvalid) | ||||
args.DType = inputs.dtype; | args.DType = inputs.dtype; | ||||
var input_shapes = inputs.shape; | |||||
build(input_shapes); | |||||
tf.init_scope(); | |||||
//tf.Context.eager_mode(); | |||||
build(inputs.shape); | |||||
//tf.Context.restore_mode(); | |||||
built = true; | built = true; | ||||
} | } | ||||
@@ -229,10 +275,16 @@ namespace Tensorflow.Keras.Engine | |||||
built = true; | built = true; | ||||
} | } | ||||
protected virtual void add_loss(Func<Tensor> losses) | |||||
{ | |||||
} | |||||
protected virtual IVariableV1 add_weight(string name, | protected virtual IVariableV1 add_weight(string name, | ||||
TensorShape shape, | TensorShape shape, | ||||
TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
IInitializer initializer = null, | IInitializer initializer = null, | ||||
IRegularizer regularizer = null, | |||||
bool? trainable = null, | bool? trainable = null, | ||||
Func<VariableArgs, IVariableV1> getter = null) | Func<VariableArgs, IVariableV1> getter = null) | ||||
{ | { | ||||
@@ -251,7 +303,7 @@ namespace Tensorflow.Keras.Engine | |||||
else if (dtype.is_integer()) | else if (dtype.is_integer()) | ||||
initializer = tf.zeros_initializer; | initializer = tf.zeros_initializer; | ||||
else | else | ||||
throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {this.Name}"); | |||||
throw new ValueError($"An initializer for variable {name} of type {dtype.as_base_dtype()} is required for layer {name}"); | |||||
} | } | ||||
var args = new VariableArgs | var args = new VariableArgs | ||||
@@ -266,6 +318,12 @@ namespace Tensorflow.Keras.Engine | |||||
}; | }; | ||||
var variable = _add_variable_with_custom_getter(args); | var variable = _add_variable_with_custom_getter(args); | ||||
if(regularizer != null) | |||||
{ | |||||
var name_in_scope = variable.Name.Split(':')[0]; | |||||
_handle_weight_regularization(name_in_scope, variable, regularizer); | |||||
} | |||||
//backend.track_variable(variable); | //backend.track_variable(variable); | ||||
if (trainable == true) | if (trainable == true) | ||||
trainableWeights.Add(variable); | trainableWeights.Add(variable); | ||||
@@ -275,6 +333,20 @@ namespace Tensorflow.Keras.Engine | |||||
return variable; | return variable; | ||||
} | } | ||||
/// <summary> | |||||
/// Create lambdas which compute regularization losses. | |||||
/// </summary> | |||||
/// <param name="name"></param> | |||||
/// <param name="variable"></param> | |||||
/// <param name="regularizer"></param> | |||||
void _handle_weight_regularization(string name, IVariableV1 variable, IRegularizer regularizer) | |||||
{ | |||||
add_loss(() => regularizer.Apply(new RegularizerArgs | |||||
{ | |||||
})); | |||||
} | |||||
protected virtual void add_update(Tensor[] updates, bool inputs = false) | protected virtual void add_update(Tensor[] updates, bool inputs = false) | ||||
{ | { | ||||
var updates_op = updates.Select(x => x.op).ToArray(); | var updates_op = updates.Select(x => x.op).ToArray(); | ||||
@@ -284,17 +356,13 @@ namespace Tensorflow.Keras.Engine | |||||
// Determine layer name (non-unique). | // Determine layer name (non-unique). | ||||
protected virtual void _init_set_name(string name, bool zero_based = true) | protected virtual void _init_set_name(string name, bool zero_based = true) | ||||
{ | { | ||||
var base_name = name; | |||||
base_name = name; | |||||
this.name = name; | this.name = name; | ||||
if (name == null) | if (name == null) | ||||
(this.name, baseName) = _make_unique_name(); | |||||
} | |||||
protected virtual (string, string) _make_unique_name() | |||||
{ | |||||
string base_name = generic_utils.to_snake_case(this.GetType().Name); | |||||
string name = base_layer_utils.unique_layer_name(base_name); | |||||
return (name, base_name); | |||||
{ | |||||
base_name = generic_utils.to_snake_case(this.GetType().Name); | |||||
this.name = base_layer_utils.unique_layer_name(base_name, zero_based: zero_based); | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -23,15 +23,14 @@ namespace Tensorflow.Keras.Engine | |||||
string loss; | string loss; | ||||
IOptimizer optimizer; | IOptimizer optimizer; | ||||
IVariableV1 _steps_per_execution; | IVariableV1 _steps_per_execution; | ||||
protected bool _is_graph_network; | |||||
protected Tensors inputs; | |||||
protected Tensors outputs; | |||||
public Model(ModelArgs args) | public Model(ModelArgs args) | ||||
: base(args) | : base(args) | ||||
{ | { | ||||
// Build _output_layers | |||||
/*foreach(var x in args.Outputs) | |||||
{ | |||||
var layer = x.KerasHistory; | |||||
}*/ | |||||
} | } | ||||
public void compile(string optimizerName, string lossName) | public void compile(string optimizerName, string lossName) | ||||
@@ -16,6 +16,7 @@ namespace Tensorflow | |||||
{ | { | ||||
public KerasDataset datasets { get; } = new KerasDataset(); | public KerasDataset datasets { get; } = new KerasDataset(); | ||||
public Initializers initializers { get; } = new Initializers(); | public Initializers initializers { get; } = new Initializers(); | ||||
public Regularizers regularizers { get; } = new Regularizers(); | |||||
public LayersApi layers { get; } = new LayersApi(); | public LayersApi layers { get; } = new LayersApi(); | ||||
public LossesApi losses { get; } = new LossesApi(); | public LossesApi losses { get; } = new LossesApi(); | ||||
public Activations activations { get; } = new Activations(); | public Activations activations { get; } = new Activations(); | ||||
@@ -36,12 +37,8 @@ namespace Tensorflow | |||||
/// <param name="input"></param> | /// <param name="input"></param> | ||||
/// <param name="output"></param> | /// <param name="output"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public Model Model(Tensor input, Tensor output) | |||||
=> new Model(new ModelArgs | |||||
{ | |||||
Inputs = new[] { input }, | |||||
Outputs = new[] { output } | |||||
}); | |||||
public Functional Model(Tensors inputs, Tensors outputs) | |||||
=> new Functional(inputs, outputs); | |||||
/// <summary> | /// <summary> | ||||
/// Instantiate a Keras tensor. | /// Instantiate a Keras tensor. | ||||
@@ -15,73 +15,41 @@ | |||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System; | using System; | ||||
using System.Collections.Generic; | |||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Engine; | |||||
using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow.Keras.Layers | namespace Tensorflow.Keras.Layers | ||||
{ | { | ||||
public class BatchNormalization : Tensorflow.Layers.Layer | |||||
public class BatchNormalization : Layer | |||||
{ | { | ||||
#pragma warning disable CS0414 // The field 'BatchNormalization._USE_V2_BEHAVIOR' is assigned but its value is never used | |||||
private bool _USE_V2_BEHAVIOR = true; | |||||
#pragma warning restore CS0414 // The field 'BatchNormalization._USE_V2_BEHAVIOR' is assigned but its value is never used | |||||
private float momentum; | |||||
private float epsilon; | |||||
private bool center; | |||||
private bool scale; | |||||
private bool renorm; | |||||
private bool fused; | |||||
#pragma warning disable CS0414 // The field 'BatchNormalization._bessels_correction_test_only' is assigned but its value is never used | |||||
private bool _bessels_correction_test_only; | |||||
#pragma warning restore CS0414 // The field 'BatchNormalization._bessels_correction_test_only' is assigned but its value is never used | |||||
private int[] axis; | |||||
private string _data_format; | |||||
private IInitializer beta_initializer; | |||||
private IInitializer gamma_initializer; | |||||
private IInitializer moving_mean_initializer; | |||||
private IInitializer moving_variance_initializer; | |||||
private IVariableV1 gamma; | |||||
private IVariableV1 beta; | |||||
private RefVariable moving_mean; | |||||
private RefVariable moving_variance; | |||||
public BatchNormalization(int axis = -1, | |||||
float momentum = 0.99f, | |||||
float epsilon = 0.001f, | |||||
bool center = true, | |||||
bool scale = true, | |||||
IInitializer beta_initializer = null, | |||||
IInitializer gamma_initializer = null, | |||||
IInitializer moving_mean_initializer = null, | |||||
IInitializer moving_variance_initializer = null, | |||||
bool renorm = false, | |||||
float renorm_momentum = 0.99f, | |||||
bool trainable = true, | |||||
string name = null) : base(trainable: trainable, | |||||
name: name) | |||||
BatchNormalizationArgs args; | |||||
float momentum => args.Momentum; | |||||
float epsilon => args.Epsilon; | |||||
bool center => args.Center; | |||||
bool scale => args.Scale; | |||||
bool renorm => args.Renorm; | |||||
bool fused; | |||||
int[] axis; | |||||
string _data_format; | |||||
IInitializer beta_initializer => args.BetaInitializer; | |||||
IInitializer gamma_initializer => args.GammaInitializer; | |||||
IInitializer moving_mean_initializer; | |||||
IInitializer moving_variance_initializer; | |||||
IRegularizer gamma_regularizer => args.GammaRegularizer; | |||||
IVariableV1 gamma; | |||||
IVariableV1 beta; | |||||
IVariableV1 moving_mean; | |||||
IVariableV1 moving_variance; | |||||
public BatchNormalization(BatchNormalizationArgs args) : base(args) | |||||
{ | { | ||||
this.axis = new int[] { axis }; | |||||
this.momentum = momentum; | |||||
this.epsilon = epsilon; | |||||
this.center = center; | |||||
this.scale = scale; | |||||
if (beta_initializer == null) | |||||
beta_initializer = tf.zeros_initializer; | |||||
if (gamma_initializer == null) | |||||
gamma_initializer = tf.ones_initializer; | |||||
if (moving_mean_initializer == null) | |||||
moving_mean_initializer = tf.zeros_initializer; | |||||
if (moving_variance_initializer == null) | |||||
moving_variance_initializer = tf.ones_initializer; | |||||
this.beta_initializer = beta_initializer; | |||||
this.gamma_initializer = gamma_initializer; | |||||
this.moving_mean_initializer = moving_mean_initializer; | |||||
this.moving_variance_initializer = moving_variance_initializer; | |||||
this.renorm = renorm; | |||||
this.fused = true; | |||||
this.SupportsMasking = true; | |||||
this._bessels_correction_test_only = true; | |||||
this.args = args; | |||||
axis = args.Axis.dims; | |||||
} | } | ||||
protected override void build(TensorShape input_shape) | protected override void build(TensorShape input_shape) | ||||
@@ -91,12 +59,25 @@ namespace Tensorflow.Keras.Layers | |||||
if (x < 0) | if (x < 0) | ||||
axis[idx] = ndims + x; | axis[idx] = ndims + x; | ||||
fused = ndims == 4; | |||||
if (fused) | if (fused) | ||||
if (Enumerable.SequenceEqual(axis, new int[] { 3 })) | |||||
{ | |||||
if (Enumerable.SequenceEqual(axis, new int[] { 1 })) | |||||
_data_format = "NCHW"; | |||||
else if (Enumerable.SequenceEqual(axis, new int[] { 3 })) | |||||
_data_format = "NHWC"; | _data_format = "NHWC"; | ||||
else | |||||
throw new ValueError($"Unsupported axis, fused batch norm only supports axis == [1] or axis == [3]"); | |||||
} | |||||
var axis_to_dim = new Dictionary<int, int>(); | |||||
foreach(var x in axis) | |||||
axis_to_dim[x] = input_shape[x]; | |||||
inputSpec = new InputSpec(ndim: ndims, axes: axis_to_dim); | |||||
var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType; | var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType; | ||||
var param_shape = new int[] { input_shape.dims[axis[0]] }; | |||||
var param_shape = inputSpec.AllAxisDim; | |||||
if (scale) | if (scale) | ||||
gamma = add_weight("gamma", | gamma = add_weight("gamma", | ||||
@@ -116,26 +97,17 @@ namespace Tensorflow.Keras.Layers | |||||
else | else | ||||
throw new NotImplementedException("add_weight beta"); | throw new NotImplementedException("add_weight beta"); | ||||
if(_scope != null) | |||||
{ | |||||
} | |||||
moving_mean = (RefVariable)add_weight("moving_mean", | |||||
moving_mean = add_weight("moving_mean", | |||||
param_shape, | param_shape, | ||||
dtype: param_dtype, | dtype: param_dtype, | ||||
initializer: moving_mean_initializer, | initializer: moving_mean_initializer, | ||||
synchronization: VariableSynchronization.OnRead, | |||||
trainable: false, | |||||
aggregation: VariableAggregation.Mean); | |||||
trainable: false); | |||||
moving_variance = (RefVariable)add_weight("moving_variance", | |||||
moving_variance = add_weight("moving_variance", | |||||
shape: param_shape, | shape: param_shape, | ||||
dtype: param_dtype, | dtype: param_dtype, | ||||
initializer: moving_variance_initializer, | initializer: moving_variance_initializer, | ||||
synchronization: VariableSynchronization.OnRead, | |||||
trainable: false, | |||||
aggregation: VariableAggregation.Mean); | |||||
trainable: false); | |||||
if (renorm) | if (renorm) | ||||
throw new NotImplementedException("build when renorm is true"); | throw new NotImplementedException("build when renorm is true"); | ||||
@@ -178,8 +150,8 @@ namespace Tensorflow.Keras.Layers | |||||
inputs, | inputs, | ||||
gamma, | gamma, | ||||
beta, | beta, | ||||
mean: moving_mean, | |||||
variance: moving_variance, | |||||
mean: moving_mean.AsTensor(), | |||||
variance: moving_variance.AsTensor(), | |||||
epsilon: epsilon, | epsilon: epsilon, | ||||
is_training: false, | is_training: false, | ||||
data_format: _data_format); | data_format: _data_format); | ||||
@@ -202,8 +174,8 @@ namespace Tensorflow.Keras.Layers | |||||
if(training_value == null) | if(training_value == null) | ||||
{ | { | ||||
var mean_update = _assign_moving_average(moving_mean, mean, momentum_tensor); | |||||
var variance_update = _assign_moving_average(moving_variance, variance, momentum_tensor); | |||||
var mean_update = _assign_moving_average(moving_mean.AsTensor(), mean, momentum_tensor); | |||||
var variance_update = _assign_moving_average(moving_variance.AsTensor(), variance, momentum_tensor); | |||||
add_update(new Tensor[] { mean_update }, inputs: true); | add_update(new Tensor[] { mean_update }, inputs: true); | ||||
add_update(new Tensor[] { variance_update }, inputs: true); | add_update(new Tensor[] { variance_update }, inputs: true); | ||||
} | } | ||||
@@ -19,12 +19,11 @@ using Tensorflow.Operations.Activation; | |||||
namespace Tensorflow.Keras.Layers | namespace Tensorflow.Keras.Layers | ||||
{ | { | ||||
public class Conv2D : Conv | |||||
public class Conv2D : Convolutional | |||||
{ | { | ||||
public Conv2D(Conv2DArgs args) | |||||
: base(args) | |||||
public Conv2D(Conv2DArgs args) : base(args) | |||||
{ | { | ||||
} | } | ||||
} | } | ||||
} | } |
@@ -20,13 +20,13 @@ using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
using Tensorflow.Operations.Activation; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Keras.Layers | namespace Tensorflow.Keras.Layers | ||||
{ | { | ||||
public class Conv : Layer | |||||
public class Convolutional : Layer | |||||
{ | { | ||||
ConvArgs args; | |||||
ConvolutionalArgs args; | |||||
protected int rank => args.Rank; | protected int rank => args.Rank; | ||||
protected int filters => args.Filters; | protected int filters => args.Filters; | ||||
protected TensorShape kernel_size => args.KernelSize; | protected TensorShape kernel_size => args.KernelSize; | ||||
@@ -37,13 +37,14 @@ namespace Tensorflow.Keras.Layers | |||||
protected Activation activation => args.Activation; | protected Activation activation => args.Activation; | ||||
protected bool use_bias => args.UseBias; | protected bool use_bias => args.UseBias; | ||||
protected IInitializer kernel_initializer => args.KernelInitializer; | protected IInitializer kernel_initializer => args.KernelInitializer; | ||||
protected IRegularizer kernel_regularizer => args.KernelRegularizer; | |||||
protected IInitializer bias_initializer => args.BiasInitializer; | protected IInitializer bias_initializer => args.BiasInitializer; | ||||
protected IVariableV1 kernel; | protected IVariableV1 kernel; | ||||
protected IVariableV1 bias; | protected IVariableV1 bias; | ||||
protected Convolution _convolution_op; | |||||
string _tf_data_format; | |||||
ConvolutionInternal _convolution_op; | |||||
protected string _tf_data_format; | |||||
public Conv(ConvArgs args) : base(args) | |||||
public Convolutional(ConvolutionalArgs args) : base(args) | |||||
{ | { | ||||
this.args = args; | this.args = args; | ||||
args.KernelSize = conv_utils.normalize_tuple(args.KernelSize.dims, args.Rank, "kernel_size"); | args.KernelSize = conv_utils.normalize_tuple(args.KernelSize.dims, args.Rank, "kernel_size"); | ||||
@@ -65,6 +66,7 @@ namespace Tensorflow.Keras.Layers | |||||
kernel = add_weight(name: "kernel", | kernel = add_weight(name: "kernel", | ||||
shape: kernel_shape, | shape: kernel_shape, | ||||
initializer: kernel_initializer, | initializer: kernel_initializer, | ||||
regularizer: kernel_regularizer, | |||||
trainable: true, | trainable: true, | ||||
dtype: DType); | dtype: DType); | ||||
if (use_bias) | if (use_bias) | ||||
@@ -76,7 +78,7 @@ namespace Tensorflow.Keras.Layers | |||||
var axes = new Dictionary<int, int>(); | var axes = new Dictionary<int, int>(); | ||||
axes.Add(-1, input_channel); | axes.Add(-1, input_channel); | ||||
inputSpec = new InputSpec(ndim: rank + 2, axes: axes); | |||||
inputSpec = new InputSpec(min_ndim: rank + 2, axes: axes); | |||||
string tf_padding; | string tf_padding; | ||||
if (padding == "causal") | if (padding == "causal") | ||||
@@ -84,20 +86,21 @@ namespace Tensorflow.Keras.Layers | |||||
else | else | ||||
tf_padding = padding.ToUpper(); | tf_padding = padding.ToUpper(); | ||||
_convolution_op = nn_ops.Convolution(input_shape, | |||||
kernel.shape, | |||||
tf_padding, | |||||
string tf_op_name = GetType().Name; | |||||
_convolution_op = nn_ops.convolution_internal(tf_padding, | |||||
strides, | strides, | ||||
dilation_rate, | dilation_rate, | ||||
data_format: _tf_data_format); | |||||
data_format: _tf_data_format, | |||||
name: tf_op_name); | |||||
built = true; | built = true; | ||||
} | } | ||||
protected override Tensors call(Tensors inputs, Tensor state = null, bool training = false) | protected override Tensors call(Tensors inputs, Tensor state = null, bool training = false) | ||||
{ | { | ||||
var outputs = _convolution_op.__call__(inputs, kernel); | |||||
var outputs = _convolution_op.Apply(inputs, kernel); | |||||
if (use_bias) | if (use_bias) | ||||
{ | { | ||||
if (data_format == "channels_first") | if (data_format == "channels_first") |
@@ -47,10 +47,10 @@ namespace Tensorflow.Keras.Layers | |||||
} | } | ||||
// moved to base class | // moved to base class | ||||
if (string.IsNullOrEmpty(Name)) | |||||
if (string.IsNullOrEmpty(args.Name)) | |||||
{ | { | ||||
var prefix = "input"; | var prefix = "input"; | ||||
args.Name = prefix + '_' + tf.keras.backend.get_uid(prefix); | |||||
name = prefix + '_' + tf.keras.backend.get_uid(prefix); | |||||
} | } | ||||
if(args.DType == TF_DataType.DtInvalid) | if(args.DType == TF_DataType.DtInvalid) | ||||
@@ -91,7 +91,6 @@ namespace Tensorflow.Keras.Layers | |||||
// input_tensor._keras_mask = None | // input_tensor._keras_mask = None | ||||
new Node(this, new NodeArgs | new Node(this, new NodeArgs | ||||
{ | { | ||||
InputTensors = args.InputTensor, | |||||
Outputs = args.InputTensor | Outputs = args.InputTensor | ||||
}); | }); | ||||
@@ -11,15 +11,35 @@ namespace Tensorflow.Keras.Layers | |||||
{ | { | ||||
public Conv2D Conv2D(int filters, | public Conv2D Conv2D(int filters, | ||||
TensorShape kernel_size = null, | TensorShape kernel_size = null, | ||||
TensorShape strides = null, | |||||
string padding = "valid", | string padding = "valid", | ||||
string activation = "relu") | |||||
=> new Conv2D(new Conv2DArgs | |||||
{ | |||||
Filters = filters, | |||||
KernelSize = kernel_size, | |||||
Padding = padding, | |||||
Activation = GetActivationByName(activation) | |||||
}); | |||||
string data_format = null, | |||||
TensorShape dilation_rate = null, | |||||
int groups = 1, | |||||
string activation = null, | |||||
bool use_bias = true, | |||||
IInitializer kernel_initializer = null, | |||||
IInitializer bias_initializer = null, | |||||
IRegularizer kernel_regularizer = null, | |||||
IRegularizer bias_regularizer = null, | |||||
IRegularizer activity_regularizer = null) | |||||
=> new Conv2D(new Conv2DArgs | |||||
{ | |||||
Rank = 2, | |||||
Filters = filters, | |||||
KernelSize = kernel_size, | |||||
Strides = strides == null ? (1, 1) : strides, | |||||
Padding = padding, | |||||
DataFormat = data_format, | |||||
DilationRate = dilation_rate == null ? (1, 1) : dilation_rate, | |||||
Groups = groups, | |||||
KernelRegularizer = kernel_regularizer, | |||||
KernelInitializer = kernel_initializer == null ? tf.glorot_uniform_initializer : kernel_initializer, | |||||
BiasInitializer = bias_initializer == null ? tf.zeros_initializer : bias_initializer, | |||||
BiasRegularizer = bias_regularizer, | |||||
ActivityRegularizer = activity_regularizer, | |||||
Activation = GetActivationByName(activation) | |||||
}); | |||||
public Dense Dense(int units, | public Dense Dense(int units, | ||||
@@ -65,6 +85,30 @@ namespace Tensorflow.Keras.Layers | |||||
DataFormat = data_format | DataFormat = data_format | ||||
}); | }); | ||||
/// <summary> | |||||
/// `Input()` is used to instantiate a Keras tensor. | |||||
/// </summary> | |||||
/// <param name="shape">A shape tuple not including the batch size.</param> | |||||
/// <param name="name"></param> | |||||
/// <param name="sparse"></param> | |||||
/// <param name="ragged"></param> | |||||
/// <returns></returns> | |||||
public Tensors Input(TensorShape shape, | |||||
string name = null, | |||||
bool sparse = false, | |||||
bool ragged = false) | |||||
{ | |||||
var input_layer = new InputLayer(new InputLayerArgs | |||||
{ | |||||
InputShape = shape, | |||||
Name = name, | |||||
Sparse = sparse, | |||||
Ragged = ragged | |||||
}); | |||||
return input_layer.InboundNodes[0].Outputs; | |||||
} | |||||
public MaxPooling2D MaxPooling2D(TensorShape pool_size = null, | public MaxPooling2D MaxPooling2D(TensorShape pool_size = null, | ||||
TensorShape strides = null, | TensorShape strides = null, | ||||
string padding = "valid") | string padding = "valid") | ||||
@@ -0,0 +1,12 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras | |||||
{ | |||||
public class Regularizers | |||||
{ | |||||
public IRegularizer l2(float l2 = 0.01f) | |||||
=> new L2(l2); | |||||
} | |||||
} |
@@ -0,0 +1,11 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras | |||||
{ | |||||
public interface IRegularizer | |||||
{ | |||||
Tensor Apply(RegularizerArgs args); | |||||
} | |||||
} |
@@ -0,0 +1,21 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras | |||||
{ | |||||
public class L2 : IRegularizer | |||||
{ | |||||
float l2; | |||||
public L2(float l2 = 0.01f) | |||||
{ | |||||
this.l2 = l2; | |||||
} | |||||
public Tensor Apply(RegularizerArgs args) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,10 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras | |||||
{ | |||||
public class RegularizerArgs | |||||
{ | |||||
} | |||||
} |
@@ -55,8 +55,8 @@ namespace Tensorflow.Keras.Utils | |||||
/// </summary> | /// </summary> | ||||
/// <param name="name"></param> | /// <param name="name"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public static string unique_layer_name(string name, Dictionary<(string, string), int> name_uid_map = null, | |||||
string[] avoid_names = null, string @namespace = "", bool zero_based = false) | |||||
public static string unique_layer_name(string name, Dictionary<string, int> name_uid_map = null, | |||||
string[] avoid_names = null, bool zero_based = false) | |||||
{ | { | ||||
if (name_uid_map == null) | if (name_uid_map == null) | ||||
name_uid_map = get_default_graph_uid_map(); | name_uid_map = get_default_graph_uid_map(); | ||||
@@ -66,41 +66,40 @@ namespace Tensorflow.Keras.Utils | |||||
string proposed_name = null; | string proposed_name = null; | ||||
while (proposed_name == null || avoid_names.Contains(proposed_name)) | while (proposed_name == null || avoid_names.Contains(proposed_name)) | ||||
{ | { | ||||
var name_key = (@namespace, name); | |||||
if (!name_uid_map.ContainsKey(name_key)) | |||||
name_uid_map[name_key] = 0; | |||||
if (!name_uid_map.ContainsKey(name)) | |||||
name_uid_map[name] = 0; | |||||
if (zero_based) | if (zero_based) | ||||
{ | { | ||||
int number = name_uid_map[name_key]; | |||||
int number = name_uid_map[name]; | |||||
if (number > 0) | if (number > 0) | ||||
proposed_name = $"{name}_{number}"; | proposed_name = $"{name}_{number}"; | ||||
else | else | ||||
proposed_name = name; | proposed_name = name; | ||||
name_uid_map[name_key] += 1; | |||||
name_uid_map[name] += 1; | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
name_uid_map[name_key] += 1; | |||||
proposed_name = $"{name}_{name_uid_map[name_key]}"; | |||||
name_uid_map[name] += 1; | |||||
proposed_name = $"{name}_{name_uid_map[name]}"; | |||||
} | } | ||||
} | } | ||||
return proposed_name; | return proposed_name; | ||||
} | } | ||||
public static Dictionary<(string, string), int> get_default_graph_uid_map() | |||||
public static Dictionary<string, int> get_default_graph_uid_map() | |||||
{ | { | ||||
var graph = ops.get_default_graph(); | var graph = ops.get_default_graph(); | ||||
Dictionary<(string, string), int> name_uid_map = null; | |||||
Dictionary<string, int> name_uid_map = null; | |||||
if (tf.keras.backend.PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph)) | if (tf.keras.backend.PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph)) | ||||
{ | { | ||||
name_uid_map = tf.keras.backend.PER_GRAPH_LAYER_NAME_UIDS[graph]; | name_uid_map = tf.keras.backend.PER_GRAPH_LAYER_NAME_UIDS[graph]; | ||||
} | } | ||||
else | else | ||||
{ | { | ||||
name_uid_map = new Dictionary<(string, string), int>(); | |||||
name_uid_map = new Dictionary<string, int>(); | |||||
tf.keras.backend.PER_GRAPH_LAYER_NAME_UIDS[graph] = name_uid_map; | tf.keras.backend.PER_GRAPH_LAYER_NAME_UIDS[graph] = name_uid_map; | ||||
} | } | ||||
@@ -183,8 +183,6 @@ namespace Tensorflow.Layers | |||||
}); | }); | ||||
} | } | ||||
protected override string _name_scope() | protected override string _name_scope() | ||||
{ | { | ||||
return _current_scope.original_name_scope; | return _current_scope.original_name_scope; | ||||
@@ -202,7 +200,7 @@ namespace Tensorflow.Layers | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
tf_with(tf.variable_scope(scope, default_name: baseName), captured_scope => | |||||
tf_with(tf.variable_scope(scope, default_name: base_name), captured_scope => | |||||
{ | { | ||||
// convert variable_scope to VariableScope | // convert variable_scope to VariableScope | ||||
_scope = captured_scope; | _scope = captured_scope; | ||||
@@ -41,8 +41,8 @@ namespace Tensorflow.Operations.Initializers | |||||
public Tensor Apply(InitializerArgs args) | public Tensor Apply(InitializerArgs args) | ||||
{ | { | ||||
if (args.DType == TF_DataType.DtInvalid) | if (args.DType == TF_DataType.DtInvalid) | ||||
args.DType = this.dtype; | |||||
return random_ops.random_normal(args.Shape, mean, stddev, dtype, seed: seed); | |||||
args.DType = dtype; | |||||
return random_ops.random_normal(args.Shape, mean, stddev, args.DType, seed: seed); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -1,84 +0,0 @@ | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using System.Linq; | |||||
namespace Tensorflow.Operations | |||||
{ | |||||
public class Convolution | |||||
{ | |||||
public TensorShape input_shape; | |||||
public TensorShape filter_shape; | |||||
public string data_format; | |||||
public int[] strides; | |||||
public string name; | |||||
public _WithSpaceToBatch conv_op; | |||||
public Convolution(TensorShape input_shape, | |||||
TensorShape filter_shape, | |||||
string padding, | |||||
int[] strides, | |||||
int[] dilation_rate, | |||||
string name = null, | |||||
string data_format = null) | |||||
{ | |||||
var num_total_dims = filter_shape.ndim; | |||||
var num_spatial_dims = num_total_dims - 2; | |||||
int input_channels_dim; | |||||
int[] spatial_dims; | |||||
if (string.IsNullOrEmpty(data_format) || !data_format.StartsWith("NC")) | |||||
{ | |||||
input_channels_dim = input_shape.dims[num_spatial_dims + 1]; | |||||
spatial_dims = Enumerable.Range(1, num_spatial_dims).ToArray(); | |||||
} | |||||
else | |||||
{ | |||||
input_channels_dim = input_shape.dims[1]; | |||||
spatial_dims = Enumerable.Range(2, num_spatial_dims).ToArray(); | |||||
} | |||||
this.input_shape = input_shape; | |||||
this.filter_shape = filter_shape; | |||||
this.data_format = data_format; | |||||
this.strides = strides; | |||||
this.name = name; | |||||
conv_op = new _WithSpaceToBatch( | |||||
input_shape, | |||||
dilation_rate: dilation_rate, | |||||
padding: padding, | |||||
build_op: _build_op, | |||||
filter_shape: filter_shape, | |||||
spatial_dims: spatial_dims, | |||||
data_format: data_format); | |||||
} | |||||
public _NonAtrousConvolution _build_op(int _, string padding) | |||||
{ | |||||
return new _NonAtrousConvolution(input_shape, | |||||
filter_shape: filter_shape, | |||||
padding: padding, | |||||
data_format: data_format, | |||||
strides: strides, | |||||
name: name); | |||||
} | |||||
public Tensor __call__(Tensor inp, IVariableV1 filter) | |||||
{ | |||||
return conv_op.__call__(inp, filter); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,100 @@ | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Xml; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Operations | |||||
{ | |||||
internal class ConvolutionInternal | |||||
{ | |||||
ConvolutionalArgs args; | |||||
string data_format => args.DataFormat; | |||||
string name; | |||||
string padding => args.Padding; | |||||
public ConvolutionInternal(ConvolutionalArgs args) | |||||
{ | |||||
this.args = args; | |||||
name = args.Name; | |||||
} | |||||
public Tensor Apply(Tensors input, IVariableV1 filters) | |||||
{ | |||||
var filters_rank = filters.shape.rank; | |||||
var inputs_rank = input.shape.rank; | |||||
var num_spatial_dims = args.NumSpatialDims; | |||||
if (num_spatial_dims == Unknown) | |||||
num_spatial_dims = filters_rank - 2; | |||||
// Channel dimension. | |||||
var num_batch_dims = inputs_rank - num_spatial_dims - 1; | |||||
if (!new[] { 1, 2, 3 }.Contains(num_spatial_dims)) | |||||
throw new ValueError($"num_spatial_dims (input.shape.ndims - num_batch_dims - 1) must be one " + | |||||
$"of 1, 2 or 3 but saw {num_spatial_dims}. num_batch_dims: {num_batch_dims}."); | |||||
var channel_index = num_batch_dims + num_spatial_dims; | |||||
var dilations = _get_sequence(args.DilationRate, num_spatial_dims, channel_index); | |||||
var strides = _get_sequence(args.Strides, num_spatial_dims, channel_index); | |||||
Tensor result = null; | |||||
tf_with(ops.name_scope(name, default_name: null, (input, filters)), scope => | |||||
{ | |||||
name = scope; | |||||
if (num_spatial_dims == 2) | |||||
result = gen_nn_ops.conv2d(new Conv2dParams | |||||
{ | |||||
Input = input, | |||||
Filter = filters.AsTensor(), | |||||
Strides = strides, | |||||
Padding = padding, | |||||
DataFormat = data_format, | |||||
Dilations = dilations, | |||||
Name = name | |||||
}); | |||||
else | |||||
throw new NotImplementedException(""); | |||||
}); | |||||
return result; | |||||
} | |||||
int[] _get_sequence(int[] value, int n, int channel_index) | |||||
{ | |||||
var seq = new List<int>(); | |||||
if (channel_index == 1) | |||||
{ | |||||
seq.Add(1); | |||||
seq.Add(1); | |||||
seq.AddRange(value); | |||||
} | |||||
else | |||||
{ | |||||
seq.Add(1); | |||||
seq.AddRange(value); | |||||
seq.Add(1); | |||||
} | |||||
return seq.ToArray(); | |||||
} | |||||
} | |||||
} |
@@ -1,83 +0,0 @@ | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using System; | |||||
using System.Linq; | |||||
namespace Tensorflow.Operations | |||||
{ | |||||
public class _NonAtrousConvolution | |||||
{ | |||||
public string padding; | |||||
public string name; | |||||
public int[] strides; | |||||
public string data_format; | |||||
private Func<Conv2dParams, Tensor> conv_op; | |||||
public _NonAtrousConvolution(TensorShape input_shape, | |||||
TensorShape filter_shape, | |||||
string padding, | |||||
string data_format, | |||||
int[] strides, | |||||
string name) | |||||
{ | |||||
this.padding = padding; | |||||
this.name = name; | |||||
var conv_dims = input_shape.ndim - 2; | |||||
if (conv_dims == 1) | |||||
{ | |||||
throw new NotImplementedException("_NonAtrousConvolution conv_dims 1"); | |||||
} | |||||
else if (conv_dims == 2) | |||||
{ | |||||
var list = strides.ToList(); | |||||
if (string.IsNullOrEmpty(data_format) || data_format == "NHWC") | |||||
{ | |||||
data_format = "NHWC"; | |||||
list.Insert(0, 1); | |||||
list.Add(1); | |||||
} | |||||
else if (data_format == "NCHW") | |||||
list.InsertRange(0, new int[] { 1, 1 }); | |||||
else | |||||
throw new ValueError("data_format must be \"NHWC\" or \"NCHW\"."); | |||||
strides = list.ToArray(); | |||||
this.strides = strides; | |||||
this.data_format = data_format; | |||||
conv_op = gen_nn_ops.conv2d; | |||||
} | |||||
else if (conv_dims == 3) | |||||
{ | |||||
throw new NotImplementedException("_NonAtrousConvolution conv_dims 3"); | |||||
} | |||||
} | |||||
public Tensor __call__(Tensor inp, IVariableV1 filter) | |||||
{ | |||||
return conv_op(new Conv2dParams | |||||
{ | |||||
Input = inp, | |||||
Filter = filter.AsTensor(), | |||||
Strides = strides, | |||||
Padding = padding, | |||||
DataFormat = data_format, | |||||
Name = name | |||||
}); | |||||
} | |||||
} | |||||
} |
@@ -1,76 +0,0 @@ | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using System; | |||||
using System.Linq; | |||||
namespace Tensorflow.Operations | |||||
{ | |||||
public class _WithSpaceToBatch | |||||
{ | |||||
private _NonAtrousConvolution call; | |||||
public _WithSpaceToBatch(TensorShape input_shape, | |||||
int[] dilation_rate, | |||||
string padding, | |||||
Func<int, string, _NonAtrousConvolution> build_op, | |||||
TensorShape filter_shape = null, | |||||
int[] spatial_dims = null, | |||||
string data_format = null) | |||||
{ | |||||
var dilation_rate_tensor = ops.convert_to_tensor(dilation_rate, TF_DataType.TF_INT32, name: "dilation_rate"); | |||||
var rate_shape = dilation_rate_tensor.TensorShape; | |||||
var num_spatial_dims = rate_shape.dims[0]; | |||||
#pragma warning disable CS0219 // Variable is assigned but its value is never used | |||||
int starting_spatial_dim = -1; | |||||
#pragma warning restore CS0219 // Variable is assigned but its value is never used | |||||
if (!string.IsNullOrEmpty(data_format) && data_format.StartsWith("NC")) | |||||
starting_spatial_dim = 2; | |||||
else | |||||
starting_spatial_dim = 1; | |||||
if (spatial_dims == null) | |||||
throw new NotImplementedException("_WithSpaceToBatch spatial_dims"); | |||||
var orig_spatial_dims = spatial_dims; | |||||
spatial_dims = spatial_dims.OrderBy(x => x).ToArray(); | |||||
if (!Enumerable.SequenceEqual(spatial_dims, orig_spatial_dims) || spatial_dims.Any(x => x < 1)) | |||||
throw new ValueError("spatial_dims must be a montonically increasing sequence of positive integers"); | |||||
int expected_input_rank = -1; | |||||
if (!string.IsNullOrEmpty(data_format) && data_format.StartsWith("NC")) | |||||
expected_input_rank = spatial_dims.Last(); | |||||
else | |||||
expected_input_rank = spatial_dims.Last() + 1; | |||||
var const_rate = tensor_util.constant_value(dilation_rate_tensor); | |||||
var rate_or_const_rate = dilation_rate; | |||||
if(!(const_rate is null)) | |||||
{ | |||||
if (const_rate.Data<int>().Count(x => x == 1) == const_rate.size) | |||||
{ | |||||
call = build_op(num_spatial_dims, padding); | |||||
return; | |||||
} | |||||
} | |||||
} | |||||
public Tensor __call__(Tensor inp, IVariableV1 filter) | |||||
{ | |||||
return call.__call__(inp, filter); | |||||
} | |||||
} | |||||
} |
@@ -78,35 +78,45 @@ namespace Tensorflow.Operations | |||||
/// </summary> | /// </summary> | ||||
/// <param name="parameters"></param> | /// <param name="parameters"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public static Tensor conv2d_backprop_filter(Conv2dParams parameters) | |||||
public static Tensor conv2d_backprop_filter(Tensor input, Tensor filter_sizes, Tensor out_backprop, | |||||
int[] strides, string padding, bool use_cudnn_on_gpu = true, | |||||
int[] explicit_paddings = null, | |||||
string data_format = "NHWC", | |||||
int[] dilations = null, | |||||
string name = null) | |||||
{ | { | ||||
if (explicit_paddings == null) | |||||
explicit_paddings = new int[0]; | |||||
if (dilations == null) | |||||
dilations = new int[] { 1, 1, 1, 1 }; | |||||
if (tf.executing_eagerly()) | if (tf.executing_eagerly()) | ||||
{ | { | ||||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | ||||
"Conv2DBackpropFilter", parameters.Name, | |||||
"Conv2DBackpropFilter", name, | |||||
null, | null, | ||||
parameters.Input, parameters.FilterSizes, parameters.OutBackProp, | |||||
"strides", parameters.Strides, | |||||
"use_cudnn_on_gpu", parameters.UseCudnnOnGpu, | |||||
"padding", parameters.Padding, | |||||
"explicit_paddings", parameters.ExplicitPaddings, | |||||
"data_format", parameters.DataFormat, | |||||
"dilations", parameters.Dilations); | |||||
input, filter_sizes, out_backprop, | |||||
"strides", strides, | |||||
"use_cudnn_on_gpu", use_cudnn_on_gpu, | |||||
"padding", padding, | |||||
"explicit_paddings", explicit_paddings, | |||||
"data_format", data_format, | |||||
"dilations", dilations); | |||||
return results[0]; | return results[0]; | ||||
} | } | ||||
var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropFilter", name: parameters.Name, args: new | |||||
var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropFilter", name: name, args: new | |||||
{ | { | ||||
input = parameters.Input, | |||||
filter_sizes = parameters.FilterSizes, | |||||
out_backprop = parameters.OutBackProp, | |||||
strides = parameters.Strides, | |||||
padding = parameters.Padding, | |||||
use_cudnn_on_gpu = parameters.UseCudnnOnGpu, | |||||
explicit_paddings = parameters.ExplicitPaddings, | |||||
data_format = parameters.DataFormat, | |||||
dilations = parameters.Dilations | |||||
input, | |||||
filter_sizes, | |||||
out_backprop, | |||||
strides, | |||||
padding, | |||||
use_cudnn_on_gpu, | |||||
explicit_paddings, | |||||
data_format, | |||||
dilations | |||||
}); | }); | ||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
@@ -117,35 +127,45 @@ namespace Tensorflow.Operations | |||||
/// </summary> | /// </summary> | ||||
/// <param name="parameters"></param> | /// <param name="parameters"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public static Tensor conv2d_backprop_input(Conv2dParams parameters) | |||||
public static Tensor conv2d_backprop_input(Tensor input_sizes, Tensor filter, Tensor out_backprop, | |||||
int[] strides, string padding, bool use_cudnn_on_gpu = true, | |||||
int[] explicit_paddings = null, | |||||
string data_format= "NHWC", | |||||
int[] dilations = null, | |||||
string name = null) | |||||
{ | { | ||||
if (explicit_paddings == null) | |||||
explicit_paddings = new int[0]; | |||||
if (dilations == null) | |||||
dilations = new int[] { 1, 1, 1, 1 }; | |||||
if (tf.executing_eagerly()) | if (tf.executing_eagerly()) | ||||
{ | { | ||||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | ||||
"Conv2DBackpropInput", parameters.Name, | |||||
"Conv2DBackpropInput", name, | |||||
null, | null, | ||||
parameters.InputSizes, parameters.Filter, parameters.OutBackProp, | |||||
"strides", parameters.Strides, | |||||
"use_cudnn_on_gpu", parameters.UseCudnnOnGpu, | |||||
"padding", parameters.Padding, | |||||
"explicit_paddings", parameters.ExplicitPaddings, | |||||
"data_format", parameters.DataFormat, | |||||
"dilations", parameters.Dilations); | |||||
input_sizes, filter, out_backprop, | |||||
"strides", strides, | |||||
"use_cudnn_on_gpu", use_cudnn_on_gpu, | |||||
"padding", padding, | |||||
"explicit_paddings", explicit_paddings, | |||||
"data_format", data_format, | |||||
"dilations", dilations); | |||||
return results[0]; | return results[0]; | ||||
} | } | ||||
var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropInput", name: parameters.Name, args: new | |||||
var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropInput", name: name, args: new | |||||
{ | { | ||||
input_sizes = parameters.InputSizes, | |||||
filter = parameters.Filter, | |||||
out_backprop = parameters.OutBackProp, | |||||
strides = parameters.Strides, | |||||
padding = parameters.Padding, | |||||
use_cudnn_on_gpu = parameters.UseCudnnOnGpu, | |||||
explicit_paddings = parameters.ExplicitPaddings, | |||||
data_format = parameters.DataFormat, | |||||
dilations = parameters.Dilations | |||||
input_sizes, | |||||
filter, | |||||
out_backprop, | |||||
strides, | |||||
padding, | |||||
use_cudnn_on_gpu, | |||||
explicit_paddings, | |||||
data_format, | |||||
dilations | |||||
}); | }); | ||||
return _op.outputs[0]; | return _op.outputs[0]; | ||||
@@ -33,11 +33,6 @@ namespace Tensorflow | |||||
/// <returns></returns> | /// <returns></returns> | ||||
public static Tensor random_standard_normal(Tensor shape, TF_DataType dtype = TF_DataType.DtInvalid, int? seed = null, int? seed2 = null, string name = null) | public static Tensor random_standard_normal(Tensor shape, TF_DataType dtype = TF_DataType.DtInvalid, int? seed = null, int? seed2 = null, string name = null) | ||||
{ | { | ||||
if (!seed.HasValue) | |||||
seed = 0; | |||||
if (!seed2.HasValue) | |||||
seed2 = 0; | |||||
if (tf.executing_eagerly()) | if (tf.executing_eagerly()) | ||||
{ | { | ||||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | ||||
@@ -51,6 +46,11 @@ namespace Tensorflow | |||||
return results[0]; | return results[0]; | ||||
} | } | ||||
if (!seed.HasValue) | |||||
seed = 0; | |||||
if (!seed2.HasValue) | |||||
seed2 = 0; | |||||
var _op = tf.OpDefLib._apply_op_helper("RandomStandardNormal", | var _op = tf.OpDefLib._apply_op_helper("RandomStandardNormal", | ||||
name: name, | name: name, | ||||
args: new { shape, dtype, seed, seed2 }); | args: new { shape, dtype, seed, seed2 }); | ||||
@@ -16,6 +16,7 @@ | |||||
using System; | using System; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -23,19 +24,18 @@ namespace Tensorflow | |||||
{ | { | ||||
public class nn_ops | public class nn_ops | ||||
{ | { | ||||
public static Convolution Convolution(TensorShape input_shape, | |||||
TensorShape filter_shape, | |||||
string padding, | |||||
internal static ConvolutionInternal convolution_internal(string padding, | |||||
int[] strides, | int[] strides, | ||||
int[] dilation_rate, | int[] dilation_rate, | ||||
string name = null, | string name = null, | ||||
string data_format = null) => new Convolution(input_shape, | |||||
filter_shape, | |||||
padding, | |||||
strides, | |||||
dilation_rate, | |||||
name: name, | |||||
data_format: data_format); | |||||
string data_format = null) => new ConvolutionInternal(new ConvolutionalArgs | |||||
{ | |||||
Padding = padding, | |||||
Strides = strides, | |||||
DilationRate = dilation_rate, | |||||
DataFormat = data_format, | |||||
Name = name | |||||
}); | |||||
/// <summary> | /// <summary> | ||||
/// Adds `bias` to `value`. | /// Adds `bias` to `value`. | ||||
@@ -64,7 +64,7 @@ namespace Tensorflow | |||||
/// The string name of this tensor.<br/> | /// The string name of this tensor.<br/> | ||||
/// Tensor.name is meaningless when eager execution is enabled. | /// Tensor.name is meaningless when eager execution is enabled. | ||||
/// </summary> | /// </summary> | ||||
public string name => $"{(op == null ? "<unnamed>" : $"{op.name}:{_value_index}")}"; | |||||
public virtual string name => $"{(op == null ? "<unnamed>" : $"{op.name}:{_value_index}")}"; | |||||
/// <summary> | /// <summary> | ||||
/// The index of this tensor in the outputs of its Operation. | /// The index of this tensor in the outputs of its Operation. | ||||
@@ -132,7 +132,7 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
public int this[int index] => dims[index]; | |||||
public int this[int index] => index < 0 ? dims[ndim + index] : dims[index]; | |||||
/// <summary> | /// <summary> | ||||
/// Returns True iff `self` is fully defined in every dimension. | /// Returns True iff `self` is fully defined in every dimension. | ||||
@@ -8,7 +8,7 @@ using static Tensorflow.Binding; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public class BaseResourceVariable : DisposableObject, IVariableV1 | |||||
public class BaseResourceVariable : DisposableObject | |||||
{ | { | ||||
protected string _name; | protected string _name; | ||||
public virtual string Name => _handle_name; | public virtual string Name => _handle_name; | ||||
@@ -92,7 +92,8 @@ namespace Tensorflow | |||||
return assign_op; | return assign_op; | ||||
} | } | ||||
public Tensor value() => tf.executing_eagerly() ? _read_variable_op() : GraphElement; | |||||
public Tensor value() | |||||
=> GraphElement ?? _read_variable_op(); | |||||
protected Tensor _read_variable_op() | protected Tensor _read_variable_op() | ||||
{ | { | ||||
@@ -159,7 +160,15 @@ namespace Tensorflow | |||||
{ | { | ||||
} | } | ||||
public Tensor AsTensor() | |||||
=> tf.executing_eagerly() ? read_value() : GraphElement; | |||||
public Tensor AsTensor(bool as_ref = true) | |||||
{ | |||||
if (!as_ref && GraphElement != null) | |||||
return GraphElement; | |||||
if (as_ref) | |||||
return tf.executing_eagerly() ? read_value() : GraphElement; | |||||
else | |||||
return _read_variable_op(); | |||||
} | |||||
} | } | ||||
} | } |
@@ -49,6 +49,6 @@ namespace Tensorflow | |||||
public TensorShape shape { get; } | public TensorShape shape { get; } | ||||
Tensor assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true); | Tensor assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true); | ||||
Tensor assign<T>(T value, bool use_locking = false, string name = null, bool read_value = true); | Tensor assign<T>(T value, bool use_locking = false, string name = null, bool read_value = true); | ||||
Tensor AsTensor(); | |||||
Tensor AsTensor(bool as_ref = true); | |||||
} | } | ||||
} | } |
@@ -152,7 +152,7 @@ namespace Tensorflow | |||||
if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES)) | if (trainable && !collections.Contains(tf.GraphKeys.TRAINABLE_VARIABLES)) | ||||
collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); | collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); | ||||
tf_with(ops.init_scope2(), delegate | |||||
tf_with(ops.init_scope(), init_scope => | |||||
{ | { | ||||
var values = init_from_fn ? new object[0] : new object[] { initial_value }; | var values = init_from_fn ? new object[0] : new object[] { initial_value }; | ||||
tf_with(ops.name_scope(name, "Variable", values), scope => | tf_with(ops.name_scope(name, "Variable", values), scope => | ||||
@@ -222,7 +222,7 @@ namespace Tensorflow | |||||
public Tensor value() => _snapshot; | public Tensor value() => _snapshot; | ||||
public Tensor AsTensor() => _snapshot; | |||||
public Tensor AsTensor(bool as_ref = true) => _snapshot; | |||||
public Tensor _as_graph_element() => _variable; | public Tensor _as_graph_element() => _variable; | ||||
@@ -26,7 +26,7 @@ namespace Tensorflow | |||||
/// <summary> | /// <summary> | ||||
/// Variable based on resource handles. | /// Variable based on resource handles. | ||||
/// </summary> | /// </summary> | ||||
public partial class ResourceVariable : BaseResourceVariable | |||||
public partial class ResourceVariable : BaseResourceVariable, IVariableV1 | |||||
{ | { | ||||
Tensor _cached_value; | Tensor _cached_value; | ||||
public string Device => handle.Device; | public string Device => handle.Device; | ||||
@@ -90,7 +90,7 @@ namespace Tensorflow | |||||
collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); | collections.Add(tf.GraphKeys.TRAINABLE_VARIABLES); | ||||
_in_graph_mode = !tf.Context.executing_eagerly(); | _in_graph_mode = !tf.Context.executing_eagerly(); | ||||
tf_with(ops.init_scope2(), delegate | |||||
tf_with(ops.init_scope(), init_scope => | |||||
{ | { | ||||
var values = init_from_fn ? new object[0] : new object[] { initial_value }; | var values = init_from_fn ? new object[0] : new object[] { initial_value }; | ||||
tf_with(ops.name_scope(name, "Variable", values), scope => | tf_with(ops.name_scope(name, "Variable", values), scope => | ||||
@@ -239,11 +239,8 @@ namespace Tensorflow | |||||
/// A context manager that lifts ops out of control-flow scopes and function-building graphs. | /// A context manager that lifts ops out of control-flow scopes and function-building graphs. | ||||
/// </summary> | /// </summary> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public static void init_scope() | |||||
public static NameScope init_scope() | |||||
{ | { | ||||
if (tf.Context.executing_eagerly()) | |||||
return; | |||||
// Retrieve the active name scope: entering an `init_scope` preserves | // Retrieve the active name scope: entering an `init_scope` preserves | ||||
// the name scope of the current context. | // the name scope of the current context. | ||||
var default_graph = get_default_graph(); | var default_graph = get_default_graph(); | ||||
@@ -257,25 +254,11 @@ namespace Tensorflow | |||||
tf_with(ops.control_dependencies(null), delegate | tf_with(ops.control_dependencies(null), delegate | ||||
{ | { | ||||
var outer_graph = get_default_graph(); | |||||
// var outer_graph = get_default_graph(); | |||||
// outer_device_stack = None | // outer_device_stack = None | ||||
}); | }); | ||||
} | |||||
public static ITensorFlowObject init_scope2() | |||||
{ | |||||
// Retrieve the active name scope: entering an `init_scope` preserves | |||||
// the name scope of the current context. | |||||
var default_graph = get_default_graph(); | |||||
var scope = default_graph.get_name_scope(); | |||||
if (!String.IsNullOrEmpty(scope) && !scope.EndsWith("/")) | |||||
// Names that end with trailing slashes are treated by `name_scope` as | |||||
// absolute. | |||||
scope += "/"; | |||||
// inner_device_stack = default_graph._device_function_stack | |||||
// var outer_context = default_graph.as_default; | |||||
return ops.control_dependencies(null); | |||||
return ops.name_scope(scope); | |||||
} | } | ||||
private static int uid_number = -1; | private static int uid_number = -1; | ||||
@@ -460,6 +443,8 @@ namespace Tensorflow | |||||
{ | { | ||||
case NDArray nd: | case NDArray nd: | ||||
return constant_op.constant(nd, dtype: dtype, name: name); | return constant_op.constant(nd, dtype: dtype, name: name); | ||||
case EagerTensor tensor: | |||||
return tf.executing_eagerly() ? tensor : tensor.AsPlaceholder(name: name); | |||||
case Tensor tensor: | case Tensor tensor: | ||||
return tensor; | return tensor; | ||||
case Tensor[] tensors: | case Tensor[] tensors: | ||||
@@ -90,6 +90,7 @@ namespace Tensorflow | |||||
return (scope_name, old_name); | return (scope_name, old_name); | ||||
} | } | ||||
[DebuggerHidden] | |||||
public void Dispose() | public void Dispose() | ||||
{ | { | ||||
if (tf.Context.executing_eagerly()) | if (tf.Context.executing_eagerly()) | ||||
@@ -28,7 +28,8 @@ namespace TensorFlowNET.UnitTest.Keras | |||||
// Create a simple model. | // Create a simple model. | ||||
var inputs = keras.Input(shape: 32); | var inputs = keras.Input(shape: 32); | ||||
var outputs = keras.layers.Dense(1).Apply(inputs); | |||||
var dense_layer = keras.layers.Dense(1); | |||||
var outputs = dense_layer.Apply(inputs); | |||||
var model = keras.Model(inputs, outputs); | var model = keras.Model(inputs, outputs); | ||||
model.compile("adam", "mean_squared_error"); | model.compile("adam", "mean_squared_error"); | ||||
return model; | return model; | ||||
@@ -8,6 +8,11 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||||
[TestClass] | [TestClass] | ||||
public class BitwiseApiTest : TFNetApiTest | public class BitwiseApiTest : TFNetApiTest | ||||
{ | { | ||||
[TestInitialize] | |||||
public void Init() | |||||
{ | |||||
tf.enable_eager_execution(); | |||||
} | |||||
[TestMethod] | [TestMethod] | ||||
public void BitwiseAnd() | public void BitwiseAnd() | ||||