Browse Source

fix KerasHistory.

tags/v0.30
Oceania2018 5 years ago
parent
commit
b02f285496
11 changed files with 94 additions and 35 deletions
  1. +13
    -2
      src/TensorFlowNET.Core/Binding.Util.cs
  2. +10
    -0
      src/TensorFlowNET.Core/Keras/BackendImpl.cs
  3. +6
    -2
      src/TensorFlowNET.Core/Keras/Engine/Functional.cs
  4. +0
    -2
      src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs
  5. +1
    -15
      src/TensorFlowNET.Core/Keras/Engine/Layer.cs
  6. +2
    -0
      src/TensorFlowNET.Core/Keras/Engine/Node.cs
  7. +2
    -2
      src/TensorFlowNET.Core/Keras/Engine/TensorFlowOpLayer.cs
  8. +2
    -0
      src/TensorFlowNET.Core/Keras/KerasApi.cs
  9. +33
    -0
      src/TensorFlowNET.Core/Keras/Optimizers/OptimizerApi.cs
  10. +24
    -11
      src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs
  11. +1
    -1
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj

+ 13
- 2
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -32,8 +32,19 @@ namespace Tensorflow
{
public static T2 get<T1, T2>(this Dictionary<T1, T2> dict, T1 key)
=> key == null ?
default(T2) :
(dict.ContainsKey(key) ? dict[key] : default(T2));
default :
(dict.ContainsKey(key) ? dict[key] : default);

public static void Update<T>(this IList<T> list, T element)
{
var index = list.IndexOf(element);
if (index < 0)
list.Add(element);
else
{
list[index] = element;
}
}

public static void add<T>(this IList<T> list, T element)
=> list.Add(element);


+ 10
- 0
src/TensorFlowNET.Core/Keras/BackendImpl.cs View File

@@ -155,6 +155,16 @@ namespace Tensorflow.Keras
return array_ops.pad(x, pattern);
}

/// <summary>
/// Method to evaluate a tensor in eager or in a tf.function.
/// </summary>
/// <param name="outputs"></param>
/// <returns></returns>
public Tensor eval_in_eager_or_function(Tensor outputs)
{
throw new NotImplementedException("");
}

public class _DummyEagerGraph
{ }
}


+ 6
- 2
src/TensorFlowNET.Core/Keras/Engine/Functional.cs View File

@@ -1,7 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Utils;
@@ -69,6 +68,11 @@ namespace Tensorflow.Keras.Engine
_input_coordinates.append(new KerasHistory(layer, node_index, tensor_index, x));
}
}

protected override Tensors call(Tensors inputs, Tensor state = null, bool is_training = false)
{
return base.call(inputs, state, is_training);
}

}
}

+ 0
- 2
src/TensorFlowNET.Core/Keras/Engine/KerasHistory.cs View File

@@ -20,8 +20,6 @@ namespace Tensorflow.Keras.Engine
this.node_index = node_index;
this.tensor_index = tensor_index;
this.tensor = tensor;
Layer.KerasHistories.Add(this);
Console.WriteLine(tensor.name);
}

public void Deconstruct(out Layer layer, out int node_index, out int tensor_index)


+ 1
- 15
src/TensorFlowNET.Core/Keras/Engine/Layer.cs View File

@@ -89,7 +89,6 @@ namespace Tensorflow.Keras.Engine

ThreadLocal<CallContext> callContext;
public CallContext CallContext => callContext.Value;
public static List<KerasHistory> KerasHistories = new List<KerasHistory>();

public Layer(LayerArgs args)
{
@@ -125,29 +124,16 @@ namespace Tensorflow.Keras.Engine
}

public void SetConnectivityMetadata(Tensors inputs, Tensors outputs)
{

}
=> _set_connectivity_metadata_(inputs, outputs);

private Tensors _set_connectivity_metadata_(Tensors inputs, Tensors outputs)
{
/*var returnOutputs = new List<Tensor>();
foreach(var x in outputs)
{
if (inputs.Contains(x))
{

}
returnOutputs.Add(x);
}*/

new Node(this, new NodeArgs
{
InputTensors = inputs,
Outputs = outputs
});

//_add_inbound_node(input_tensors: inputs, output_tensors: outputs);
return outputs;
}



