Browse Source

fix: training LSTM does not align with tensorflow.

tags/v0.110.0-LSTM-Model
Yaohui Liu 2 years ago
parent
commit
a0df8109f8
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
14 changed files with 68 additions and 37 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Binding.Util.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs
  3. +6
    -1
      src/TensorFlowNET.Core/Eager/EagerTensor.ToString.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
  6. +9
    -9
      src/TensorFlowNET.Core/NumPy/NDArrayRender.cs
  7. +22
    -0
      src/TensorFlowNET.Core/Operations/Initializers/NpyLoadInitializer.cs
  8. +1
    -1
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  9. +1
    -2
      src/TensorFlowNET.Core/Training/Trackable.cs
  10. +4
    -3
      src/TensorFlowNET.Keras/Layers/LayersApi.cs
  11. +7
    -4
      src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs
  12. +7
    -10
      test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs
  13. +6
    -2
      tools/Tensorflow.CodeGen/FunctionGenerator.cs
  14. +1
    -1
      tools/Tensorflow.CodeGen/Tensorflow.CodeGen.csproj

+ 1
- 1
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -503,7 +503,7 @@ namespace Tensorflow
case Tensors tensors:
return tensors.dtype;
case IEnumerable<Tensor> tensors:
return tensors.First().dtype;
return tensors.Where(x => x is not null).First().dtype;
case RefVariable variable:
return variable.dtype;
case ResourceVariable variable:


+ 1
- 1
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs View File

@@ -65,7 +65,7 @@ namespace Tensorflow.Eager
{
outgrad_vec = output_gradients.ToList();
}
var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, false);
var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, true);


bool unconnected_gradients_zero = unconnected_gradients == "zero";


+ 6
- 1
src/TensorFlowNET.Core/Eager/EagerTensor.ToString.cs View File

@@ -10,6 +10,11 @@ namespace Tensorflow.Eager
var str = NDArrayRender.ToString(nd);
return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}";
}
public string ToString(int maxLength)
{
var nd = new NDArray(this);
var str = NDArrayRender.ToString(nd, maxLength);
return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}";
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs View File

@@ -29,7 +29,7 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn
[JsonProperty("unit_forget_bias")]
public bool UnitForgetBias { get; set; } = true;
[JsonProperty("implementation")]
public int Implementation { get; set; } = 1;
public int Implementation { get; set; } = 2;

}
}

+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs View File

@@ -182,7 +182,7 @@ namespace Tensorflow.Keras.Layers
bool unit_forget_bias = true,
float dropout = 0f,
float recurrent_dropout = 0f,
int implementation = 1,
int implementation = 2,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,


+ 9
- 9
src/TensorFlowNET.Core/NumPy/NDArrayRender.cs View File

