Browse Source

Fix fused_batch_norm_v3 for eager mode.

tags/v0.30
Oceania2018 5 years ago
parent
commit
ce0d722355
21 changed files with 42 additions and 20 deletions
  1. +1
    -0
      src/TensorFlowNET.Core/Gradients/gradient_exclustions.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Keras/Engine/Flatten.cs
  3. +2
    -1
      src/TensorFlowNET.Core/Keras/Engine/Functional.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Keras/Engine/Layer.Apply.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Keras/Engine/Layer.FunctionalConstructionCall.cs
  6. +1
    -1
      src/TensorFlowNET.Core/Keras/Engine/Layer.cs
  7. +2
    -2
      src/TensorFlowNET.Core/Keras/Engine/TensorFlowOpLayer.cs
  8. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
  9. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Convolutional.cs
  10. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Dense.cs
  11. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Dropout.cs
  12. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Embedding.cs
  13. +2
    -2
      src/TensorFlowNET.Core/Keras/Layers/LSTM.cs
  14. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs
  15. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs
  16. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/ZeroPadding2D.cs
  17. +1
    -1
      src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs
  18. +1
    -1
      src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs
  19. +17
    -0
      src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
  20. +1
    -1
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  21. +3
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Value.cs

+ 1
- 0
src/TensorFlowNET.Core/Gradients/gradient_exclustions.cs View File

@@ -20,6 +20,7 @@ namespace Tensorflow.Gradients
public static int[] OpGradientUnusedOutputIndices(string op_name)
=> op_name switch
{
"FusedBatchNormV3" => new[] { 0, 1, 2 },
"ReadVariableOp" => new int[0],
"SoftmaxCrossEntropyWithLogits" => new[] { 0 },
"TensorArrayConcat" => new[] { 0 },


+ 1
- 1
src/TensorFlowNET.Core/Keras/Engine/Flatten.cs View File

@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine
_channels_first = args.DataFormat == "channels_first";
}

protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
{
if (_channels_first)
{


+ 2
- 1
src/TensorFlowNET.Core/Keras/Engine/Functional.cs View File

@@ -268,7 +268,7 @@ namespace Tensorflow.Keras.Engine
nodes_in_decreasing_depth.Insert(nodes_in_decreasing_depth.Count, node);
}

protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
{
return run_internal_graph(inputs, is_training);
}
@@ -305,6 +305,7 @@ namespace Tensorflow.Keras.Engine
tensor_dict[node.FlatInputIds[0]] = new Tensor[0];

var outputs = node.Layer.Apply(layer_inputs, is_training: training);
// Update tensor_dict.
foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs))
tensor_dict[x_id] = Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y).ToArray();


+ 1
- 1
src/TensorFlowNET.Core/Keras/Engine/Layer.Apply.cs View File

@@ -46,7 +46,7 @@ namespace Tensorflow.Keras.Engine
if (!built)
MaybeBuild(inputs);

outputs = CallFn(inputs, state: state, is_training: is_training);
outputs = Call(inputs, state: state, is_training: is_training);

outputs = _set_connectivity_metadata_(inputs, outputs);
_handle_activity_regularization(inputs, outputs);


+ 1
- 1
src/TensorFlowNET.Core/Keras/Engine/Layer.FunctionalConstructionCall.cs View File

@@ -42,7 +42,7 @@ namespace Tensorflow.Keras.Engine
if (!dynamic)
throw new NotImplementedException("");

outputs = CallFn(inputs);
outputs = Call(inputs);

outputs = _set_connectivity_metadata_(inputs, outputs);
_handle_activity_regularization(inputs, outputs);


+ 1
- 1
src/TensorFlowNET.Core/Keras/Engine/Layer.cs View File

@@ -162,7 +162,7 @@ namespace Tensorflow.Keras.Engine
/// <param name="state"></param>
/// <param name="is_training"></param>
/// <returns></returns>
protected virtual Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false)
protected virtual Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
{
throw new NotImplementedException("");
}


+ 2
- 2
src/TensorFlowNET.Core/Keras/Engine/TensorFlowOpLayer.cs View File

