Browse Source

Partial implementation of tf.keras. #355

tags/v0.20
Oceania2018 5 years ago
parent
commit
5f1f59897d
30 changed files with 329 additions and 196 deletions
  1. +1
    -53
      TensorFlow.NET.sln
  2. +15
    -8
      src/TensorFlowNET.Core/APIs/tf.layers.cs
  3. +3
    -0
      src/TensorFlowNET.Core/APIs/tf.math.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Data/DatasetManager.cs
  5. +2
    -2
      src/TensorFlowNET.Core/Data/TensorSliceDataset.cs
  6. +17
    -0
      src/TensorFlowNET.Core/Exceptions/InvalidArgumentError.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Gradients/GradientTape.cs
  8. +56
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/DenseArgs.cs
  9. +51
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs
  10. +10
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/ModelArgs.cs
  11. +14
    -0
      src/TensorFlowNET.Core/Keras/Engine/CallContext.cs
  12. +14
    -0
      src/TensorFlowNET.Core/Keras/Engine/CallContextManager.cs
  13. +15
    -0
      src/TensorFlowNET.Core/Keras/Engine/ILayer.cs
  14. +52
    -18
      src/TensorFlowNET.Core/Keras/Engine/Layer.cs
  15. +8
    -4
      src/TensorFlowNET.Core/Keras/Engine/Model.cs
  16. +0
    -55
      src/TensorFlowNET.Core/Keras/Engine/Network.cs
  17. +4
    -3
      src/TensorFlowNET.Core/Keras/Engine/Sequential.cs
  18. +18
    -1
      src/TensorFlowNET.Core/Keras/KerasApi.cs
  19. +2
    -1
      src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs
  20. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Conv.cs
  21. +13
    -19
      src/TensorFlowNET.Core/Keras/Layers/Dense.cs
  22. +9
    -2
      src/TensorFlowNET.Core/Keras/Layers/Embedding.cs
  23. +7
    -1
      src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs
  24. +1
    -0
      src/TensorFlowNET.Core/Keras/Layers/Node.cs
  25. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs
  26. +0
    -20
      src/TensorFlowNET.Core/Layers/Dense.cs
  27. +9
    -2
      src/TensorFlowNET.Core/Layers/Layer.cs
  28. +1
    -1
      src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs
  29. +1
    -1
      src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs
  30. +2
    -1
      src/TensorFlowNET.Core/Status/Status.cs

+ 1
- 53
TensorFlow.NET.sln View File

