Browse Source

rename call to call_fn.

tags/v0.30
Oceania2018 5 years ago
parent
commit
e92aa44c1d
23 changed files with 38 additions and 34 deletions
  1. +3
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/TensorFlowOpLayerArgs.cs
  2. +2
    -2
      src/TensorFlowNET.Core/Keras/BackendImpl.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Keras/Engine/Flatten.cs
  4. +6
    -2
      src/TensorFlowNET.Core/Keras/Engine/Functional.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Keras/Engine/Layer.Apply.cs
  6. +1
    -1
      src/TensorFlowNET.Core/Keras/Engine/Layer.FunctionalConstructionCall.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Keras/Engine/Layer.cs
  8. +0
    -2
      src/TensorFlowNET.Core/Keras/Engine/Node.cs
  9. +2
    -2
      src/TensorFlowNET.Core/Keras/Engine/TensorFlowOpLayer.cs
  10. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
  11. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Convolutional.cs
  12. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Dense.cs
  13. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Dropout.cs
  14. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Embedding.cs
  15. +2
    -2
      src/TensorFlowNET.Core/Keras/Layers/LSTM.cs
  16. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs
  17. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs
  18. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/ZeroPadding2D.cs
  19. +4
    -3
      src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs
  20. +1
    -1
      src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs
  21. +1
    -1
      src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs
  22. +3
    -0
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  23. +2
    -6
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj

+ 3
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/TensorFlowOpLayerArgs.cs View File

@@ -1,4 +1,5 @@
using System;
using NumSharp;
using System;
using System.Collections.Generic;
using System.Text;

@@ -7,5 +8,6 @@ namespace Tensorflow.Keras.ArgsDefinition
public class TensorFlowOpLayerArgs : LayerArgs
{
public NodeDef NodeDef { get; set; }
public Dictionary<int, NDArray> Constants { get; set; }
}
}

+ 2
- 2
src/TensorFlowNET.Core/Keras/BackendImpl.cs View File

@@ -160,9 +160,9 @@ namespace Tensorflow.Keras
/// </summary>
/// <param name="outputs"></param>
/// <returns></returns>
public Tensor eval_in_eager_or_function(Tensor outputs)
public NDArray eval_in_eager_or_function(Tensor outputs)
{
throw new NotImplementedException("");
return outputs.eval();
}

public class _DummyEagerGraph


+ 1
- 1
src/TensorFlowNET.Core/Keras/Engine/Flatten.cs View File

@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine
_channels_first = args.DataFormat == "channels_first";
}

protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false)
{
if (_channels_first)
{


+ 6
- 2
src/TensorFlowNET.Core/Keras/Engine/Functional.cs View File

@@ -69,10 +69,14 @@ namespace Tensorflow.Keras.Engine
}
}

protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false)
{
return base.call(inputs, state, is_training);
return run_internal_graph(inputs, state, is_training);
}

Tensors run_internal_graph(Tensors inputs, Tensor state = null, bool is_training = false)
{
throw new NotImplementedException("");
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Keras/Engine/Layer.Apply.cs View File

@@ -46,7 +46,7 @@ namespace Tensorflow.Keras.Engine
if (!built)
MaybeBuild(inputs);

outputs = call(inputs, state: state, is_training: is_training);
outputs = call_fn(inputs, state: state, is_training: is_training);

outputs = _set_connectivity_metadata_(inputs, outputs);
_handle_activity_regularization(inputs, outputs);


+ 1
- 1
src/TensorFlowNET.Core/Keras/Engine/Layer.FunctionalConstructionCall.cs View File

@@ -42,7 +42,7 @@ namespace Tensorflow.Keras.Engine
if (!dynamic)
throw new NotImplementedException("");

outputs = call(inputs);
outputs = call_fn(inputs);

outputs = _set_connectivity_metadata_(inputs, outputs);
_handle_activity_regularization(inputs, outputs);


+ 1
- 1
src/TensorFlowNET.Core/Keras/Engine/Layer.cs View File

@@ -162,7 +162,7 @@ namespace Tensorflow.Keras.Engine
/// <param name="state"></param>
/// <param name="is_training"></param>
/// <returns></returns>
protected virtual Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
protected virtual Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false)
{
throw new NotImplementedException("");
}


+ 0
- 2
src/TensorFlowNET.Core/Keras/Engine/Node.cs View File

@@ -52,8 +52,6 @@ namespace Tensorflow.Keras.Engine
layer.InboundNodes.Add(this);
foreach (var kt in kerasInputs)
{
if (kt.KerasHistory == null)
continue;
var inbound_layer = kt.KerasHistory.layer;
if (inbound_layer != null)
inbound_layer.OutboundNodes.Add(this);


+ 2
- 2
src/TensorFlowNET.Core/Keras/Engine/TensorFlowOpLayer.cs View File

@@ -23,9 +23,9 @@ namespace Tensorflow.Keras.Engine
built = true;
}

protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false)
{
return base.call(inputs, state, is_training);
return base.call_fn(inputs, state, is_training);
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs View File

@@ -119,7 +119,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false)
{
Tensor outputs = null;



+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Convolutional.cs View File

@@ -98,7 +98,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override Tensors call(Tensors inputs, Tensor state = null, bool training = false)
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool training = false)
{
var outputs = _convolution_op.Apply(inputs, kernel);
if (use_bias)


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Dense.cs View File

@@ -65,7 +65,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override Tensors call(Tensors inputs, Tensor state = null, bool training = false)
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool training = false)
{
Tensor outputs = null;
var rank = inputs.rank;


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Dropout.cs View File

@@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Layers
this.args = args;
}

protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false)
{
var output = tf_utils.smart_cond(is_training,
() => tf.nn.dropout(inputs,


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Embedding.cs View File

@@ -62,7 +62,7 @@ namespace Tensorflow.Keras.Layers
built = true;
}

protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false)
{
var dtype = inputs.dtype;
if (dtype != tf.int32 && dtype != tf.int64)


+ 2
- 2
src/TensorFlowNET.Core/Keras/Layers/LSTM.cs View File

@@ -29,9 +29,9 @@ namespace Tensorflow.Keras.Layers
.ToArray();
}

protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false)
{
return base.call(inputs, state: state, is_training: is_training);
return base.call_fn(inputs, state: state, is_training: is_training);
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs View File

@@ -36,7 +36,7 @@ namespace Tensorflow.Keras.Layers
input_spec = new InputSpec(ndim: 4);
}

protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false)
{
int[] pool_shape;
int[] strides;


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Rescaling.cs View File

@@ -20,7 +20,7 @@ namespace Tensorflow.Keras.Layers
this.args = args;
}

protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false)
{
scale = math_ops.cast(args.Scale, args.DType);
offset = math_ops.cast(args.Offset, args.DType);


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/ZeroPadding2D.cs View File

@@ -29,7 +29,7 @@ namespace Tensorflow.Keras.Layers
this.input_spec = new InputSpec(ndim: 4);
}

protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false)
{
return tf.keras.backend.spatial_2d_padding(inputs,
padding: padding,


+ 4
- 3
src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs View File

@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/

using NumSharp;
using System;
using System.Collections.Generic;
using System.Linq;
@@ -135,7 +136,7 @@ namespace Tensorflow.Keras.Utils
if (!processed_ops.Contains(op))
{
var layer_inputs = new List<Tensor>();
var constants = new Dictionary<int, NDArray>();
foreach (var (i, op_input) in enumerate(op.inputs._inputs))
{
if (uses_keras_history(op_input))
@@ -144,8 +145,7 @@ namespace Tensorflow.Keras.Utils
{
tf_with(ops.init_scope(), delegate
{


constants[i] = tf.keras.backend.eval_in_eager_or_function(op_input);
});
}
}
@@ -155,6 +155,7 @@ namespace Tensorflow.Keras.Utils
var op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs
{
NodeDef = op.node_def,
Constants = constants,
Name = op.name
});
created_layers.Add(op_layer);


+ 1
- 1
src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs View File

@@ -74,7 +74,7 @@ namespace Tensorflow
/// <param name="training"></param>
/// <param name="state"></param>
/// <returns></returns>
protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false)
{
var one = constant_op.constant(1, dtype: dtypes.int32);
// Parameters of gates are concatenated into one multiply for efficiency.


+ 1
- 1
src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs View File

@@ -67,7 +67,7 @@ namespace Tensorflow
built = true;
}

protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false)
{
// Most basic RNN: output = new_state = act(W * input + U * state + B).
var concat = array_ops.concat(new Tensor[] { inputs, state }, 1);


+ 3
- 0
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -124,6 +124,9 @@ namespace Tensorflow
case NPTypeCode.Double:
full_values.Add(value.GetValue<double>(0));
break;
case NPTypeCode.Boolean:
full_values.Add(value.GetValue<bool>(0));
break;
/*case "String":
full_values.Add(value.Data<byte>()[0]);
break;*/


+ 2
- 6
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -72,6 +72,8 @@ https://tensorflownet.readthedocs.io</Description>
</ItemGroup>

<ItemGroup>
<None Remove="FodyWeavers.xml" />
<None Remove="FodyWeavers.xsd" />
<None Remove="Protobuf\README.md" />
</ItemGroup>

@@ -84,10 +86,4 @@ https://tensorflownet.readthedocs.io</Description>
<ItemGroup>
<Folder Include="Keras\Initializers\" />
</ItemGroup>

<ItemGroup>
<None Update="FodyWeavers.xml">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
</ItemGroup>
</Project>

Loading…
Cancel
Save