Browse Source

Fix strided_slice_grad for Eager mode.

tags/v0.30
Oceania2018 5 years ago
parent
commit
1deaa75dfe
4 changed files with 53 additions and 24 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/Binding.Util.cs
  2. +1
    -0
      src/TensorFlowNET.Core/Keras/Engine/Layer.Layers.cs
  3. +20
    -2
      src/TensorFlowNET.Core/Keras/Engine/Layer.cs
  4. +29
    -22
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs

+ 3
- 0
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -424,6 +424,9 @@ namespace Tensorflow
return true;
}

public static bool empty<T>(this Queue<T> queue)
=> queue.Count == 0;

public static TValue SetDefault<TKey, TValue>(this Dictionary<TKey, TValue> dic, TKey key, TValue value)
{
if (dic.ContainsKey(key))


+ 1
- 0
src/TensorFlowNET.Core/Keras/Engine/Layer.Layers.cs View File

@@ -11,6 +11,7 @@ namespace Tensorflow.Keras.Engine
public partial class Layer
{
protected List<Layer> _layers = new List<Layer>();
public List<Layer> Layers => _layers;

protected Layer Dense(int units,
Activation activation = null,


+ 20
- 2
src/TensorFlowNET.Core/Keras/Engine/Layer.cs View File

@@ -61,7 +61,6 @@ namespace Tensorflow.Keras.Engine
protected List<IVariableV1> trainable_weights;

public virtual List<IVariableV1> trainable_variables => trainable_weights;

protected List<IVariableV1> non_trainable_weights;
public List<IVariableV1> non_trainable_variables => non_trainable_weights;
@@ -83,7 +82,8 @@ namespace Tensorflow.Keras.Engine
ThreadLocal<CallContext> callContext;
public CallContext CallContext => callContext.Value;
public Tensor[] input => inboundNodes[0].input_tensors;

public Dictionary<int, List<Node>> NodesByDepth { get; set; }
public TensorShape output_shape => inboundNodes[0].Outputs.shape;
public Layer(LayerArgs args)
{
this.args = args;
@@ -224,5 +224,23 @@ namespace Tensorflow.Keras.Engine
this.name = base_layer_utils.unique_layer_name(base_name, zero_based: zero_based);
}
}

public int count_params()
{
if (Trainable)
return layer_utils.count_params(this, weights);
return 0;
}

public List<IVariableV1> weights
{
get
{
var weights = new List<IVariableV1>();
weights.AddRange(trainable_weights);
weights.AddRange(non_trainable_weights);
return weights;
}
}
}
}

+ 29
- 22
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -388,19 +388,19 @@ namespace Tensorflow
return _op.outputs[0];
}

public static Tensor select<Tx, Ty>(Tensor condition, Tx t, Ty e, string name = null)
public static Tensor select<Tx, Ty>(Tensor condition, Tx x, Ty y, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"SelectV2", name,
"Select", name,
null,
condition, t, e);
condition, x, y);

return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("Select", name, new { condition, t, e });
var _op = tf.OpDefLib._apply_op_helper("Select", name, new { condition, t = x, e = y });
return _op.outputs[0];
}

@@ -580,26 +580,33 @@ namespace Tensorflow
/// <param name="shrink_axis_mask">An optional `int`. Defaults to `0`.</param>
/// <param name="name">A name for the operation (optional).</param>
/// <returns>A `Tensor`. Has the same type as `dy`.</returns>
public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy,
public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy,
int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0,
int shrink_axis_mask = 0, string name = null)
{
var op = tf.OpDefLib._apply_op_helper("StridedSliceGrad", name: name, args: new
{
shape,
begin,
end,
strides,
dy,
begin_mask,
end_mask,
ellipsis_mask,
new_axis_mask,
shrink_axis_mask
});

return op.output;
}
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("StridedSliceGrad", name, new
{
shape,
begin,
end,
strides,
dy,
begin_mask,
end_mask,
ellipsis_mask,
new_axis_mask,
shrink_axis_mask
}).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"StridedSliceGrad", name,
null,
shape, begin, end, strides, dy,
"begin_mask", begin_mask,
"end_mask", end_mask,
"ellipsis_mask", ellipsis_mask,
"new_axis_mask", new_axis_mask,
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(),
shape, begin, end, strides, dy);

public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null)
{


Loading…
Cancel
Save