@@ -23,9 +23,9 @@ namespace Tensorflow.Keras.Engine
built = true;
}

protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
{
return base.CallFn(inputs, state, is_training);
return base.Call(inputs, state, is_training);
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs View File

@@ -119,7 +119,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
{
Tensor outputs = null;



+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Convolutional.cs View File

@@ -98,7 +98,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool training = false)
{
var outputs = _convolution_op.Apply(inputs, kernel);
if (use_bias)


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Dense.cs View File

@@ -65,7 +65,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool training = false)
{
Tensor outputs = null;
var rank = inputs.rank;


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Dropout.cs View File

@@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Layers
this.args = args;
}

protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
{
var output = tf_utils.smart_cond(is_training,
() => tf.nn.dropout(inputs,


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Embedding.cs View File

@@ -62,7 +62,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
{
var dtype = inputs.dtype;
if (dtype != tf.int32 && dtype != tf.int64)


+ 2
- 2
src/TensorFlowNET.Core/Keras/Layers/LSTM.cs View File

@@ -29,9 +29,9 @@ namespace Tensorflow.Keras.Layers
.ToArray();
}

protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
{
return base.CallFn(inputs, state: state, is_training: is_training);
return base.Call(inputs, state: state, is_training: is_training);
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs View File

@@ -36,7 +36,7 @@ namespace Tensorflow.Keras.Layers
input_spec = new InputSpec(ndim: 4);
}

protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
{
int[] pool_shape;
int[] strides;


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs View File

@@ -20,7 +20,7 @@ namespace Tensorflow.Keras.Layers
this.args = args;
}

protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
{
scale = math_ops.cast(args.Scale, args.DType);
offset = math_ops.cast(args.Offset, args.DType);


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/ZeroPadding2D.cs View File

@@ -29,7 +29,7 @@ namespace Tensorflow.Keras.Layers
this.input_spec = new InputSpec(ndim: 4);
}

protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
{
return tf.keras.backend.spatial_2d_padding(inputs,
padding: padding,


+ 1
- 1
src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs View File

@@ -74,7 +74,7 @@ namespace Tensorflow
/// <param name="training"></param>
/// <param name="state"></param>
/// <returns></returns>
protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
{
var one = constant_op.constant(1, dtype: dtypes.int32);
// Parameters of gates are concatenated into one multiply for efficiency.


+ 1
- 1
src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs View File

@@ -67,7 +67,7 @@ namespace Tensorflow
built = true;
}

protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
{
// Most basic RNN: output = new_state = act(W * input + U * state + B).
var concat = array_ops.concat(new Tensor[] { inputs, state }, 1);


+ 17
- 0
src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs View File

@@ -321,6 +321,23 @@ namespace Tensorflow.Operations
bool is_training = true,
string name = null)
{
if (tf.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"FusedBatchNormV3", name,
null,
x,
scale,
offset,
mean,
variance,
"epsilon", epsilon,
"data_format", data_format,
"is_training", is_training);

return results;
}

var _op = tf.OpDefLib._apply_op_helper("FusedBatchNormV3", name: name, args: new
{
x,


+ 1
- 1
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -79,7 +79,7 @@ https://tensorflownet.readthedocs.io</Description>

<ItemGroup>
<PackageReference Include="Google.Protobuf" Version="3.11.4" />
<PackageReference Include="NumSharp.Lite" Version="0.1.8" />
<PackageReference Include="NumSharp.Lite" Version="0.1.9" />
<PackageReference Include="Protobuf.Text" Version="0.4.0" />
</ItemGroup>



+ 3
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Value.cs View File

@@ -158,6 +158,9 @@ namespace Tensorflow
UnmanagedStorage storage;
switch (dtype)
{
case TF_DataType.TF_BOOL:
storage = new UnmanagedStorage(NPTypeCode.Boolean);
break;
case TF_DataType.TF_STRING:
return np.array(StringBytes()[0]);
case TF_DataType.TF_INT32:


Loading…
Cancel
Save