+ 2
- 0
src/TensorFlowNET.Core/Keras/Engine/Node.cs View File

@@ -52,6 +52,8 @@ namespace Tensorflow.Keras.Engine
layer.InboundNodes.Add(this);
foreach (var kt in kerasInputs)
{
if (kt.KerasHistory == null)
continue;
var inbound_layer = kt.KerasHistory.layer;
if (inbound_layer != null)
inbound_layer.OutboundNodes.Add(this);


+ 2
- 2
src/TensorFlowNET.Core/Keras/Engine/TensorFlowOpLayer.cs View File

@@ -8,12 +8,12 @@ namespace Tensorflow.Keras.Engine
public class TensorFlowOpLayer : Layer
{
TensorFlowOpLayerArgs args;
string _TF_OP_LAYER_NAME_PREFIX = "";
static string TF_OP_LAYER_NAME_PREFIX = "tf_op_layer_";

public TensorFlowOpLayer(TensorFlowOpLayerArgs args)
: base(new LayerArgs
{
Name = "tf_op_layer_" + args.Name,
Name = TF_OP_LAYER_NAME_PREFIX + args.Name,
Trainable = args.Trainable,
DType = args.DType,
Autocast = false


+ 2
- 0
src/TensorFlowNET.Core/Keras/KerasApi.cs View File

@@ -8,6 +8,7 @@ using Tensorflow.Keras.Datasets;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Optimizers;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -22,6 +23,7 @@ namespace Tensorflow
public Activations activations { get; } = new Activations();
public Preprocessing preprocessing { get; } = new Preprocessing();
public BackendImpl backend { get; } = new BackendImpl();
public OptimizerApi optimizers { get; } = new OptimizerApi();

public Sequential Sequential(List<Layer> layers = null,
string name = null)


+ 33
- 0
src/TensorFlowNET.Core/Keras/Optimizers/OptimizerApi.cs View File

@@ -0,0 +1,33 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Optimizers
{
public class OptimizerApi
{
/// <summary>
/// Adam optimization is a stochastic gradient descent method that is based on
/// adaptive estimation of first-order and second-order moments.
/// </summary>
/// <param name="learning_rate"></param>
/// <param name="beta_1"></param>
/// <param name="beta_2"></param>
/// <param name="epsilon"></param>
/// <param name="amsgrad"></param>
/// <param name="name"></param>
/// <returns></returns>
public OptimizerV2 Adam(float learning_rate = 0.001f,
float beta_1 = 0.9f,
float beta_2 = 0.999f,
float epsilon = 1e-7f,
bool amsgrad = false,
string name = "Adam")
=> new Adam(learning_rate: learning_rate,
beta_1: beta_1,
beta_2: beta_2,
epsilon: epsilon,
amsgrad: amsgrad,
name: name);
}
}

+ 24
- 11
src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs View File

@@ -142,26 +142,39 @@ namespace Tensorflow.Keras.Utils
layer_inputs.Add(op_input);
else
{
tf_with(ops.init_scope(), delegate
{

}

// recursively
CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers);
var op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs
{
NodeDef = op.node_def,
Name = op.name
});
created_layers.Add(op_layer);
op_layer.SetConnectivityMetadata(layer_inputs, op.outputs);
});
}
}

// recursively
CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers);
var op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs
{
NodeDef = op.node_def,
Name = op.name
});
created_layers.Add(op_layer);
op_layer.SetConnectivityMetadata(layer_inputs, op.outputs);
processed_ops.Add(op);
}
}
}

// recusive
static bool uses_keras_history(Tensor op_input)
{
return Layer.KerasHistories.Any(x => x.tensor.name == op_input.name);
if (op_input.KerasHistory != null)
return true;

foreach (var input in op_input.op.inputs._inputs)
if (uses_keras_history(input))
return true;

return false;
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -29,7 +29,7 @@ https://tensorflownet.readthedocs.io</Description>
<FileVersion>0.21.0.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile>
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
<SignAssembly>false</SignAssembly>
<SignAssembly>true</SignAssembly>
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile>
<Platforms>AnyCPU;x64</Platforms>
</PropertyGroup>


Loading…
Cancel
Save