@@ -9,11 +9,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Benchmark", "src
EndProject EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.UnitTest", "test\TensorFlowNET.UnitTest\Tensorflow.UnitTest.csproj", "{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}" Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.UnitTest", "test\TensorFlowNET.UnitTest\Tensorflow.UnitTest.csproj", "{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}"
EndProject EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras", "src\TensorFlowNET.Keras\Tensorflow.Keras.csproj", "{6268B461-486A-460B-9B3C-86493CBBAAF7}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras.UnitTest", "test\Tensorflow.Keras.UnitTest\Tensorflow.Keras.UnitTest.csproj", "{EB92DD90-6346-41FB-B967-2B33A860AD98}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TensorFlowNET.Console", "src\TensorFlowNET.Console\TensorFlowNET.Console.csproj", "{03F06299-3F4B-4449-A709-3A647657BC0C}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Console", "src\TensorFlowNET.Console\TensorFlowNET.Console.csproj", "{03F06299-3F4B-4449-A709-3A647657BC0C}"
EndProject EndProject
Global Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution GlobalSection(SolutionConfigurationPlatforms) = preSolution
@@ -103,54 +99,6 @@ Global
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x64.Build.0 = Release|x64 {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x64.Build.0 = Release|x64
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x86.ActiveCfg = Release|Any CPU {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x86.ActiveCfg = Release|Any CPU
{23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x86.Build.0 = Release|Any CPU {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x86.Build.0 = Release|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|Any CPU.Build.0 = Debug|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|x64.ActiveCfg = Debug|x64
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|x64.Build.0 = Debug|x64
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|x86.ActiveCfg = Debug|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|x86.Build.0 = Debug|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|x64.ActiveCfg = Debug|x64
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|x64.Build.0 = Debug|x64
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|x86.ActiveCfg = Debug|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|x86.Build.0 = Debug|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|Any CPU.ActiveCfg = Release|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|Any CPU.Build.0 = Release|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|x64.ActiveCfg = Release|x64
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|x64.Build.0 = Release|x64
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|x86.ActiveCfg = Release|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|x86.Build.0 = Release|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|Any CPU.ActiveCfg = Release|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|Any CPU.Build.0 = Release|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|x64.ActiveCfg = Release|x64
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|x64.Build.0 = Release|x64
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|x86.ActiveCfg = Release|Any CPU
{6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|x86.Build.0 = Release|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|Any CPU.Build.0 = Debug|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|x64.ActiveCfg = Debug|x64
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|x64.Build.0 = Debug|x64
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|x86.ActiveCfg = Debug|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|x86.Build.0 = Debug|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|x64.ActiveCfg = Debug|x64
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|x64.Build.0 = Debug|x64
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|x86.ActiveCfg = Debug|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|x86.Build.0 = Debug|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|Any CPU.ActiveCfg = Release|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|Any CPU.Build.0 = Release|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|x64.ActiveCfg = Release|x64
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|x64.Build.0 = Release|x64
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|x86.ActiveCfg = Release|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|x86.Build.0 = Release|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|Any CPU.ActiveCfg = Release|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|Any CPU.Build.0 = Release|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x64.ActiveCfg = Release|x64
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x64.Build.0 = Release|x64
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x86.ActiveCfg = Release|Any CPU
{EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x86.Build.0 = Release|Any CPU
{03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|Any CPU.Build.0 = Debug|Any CPU {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|Any CPU.Build.0 = Debug|Any CPU
{03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x64.ActiveCfg = Debug|Any CPU {03F06299-3F4B-4449-A709-3A647657BC0C}.Debug|x64.ActiveCfg = Debug|Any CPU


+ 15
- 8
src/TensorFlowNET.Core/APIs/tf.layers.cs View File

@@ -14,9 +14,11 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using NumSharp; using NumSharp;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Layers; using Tensorflow.Keras.Layers;
using Tensorflow.Operations.Activation; using Tensorflow.Operations.Activation;
using static Tensorflow.Binding; using static Tensorflow.Binding;
@@ -173,14 +175,19 @@ namespace Tensorflow
if (bias_initializer == null) if (bias_initializer == null)
bias_initializer = tf.zeros_initializer; bias_initializer = tf.zeros_initializer;


var layer = new Dense(units, activation,
use_bias: use_bias,
bias_initializer: bias_initializer,
kernel_initializer: kernel_initializer,
trainable: trainable,
name: name);

return layer.apply(inputs).Item1;
var layer = new Dense(new DenseArgs
{
Units = units,
Activation = activation,
UseBias = use_bias,
BiasInitializer = bias_initializer,
KernelInitializer = kernel_initializer,
Trainable = trainable,
Name = name
});

throw new NotImplementedException("");
//return layer.apply(inputs).Item1;
} }


/// <summary> /// <summary>


+ 3
- 0
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -515,6 +515,9 @@ namespace Tensorflow
public Tensor sum(Tensor input, int axis, bool keep_dims = false, string name = null) public Tensor sum(Tensor input, int axis, bool keep_dims = false, string name = null)
=> gen_math_ops._sum(input, axis, keep_dims: keep_dims, name: name); => gen_math_ops._sum(input, axis, keep_dims: keep_dims, name: name);


public Tensor reduce_mean(Tensor input_tensors, int axis, bool keepdims = false, string name = null)
=> math_ops.reduce_mean(input_tensors, axis: new[] { axis }, keepdims: keepdims, name: name);

public Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null) public Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null)
=> math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name, reduction_indices: reduction_indices); => math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name, reduction_indices: reduction_indices);




+ 1
- 1
src/TensorFlowNET.Core/Data/DatasetManager.cs View File

@@ -7,7 +7,7 @@ namespace Tensorflow
{ {
public class DatasetManager public class DatasetManager
{ {
public IDatasetV2 from_tensor_slices(NDArray features, NDArray labels)
public IDatasetV2 from_tensor_slices(Tensor features, Tensor labels)
=> new TensorSliceDataset(features, labels); => new TensorSliceDataset(features, labels);
} }
} }

+ 2
- 2
src/TensorFlowNET.Core/Data/TensorSliceDataset.cs View File

@@ -11,9 +11,9 @@ namespace Tensorflow
{ {
public class TensorSliceDataset : DatasetSource public class TensorSliceDataset : DatasetSource
{ {
public TensorSliceDataset(NDArray features, NDArray labels)
public TensorSliceDataset(Tensor features, Tensor labels)
{ {
_tensors = new[] { tf.convert_to_tensor(features), tf.convert_to_tensor(labels) };
_tensors = new[] { features, labels };
var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray(); var batched_spec = _tensors.Select(x => x.ToTensorSpec()).ToArray();
structure = batched_spec.Select(x => x._unbatch()).ToArray(); structure = batched_spec.Select(x => x._unbatch()).ToArray();


+ 17
- 0
src/TensorFlowNET.Core/Exceptions/InvalidArgumentError.cs View File

@@ -0,0 +1,17 @@
using System;

namespace Tensorflow
{
public class InvalidArgumentError : TensorflowException
{
public InvalidArgumentError() : base()
{

}

public InvalidArgumentError(string message) : base(message)
{

}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Gradients/GradientTape.cs View File

@@ -119,7 +119,7 @@ namespace Tensorflow.Gradients
return (results[0], results[1]); return (results[0], results[1]);
} }


public Tensor[] gradient(Tensor target, ResourceVariable[] sources)
public Tensor[] gradient(Tensor target, IEnumerable<IVariableV1> sources)
{ {
if (_recording) if (_recording)
{ {


+ 56
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/DenseArgs.cs View File

@@ -0,0 +1,56 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Operations.Activation;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.ArgsDefinition
{
public class DenseArgs : LayerArgs
{
/// <summary>
/// Positive integer, dimensionality of the output space.
/// </summary>
public int Units { get; set; }

/// <summary>
/// Activation function to use.
/// </summary>
public IActivation Activation { get; set; }

/// <summary>
/// Whether the layer uses a bias vector.
/// </summary>
public bool UseBias { get; set; } = true;

/// <summary>
/// Initializer for the `kernel` weights matrix.
/// </summary>
public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer;

/// <summary>
/// Initializer for the bias vector.
/// </summary>
public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer;

/// <summary>
/// Regularizer function applied to the `kernel` weights matrix.
/// </summary>
public IInitializer KernelRegularizer { get; set; }

/// <summary>
/// Regularizer function applied to the bias vector.
/// </summary>
public IInitializer BiasRegularizer { get; set; }

/// <summary>
/// Constraint function applied to the `kernel` weights matrix.
/// </summary>
public Action KernelConstraint { get; set; }

/// <summary>
/// Constraint function applied to the bias vector.
/// </summary>
public Action BiasConstraint { get; set; }
}
}

+ 51
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs View File

@@ -0,0 +1,51 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class LayerArgs
{
/// <summary>
/// Indicates whether the layer's weights are updated during training
/// and whether the layer's updates are run during training.
/// </summary>
public bool Trainable { get; set; } = true;

public string Name { get; set; }

/// <summary>
/// Only applicable to input layers.
/// </summary>
public TF_DataType DType { get; set; }

/// <summary>
/// Whether the `call` method can be used to build a TF graph without issues.
/// This attribute has no effect if the model is created using the Functional
/// API. Instead, `model.dynamic` is determined based on the internal layers.
/// </summary>
public bool Dynamic { get; set; } = false;

/// <summary>
/// Only applicable to input layers.
/// </summary>
public TensorShape InputShape { get; set; }

/// <summary>
/// Only applicable to input layers.
/// </summary>
public TensorShape BatchInputShape { get; set; }

/// <summary>
/// Initial weight values.
/// </summary>
public float[] Weights { get; set; }

/// <summary>
/// Regularizer function applied to the output of the layer(its "activation").
/// </summary>
public IInitializer ActivityRegularizer { get; set; }

public bool Autocast { get; set; }
}
}

+ 10
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/ModelArgs.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class ModelArgs : LayerArgs
{
}
}

+ 14
- 0
src/TensorFlowNET.Core/Keras/Engine/CallContext.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Engine
{
public class CallContext
{
public CallContextManager enter()
{
return new CallContextManager();
}
}
}

+ 14
- 0
src/TensorFlowNET.Core/Keras/Engine/CallContextManager.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Engine
{
public class CallContextManager : IDisposable
{
public void Dispose()
{
}
}
}

+ 15
- 0
src/TensorFlowNET.Core/Keras/Engine/ILayer.cs View File

@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Engine
{
/// <summary>
/// A layer is a callable object that takes as input one or more tensors and
/// that outputs one or more tensors.
/// </summary>
public interface ILayer
{
Tensor Apply(Tensor inputs, bool is_training = false);
}
}

src/TensorFlowNET.Core/Keras/Layers/Layer.cs → src/TensorFlowNET.Core/Keras/Engine/Layer.cs View File

@@ -17,12 +17,14 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using Tensorflow.Keras.Engine;
using System.Threading;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Utils; using Tensorflow.Keras.Utils;
using Tensorflow.Train; using Tensorflow.Train;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow.Keras.Layers
namespace Tensorflow.Keras.Engine
{ {
/// <summary> /// <summary>
/// Base layer class. /// Base layer class.
@@ -32,8 +34,10 @@ namespace Tensorflow.Keras.Layers
/// ///
/// tensorflow\python\keras\engine\base_layer.py /// tensorflow\python\keras\engine\base_layer.py
/// </summary> /// </summary>
public class Layer : AutoTrackable
public class Layer : AutoTrackable, ILayer
{ {
protected LayerArgs _args;

/// <summary> /// <summary>
/// Indicates whether `build` needs to be called upon layer call, to create /// Indicates whether `build` needs to be called upon layer call, to create
/// the layer's weights. /// the layer's weights.
@@ -52,6 +56,7 @@ namespace Tensorflow.Keras.Layers
protected InputSpec input_spec; protected InputSpec input_spec;
protected bool supports_masking; protected bool supports_masking;
protected List<IVariableV1> _trainable_weights; protected List<IVariableV1> _trainable_weights;
public List<IVariableV1> trainable_variables => _trainable_weights;
protected List<IVariableV1> _non_trainable_weights; protected List<IVariableV1> _non_trainable_weights;
private string _name; private string _name;
public string name => _name; public string name => _name;
@@ -72,13 +77,12 @@ namespace Tensorflow.Keras.Layers
float _initial_weights; float _initial_weights;
#pragma warning restore CS0169 // The field 'Layer._initial_weights' is never used #pragma warning restore CS0169 // The field 'Layer._initial_weights' is never used


public Layer(bool trainable = true,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,
int[] input_shape = null)
ThreadLocal<CallContext> _call_context;
public CallContext CallContext => _call_context.Value;
public Layer(LayerArgs args)
{ {
this.trainable = trainable;
this._dtype = dtype;
_args = args;
// A stateful layer is a layer whose updates are run during inference too, // A stateful layer is a layer whose updates are run during inference too,
// for instance stateful RNNs. // for instance stateful RNNs.
stateful = false; stateful = false;
@@ -94,17 +98,47 @@ namespace Tensorflow.Keras.Layers
_updates = new List<Operation>(); _updates = new List<Operation>();


// Manage input shape information if passed. // Manage input shape information if passed.
if(input_shape != null)
_inbound_nodes = new List<Node>();
}

/// <summary>
/// Wraps `call`, applying pre- and post-processing steps.
/// </summary>
/// <param name="input"></param>
/// <param name="is_training"></param>
/// <returns></returns>
public Tensor Apply(Tensor input, bool is_training = false)
{
var input_list = new Tensor[] { input };

if (_call_context == null)
_call_context = new ThreadLocal<CallContext>()
{
Value = new CallContext()
};

using var ctxManager = CallContext.enter();

string name_scope = "";
if (tf.context.executing_eagerly())
{ {
var shapes = new List<int> { -1 };
shapes.AddRange(input_shape);
_batch_input_shape = shapes.ToArray();
name_scope = _name;
}
else
{
throw new NotImplementedException("");
} }


_dtype = dtype;
tf_with(ops.name_scope(name_scope), scope =>
{
if (!built)
_maybe_build(input);


_inbound_nodes = new List<Node>();
call(input, is_training: is_training);
});

throw new NotImplementedException("");
} }


public Tensor[] __call__(Tensor[] inputs, public Tensor[] __call__(Tensor[] inputs,
@@ -147,7 +181,7 @@ namespace Tensorflow.Keras.Layers
_maybe_build(inputs[0]); _maybe_build(inputs[0]);


outputs = call(inputs[0], outputs = call(inputs[0],
training: training,
// training: training,
state: state); state: state);


(input, outputs) = _set_connectivity_metadata_(input, outputs); (input, outputs) = _set_connectivity_metadata_(input, outputs);
@@ -183,7 +217,7 @@ namespace Tensorflow.Keras.Layers
return null; return null;
} }


protected virtual Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
protected virtual Tensor[] call(Tensor inputs, bool is_training = false, Tensor state = null)
{ {
throw new NotImplementedException(""); throw new NotImplementedException("");
} }

+ 8
- 4
src/TensorFlowNET.Core/Keras/Engine/Model.cs View File

@@ -1,8 +1,12 @@
using Tensorflow.Keras.Optimizers;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Optimizers;


namespace Tensorflow.Keras.Engine namespace Tensorflow.Keras.Engine
{ {
public class Model : Network
/// <summary>
/// `Model` groups layers into an object with training and inference features.
/// </summary>
public class Model : Layer
{ {
#pragma warning disable CS0169 // The field 'Model._cloning' is never used #pragma warning disable CS0169 // The field 'Model._cloning' is never used
bool _cloning; bool _cloning;
@@ -15,8 +19,8 @@ namespace Tensorflow.Keras.Engine
string loss; string loss;
IOptimizer optimizer; IOptimizer optimizer;


public Model(string name = null)
: base(name: name)
public Model(ModelArgs args)
: base(args)
{ {


} }


+ 0
- 55
src/TensorFlowNET.Core/Keras/Engine/Network.cs View File

@@ -1,55 +0,0 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System.Collections.Generic;
using Tensorflow.Keras.Layers;

namespace Tensorflow.Keras.Engine
{
public class Network : Layer
{
protected bool _is_compiled;
protected bool _expects_training_arg;
protected bool _compute_output_and_mask_jointly;
/// <summary>
/// All layers in order of horizontal graph traversal.
/// Entries are unique. Includes input and output layers.
/// </summary>
protected List<Layer> _layers;

public Network(string name = null)
: base(name: name)
{
_init_subclassed_network(name);
}

protected virtual void _init_subclassed_network(string name = null)
{
_base_init(name: name);
}

protected virtual void _base_init(string name = null)
{
_init_set_name(name);
trainable = true;
_is_compiled = false;
_expects_training_arg = false;
_compute_output_and_mask_jointly = false;
supports_masking = false;
_layers = new List<Layer>();
}
}
}

+ 4
- 3
src/TensorFlowNET.Core/Keras/Engine/Sequential.cs View File

@@ -14,6 +14,7 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Layers; using Tensorflow.Keras.Layers;


namespace Tensorflow.Keras.Engine namespace Tensorflow.Keras.Engine
@@ -28,10 +29,10 @@ namespace Tensorflow.Keras.Engine
#pragma warning restore CS0169 // The field 'Sequential.outputs' is never used #pragma warning restore CS0169 // The field 'Sequential.outputs' is never used


public Sequential(string name = null) public Sequential(string name = null)
: base(name: name)
: base(new ModelArgs { Name = name})
{ {
supports_masking = true; supports_masking = true;
_compute_output_and_mask_jointly = true;
// _compute_output_and_mask_jointly = true;
} }


public void __enter__() public void __enter__()
@@ -47,7 +48,7 @@ namespace Tensorflow.Keras.Engine
{ {
built = false; built = false;
var set_inputs = false; var set_inputs = false;
if(_layers.Count == 0)
//if(_layers.Count == 0)
{ {
if(layer is InputLayer) if(layer is InputLayer)
{ {


+ 18
- 1
src/TensorFlowNET.Core/Keras/KerasApi.cs View File

@@ -1,6 +1,11 @@
using System.Data;
using System;
using System.Data;
using Tensorflow.Keras; using Tensorflow.Keras;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Datasets; using Tensorflow.Keras.Datasets;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;
using Tensorflow.Operations.Activation;


namespace Tensorflow namespace Tensorflow
{ {
@@ -8,5 +13,17 @@ namespace Tensorflow
{ {
public KerasDataset datasets { get; } = new KerasDataset(); public KerasDataset datasets { get; } = new KerasDataset();
public Initializers initializers { get; } = new Initializers(); public Initializers initializers { get; } = new Initializers();
public Layers layers { get; } = new Layers();

public class Layers
{
public ILayer Dense(int units,
IActivation activation = null)
=> new Dense(new DenseArgs
{
Units = units,
Activation = activation
});
}
} }
} }

+ 2
- 1
src/TensorFlowNET.Core/Keras/Layers/BatchNormalization.cs View File

@@ -143,12 +143,13 @@ namespace Tensorflow.Keras.Layers
built = true; built = true;
} }


protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
protected override Tensor[] call(Tensor inputs, bool is_training = false, Tensor state = null)
{ {
Tensor outputs = null; Tensor outputs = null;


if (fused) if (fused)
{ {
Tensor training = tf.convert_to_tensor(is_training);
outputs = _fused_batch_norm(inputs, training: training); outputs = _fused_batch_norm(inputs, training: training);
return new[] { outputs, outputs }; return new[] { outputs, outputs };
} }


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Conv.cs View File

@@ -108,7 +108,7 @@ namespace Tensorflow.Keras.Layers
built = true; built = true;
} }


protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
protected override Tensor[] call(Tensor inputs, bool training = false, Tensor state = null)
{ {
var outputs = _convolution_op.__call__(inputs, kernel); var outputs = _convolution_op.__call__(inputs, kernel);
if (use_bias) if (use_bias)


+ 13
- 19
src/TensorFlowNET.Core/Keras/Layers/Dense.cs View File

@@ -17,35 +17,29 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using Tensorflow.Operations.Activation; using Tensorflow.Operations.Activation;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow.Keras.Layers namespace Tensorflow.Keras.Layers
{ {
public class Dense : Tensorflow.Layers.Layer
/// <summary>
/// Just your regular densely-connected NN layer.
/// </summary>
public class Dense : Layer
{ {
protected int units; protected int units;
protected IActivation activation; protected IActivation activation;
protected bool use_bias; protected bool use_bias;
protected IInitializer kernel_initializer; protected IInitializer kernel_initializer;
protected IInitializer bias_initializer; protected IInitializer bias_initializer;
protected RefVariable kernel;
protected RefVariable bias;
protected IVariableV1 kernel;
protected IVariableV1 bias;


public Dense(int units,
IActivation activation,
string name = null,
bool use_bias = true,
bool trainable = false,
IInitializer kernel_initializer = null,
IInitializer bias_initializer = null) : base(trainable: trainable, name: name)
public Dense(DenseArgs args) :
base(args)
{ {
this.units = units;
this.activation = activation;
this.use_bias = use_bias;
this.kernel_initializer = kernel_initializer;
this.bias_initializer = bias_initializer;
this.supports_masking = true; this.supports_masking = true;
this.input_spec = new InputSpec(min_ndim: 2); this.input_spec = new InputSpec(min_ndim: 2);
} }
@@ -56,14 +50,14 @@ namespace Tensorflow.Keras.Layers
var axes = new Dictionary<int, int>(); var axes = new Dictionary<int, int>();
axes[-1] = last_dim; axes[-1] = last_dim;
input_spec = new InputSpec(min_ndim: 2, axes: axes); input_spec = new InputSpec(min_ndim: 2, axes: axes);
kernel = (RefVariable)add_weight(
kernel = add_weight(
"kernel", "kernel",
shape: new int[] { last_dim, units }, shape: new int[] { last_dim, units },
initializer: kernel_initializer, initializer: kernel_initializer,
dtype: _dtype, dtype: _dtype,
trainable: true); trainable: true);
if (use_bias) if (use_bias)
bias = (RefVariable)add_weight(
bias = add_weight(
"bias", "bias",
shape: new int[] { units }, shape: new int[] { units },
initializer: bias_initializer, initializer: bias_initializer,
@@ -73,7 +67,7 @@ namespace Tensorflow.Keras.Layers
built = true; built = true;
} }


protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
protected override Tensor[] call(Tensor inputs, bool training = false, Tensor state = null)
{ {
Tensor outputs = null; Tensor outputs = null;
var rank = inputs.rank; var rank = inputs.rank;
@@ -83,7 +77,7 @@ namespace Tensorflow.Keras.Layers
} }
else else
{ {
outputs = gen_math_ops.mat_mul(inputs, kernel);
outputs = gen_math_ops.mat_mul(inputs, kernel.Handle);
} }


if (use_bias) if (use_bias)


+ 9
- 2
src/TensorFlowNET.Core/Keras/Layers/Embedding.cs View File

@@ -14,6 +14,8 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow.Keras.Layers namespace Tensorflow.Keras.Layers
@@ -32,7 +34,12 @@ namespace Tensorflow.Keras.Layers
bool mask_zero = false, bool mask_zero = false,
TF_DataType dtype = TF_DataType.TF_FLOAT, TF_DataType dtype = TF_DataType.TF_FLOAT,
int[] input_shape = null, int[] input_shape = null,
int input_length = -1) : base(dtype: dtype, input_shape: input_shape ?? new[] { input_length })
int input_length = -1) :
base(new LayerArgs
{
DType = dtype,
InputShape = input_shape ?? new[] { input_length }
})
{ {
this.input_dim = input_dim; this.input_dim = input_dim;
this.output_dim = output_dim; this.output_dim = output_dim;
@@ -50,7 +57,7 @@ namespace Tensorflow.Keras.Layers
built = true; built = true;
} }


protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
protected override Tensor[] call(Tensor inputs, bool is_training = false, Tensor state = null)
{ {
var dtype = inputs.dtype; var dtype = inputs.dtype;
if (dtype != tf.int32 && dtype != tf.int64) if (dtype != tf.int32 && dtype != tf.int64)


+ 7
- 1
src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs View File

@@ -17,6 +17,8 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;


namespace Tensorflow.Keras.Layers namespace Tensorflow.Keras.Layers
{ {
@@ -35,7 +37,11 @@ namespace Tensorflow.Keras.Layers
TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype = TF_DataType.DtInvalid,
string name = null, string name = null,
bool sparse = false, bool sparse = false,
Tensor input_tensor = null) : base(dtype: dtype, name: name)
Tensor input_tensor = null) :
base(new LayerArgs
{
DType = dtype, Name = name
})
{ {
built = true; built = true;
this.sparse = sparse; this.sparse = sparse;


+ 1
- 0
src/TensorFlowNET.Core/Keras/Layers/Node.cs View File

@@ -15,6 +15,7 @@
******************************************************************************/ ******************************************************************************/


using System.Linq; using System.Linq;
using Tensorflow.Keras.Engine;


namespace Tensorflow.Keras.Layers namespace Tensorflow.Keras.Layers
{ {


+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/Pooling2D.cs View File

@@ -45,7 +45,7 @@ namespace Tensorflow.Keras.Layers
this.input_spec = new InputSpec(ndim: 4); this.input_spec = new InputSpec(ndim: 4);
} }


protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
protected override Tensor[] call(Tensor inputs, bool is_training = false, Tensor state = null)
{ {
int[] pool_shape; int[] pool_shape;
if (data_format == "channels_last") if (data_format == "channels_last")


+ 0
- 20
src/TensorFlowNET.Core/Layers/Dense.cs View File

@@ -1,20 +0,0 @@
using Tensorflow.Operations.Activation;

namespace Tensorflow.Layers
{
public class Dense : Keras.Layers.Dense
{
public Dense(int units,
IActivation activation,
bool use_bias = true,
bool trainable = false,
IInitializer kernel_initializer = null) : base(units,
activation,
use_bias: use_bias,
trainable: trainable,
kernel_initializer: kernel_initializer)
{

}
}
}

+ 9
- 2
src/TensorFlowNET.Core/Layers/Layer.cs View File

@@ -16,11 +16,12 @@


using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow.Layers namespace Tensorflow.Layers
{ {
public class Layer : Keras.Layers.Layer
public class Layer : Keras.Engine.Layer
{ {
protected Graph _graph; protected Graph _graph;
@@ -34,7 +35,13 @@ namespace Tensorflow.Layers
public Layer(bool trainable = true, public Layer(bool trainable = true,
string name = null, string name = null,
TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype = TF_DataType.DtInvalid,
bool? _reuse = null) : base(trainable: trainable, name: name, dtype: dtype)
bool? _reuse = null) :
base(new LayerArgs
{
Trainable = trainable,
Name = name,
DType = dtype
})
{ {
// For backwards compatibility, legacy layers do not use `ResourceVariable` // For backwards compatibility, legacy layers do not use `ResourceVariable`
// by default. // by default.


+ 1
- 1
src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs View File

@@ -74,7 +74,7 @@ namespace Tensorflow
/// <param name="training"></param> /// <param name="training"></param>
/// <param name="state"></param> /// <param name="state"></param>
/// <returns></returns> /// <returns></returns>
protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
protected override Tensor[] call(Tensor inputs, bool is_training = false, Tensor state = null)
{ {
var one = constant_op.constant(1, dtype: dtypes.int32); var one = constant_op.constant(1, dtype: dtypes.int32);
// Parameters of gates are concatenated into one multiply for efficiency. // Parameters of gates are concatenated into one multiply for efficiency.


+ 1
- 1
src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs View File

@@ -67,7 +67,7 @@ namespace Tensorflow
built = true; built = true;
} }


protected override Tensor[] call(Tensor inputs, Tensor training = null, Tensor state = null)
protected override Tensor[] call(Tensor inputs, bool is_training = false, Tensor state = null)
{ {
// Most basic RNN: output = new_state = act(W * input + U * state + B). // Most basic RNN: output = new_state = act(W * input + U * state + B).
var concat = array_ops.concat(new[] { inputs, state }, 1); var concat = array_ops.concat(new[] { inputs, state }, 1);


+ 2
- 1
src/TensorFlowNET.Core/Status/Status.cs View File

@@ -85,8 +85,9 @@ namespace Tensorflow
{ {
case TF_Code.TF_OUT_OF_RANGE: case TF_Code.TF_OUT_OF_RANGE:
throw new OutOfRangeError(message); throw new OutOfRangeError(message);
case TF_Code.TF_INVALID_ARGUMENT:
throw new InvalidArgumentError(message);
default: default:
Console.WriteLine(message);
throw new TensorflowException(message); throw new TensorflowException(message);
} }
} }


Loading…
Cancel
Save