diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index 254f4ded..d8732224 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -424,6 +424,9 @@ namespace Tensorflow return true; } + public static bool empty(this Queue queue) + => queue.Count == 0; + public static TValue SetDefault(this Dictionary dic, TKey key, TValue value) { if (dic.ContainsKey(key)) diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.Layers.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.Layers.cs index 1cddc769..cccb605b 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.Layers.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.Layers.cs @@ -11,6 +11,7 @@ namespace Tensorflow.Keras.Engine public partial class Layer { protected List _layers = new List(); + public List Layers => _layers; protected Layer Dense(int units, Activation activation = null, diff --git a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs index f15e2edf..00cd858d 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/Layer.cs @@ -61,7 +61,6 @@ namespace Tensorflow.Keras.Engine protected List trainable_weights; public virtual List trainable_variables => trainable_weights; - protected List non_trainable_weights; public List non_trainable_variables => non_trainable_weights; @@ -83,7 +82,8 @@ namespace Tensorflow.Keras.Engine ThreadLocal callContext; public CallContext CallContext => callContext.Value; public Tensor[] input => inboundNodes[0].input_tensors; - + public Dictionary> NodesByDepth { get; set; } + public TensorShape output_shape => inboundNodes[0].Outputs.shape; public Layer(LayerArgs args) { this.args = args; @@ -224,5 +224,23 @@ namespace Tensorflow.Keras.Engine this.name = base_layer_utils.unique_layer_name(base_name, zero_based: zero_based); } } + + public int count_params() + { + if (Trainable) + return layer_utils.count_params(this, weights); + return 0; + } + + public List weights + { + get + { + var weights = new List(); + weights.AddRange(trainable_weights); + weights.AddRange(non_trainable_weights); + return weights; + } + } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 7f1f9be0..c35ff349 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -388,19 +388,19 @@ namespace Tensorflow return _op.outputs[0]; } - public static Tensor select(Tensor condition, Tx t, Ty e, string name = null) + public static Tensor select(Tensor condition, Tx x, Ty y, string name = null) { if (tf.Context.executing_eagerly()) { var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "SelectV2", name, + "Select", name, null, - condition, t, e); + condition, x, y); return results[0]; } - var _op = tf.OpDefLib._apply_op_helper("Select", name, new { condition, t, e }); + var _op = tf.OpDefLib._apply_op_helper("Select", name, new { condition, t = x, e = y }); return _op.outputs[0]; } @@ -580,26 +580,33 @@ namespace Tensorflow /// An optional `int`. Defaults to `0`. /// A name for the operation (optional). /// A `Tensor`. Has the same type as `dy`. - public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy, + public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy, int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, int shrink_axis_mask = 0, string name = null) - { - var op = tf.OpDefLib._apply_op_helper("StridedSliceGrad", name: name, args: new - { - shape, - begin, - end, - strides, - dy, - begin_mask, - end_mask, - ellipsis_mask, - new_axis_mask, - shrink_axis_mask - }); - - return op.output; - } + => tf.Context.RunInAutoMode(() + => tf.OpDefLib._apply_op_helper("StridedSliceGrad", name, new + { + shape, + begin, + end, + strides, + dy, + begin_mask, + end_mask, + ellipsis_mask, + new_axis_mask, + shrink_axis_mask + }).output, () + => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "StridedSliceGrad", name, + null, + shape, begin, end, strides, dy, + "begin_mask", begin_mask, + "end_mask", end_mask, + "ellipsis_mask", ellipsis_mask, + "new_axis_mask", new_axis_mask, + "shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), + shape, begin, end, strides, dy); public static Tensor slice(Tensor input, Tb begin, Ts size, string name = null) {