@@ -7,7 +7,7 @@ namespace Tensorflow.NumPy
{
public class NDArrayRender
{
public static string ToString(NDArray array)
public static string ToString(NDArray array, int maxLength = 10)
{
Shape shape = array.shape;
if (shape.IsScalar)
@@ -15,12 +15,12 @@ namespace Tensorflow.NumPy

var s = new StringBuilder();
s.Append("array(");
Build(s, array);
Build(s, array, maxLength);
s.Append(")");
return s.ToString();
}

static void Build(StringBuilder s, NDArray array)
static void Build(StringBuilder s, NDArray array, int maxLength)
{
var shape = array.shape;

@@ -35,11 +35,11 @@ namespace Tensorflow.NumPy
var len = shape[0];
s.Append("[");

if (len <= 10)
if (len <= maxLength)
{
for (int i = 0; i < len; i++)
{
Build(s, array[i]);
Build(s, array[i], maxLength);
if (i < len - 1)
{
s.Append(", ");
@@ -49,9 +49,9 @@ namespace Tensorflow.NumPy
}
else
{
for (int i = 0; i < 5; i++)
for (int i = 0; i < maxLength / 2; i++)
{
Build(s, array[i]);
Build(s, array[i], maxLength);
if (i < len - 1)
{
s.Append(", ");
@@ -62,9 +62,9 @@ namespace Tensorflow.NumPy
s.Append(" ... ");
s.AppendLine();

for (int i = (int)len - 5; i < len; i++)
for (int i = (int)len - maxLength / 2; i < len; i++)
{
Build(s, array[i]);
Build(s, array[i], maxLength);
if (i < len - 1)
{
s.Append(", ");


+ 22
- 0
src/TensorFlowNET.Core/Operations/Initializers/NpyLoadInitializer.cs View File

@@ -0,0 +1,22 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.NumPy;

namespace Tensorflow.Operations.Initializers
{
/// <summary>
/// An initializer specially used for debugging (to load weights from disk).
/// </summary>
class NpyLoadInitializer : IInitializer
{
string _path;
public NpyLoadInitializer(string path) { _path = path; }
public string ClassName => "";
public IDictionary<string, object> Config => new Dictionary<string, object>();
public Tensor Apply(InitializerArgs args)
{
return np.load(_path);
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -111,7 +111,7 @@ https://tensorflownet.readthedocs.io</Description>
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
<PackageReference Include="OneOf" Version="3.0.223" />
<PackageReference Include="Protobuf.Text" Version="0.7.0" />
<PackageReference Include="Protobuf.Text" Version="0.7.1" />
<PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" />
</ItemGroup>



+ 1
- 2
src/TensorFlowNET.Core/Training/Trackable.cs View File

@@ -179,8 +179,7 @@ namespace Tensorflow.Train
// handles slot variables.
if (!args.Overwrite || new_variable is RefVariable || new_variable is Trackable)
{
var temp = new_variable as Trackable;
var res = _track_trackable(temp, args.Name, args.Overwrite);
var res = _track_trackable(new_variable as Trackable, args.Name, args.Overwrite);
Debug.Assert(res is IVariableV1);
return res as IVariableV1;
}


+ 4
- 3
src/TensorFlowNET.Keras/Layers/LayersApi.cs View File

@@ -793,7 +793,7 @@ namespace Tensorflow.Keras.Layers
bool unit_forget_bias = true,
float dropout = 0f,
float recurrent_dropout = 0f,
int implementation = 1)
int implementation = 2)
=> new LSTMCell(new LSTMCellArgs
{
Units = uints,
@@ -846,7 +846,7 @@ namespace Tensorflow.Keras.Layers
bool unit_forget_bias = true,
float dropout = 0f,
float recurrent_dropout = 0f,
int implementation = 1,
int implementation = 2,
bool return_sequences = false,
bool return_state = false,
bool go_backwards = false,
@@ -869,7 +869,8 @@ namespace Tensorflow.Keras.Layers
GoBackwards = go_backwards,
Stateful = stateful,
TimeMajor = time_major,
Unroll = unroll
Unroll = unroll,
UnitForgetBias = unit_forget_bias
});

/// <summary>


+ 7
- 4
src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs View File

@@ -1,4 +1,5 @@
using Serilog.Core;
using Newtonsoft.Json;
using Serilog.Core;
using System.Diagnostics;
using Tensorflow.Common.Extensions;
using Tensorflow.Common.Types;
@@ -54,6 +55,7 @@ namespace Tensorflow.Keras.Layers.Rnn

public override void build(KerasShapesWrapper input_shape)
{
base.build(input_shape);
var single_shape = input_shape.ToSingleShape();
var input_dim = single_shape[-1];
_kernel = add_weight("kernel", (input_dim, _args.Units * 4),
@@ -82,7 +84,8 @@ namespace Tensorflow.Keras.Layers.Rnn
_bias_initializer = _args.BiasInitializer;
}
_bias = add_weight("bias", (_args.Units * 4),
initializer: _bias_initializer);
initializer: _bias_initializer
);
}
built = true;
}
@@ -203,7 +206,7 @@ namespace Tensorflow.Keras.Layers.Rnn
x_c + math_ops.matmul(h_tm1_c, _recurrent_kernel_slice));
_recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor,
new[] { 0, _args.Units * 3 }, new[] { startIndex, _args.Units });
var o = _args.RecurrentActivation.Apply(
var o = _args.Activation.Apply(
x_o + math_ops.matmul(h_tm1_o, _recurrent_kernel_slice));

return new Tensors(c, o);
@@ -220,7 +223,7 @@ namespace Tensorflow.Keras.Layers.Rnn
Tensor z0 = z[0], z1 = z[1], z2 = z[2], z3 = z[3];
var i = _args.RecurrentActivation.Apply(z0);
var f = _args.RecurrentActivation.Apply(z1);
var c = f * c_tm1 + i * _args.RecurrentActivation.Apply(z2);
var c = f * c_tm1 + i * _args.Activation.Apply(z2);
var o = _args.RecurrentActivation.Apply(z3);
return new Tensors(c, o);
}


+ 7
- 10
test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs View File

@@ -60,26 +60,23 @@ namespace Tensorflow.Keras.UnitTest.Layers
{
var input = keras.Input((784));
var x = keras.layers.Reshape((28, 28)).Apply(input);
//x = keras.layers.LSTM(50, return_sequences: true).Apply(x);
//x = keras.layers.LSTM(100, return_sequences: true).Apply(x);
//x = keras.layers.LSTM(150, return_sequences: true).Apply(x);
x = keras.layers.LSTM(4, implementation: 2).Apply(x);
//x = keras.layers.Dense(100).Apply(x);
x = keras.layers.LSTM(50, return_sequences: true).Apply(x);
x = keras.layers.LSTM(100).Apply(x);
var output = keras.layers.Dense(10, activation: "softmax").Apply(x);

var model = keras.Model(input, output);
model.summary();
model.compile(keras.optimizers.Adam(), keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });
model.compile(keras.optimizers.Adam(), keras.losses.CategoricalCrossentropy(), new string[] { "accuracy" });

var data_loader = new MnistModelLoader();
var dataset = data_loader.LoadAsync(new ModelLoadSetting
{
TrainDir = "mnist",
OneHot = false,
ValidationSize = 58000,
OneHot = true,
ValidationSize = 55000,
}).Result;

model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 30);
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 1);
}

[TestMethod]
@@ -102,7 +99,7 @@ namespace Tensorflow.Keras.UnitTest.Layers
ValidationSize = 58000,
}).Result;

model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 10);
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 2);
}

[TestMethod]


+ 6
- 2
tools/Tensorflow.CodeGen/FunctionGenerator.cs View File

@@ -83,8 +83,12 @@ namespace Tensorflow.CodeGen

sb.AppendLine("}"); // try

sb.Append("catch(NotOkStatusException ex)\n{\n");
sb.AppendLine("throw ex;");
sb.Append("catch(NotOkStatusException ex1)\n{\n");
sb.AppendLine("throw ex1;");
sb.AppendLine("}"); // catch

sb.Append("catch(InvalidArgumentError ex2)\n{\n");
sb.AppendLine("throw ex2;");
sb.AppendLine("}"); // catch

sb.Append("catch(Exception)\n{\n");


+ 1
- 1
tools/Tensorflow.CodeGen/Tensorflow.CodeGen.csproj View File

@@ -9,7 +9,7 @@

<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Scripting" Version="4.6.0-1.final" />
<PackageReference Include="Protobuf.Text" Version="0.7.0" />
<PackageReference Include="Protobuf.Text" Version="0.7.1" />
</ItemGroup>

<ItemGroup>


Loading…
Cancel
Save