Browse Source

fix: training LSTM does not align with tensorflow.

tags/v0.110.0-LSTM-Model
Yaohui Liu 2 years ago
parent
commit
a0df8109f8
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
14 changed files with 68 additions and 37 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Binding.Util.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs
  3. +6
    -1
      src/TensorFlowNET.Core/Eager/EagerTensor.ToString.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
  6. +9
    -9
      src/TensorFlowNET.Core/NumPy/NDArrayRender.cs
  7. +22
    -0
      src/TensorFlowNET.Core/Operations/Initializers/NpyLoadInitializer.cs
  8. +1
    -1
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  9. +1
    -2
      src/TensorFlowNET.Core/Training/Trackable.cs
  10. +4
    -3
      src/TensorFlowNET.Keras/Layers/LayersApi.cs
  11. +7
    -4
      src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs
  12. +7
    -10
      test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs
  13. +6
    -2
      tools/Tensorflow.CodeGen/FunctionGenerator.cs
  14. +1
    -1
      tools/Tensorflow.CodeGen/Tensorflow.CodeGen.csproj

+ 1
- 1
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -503,7 +503,7 @@ namespace Tensorflow
case Tensors tensors: case Tensors tensors:
return tensors.dtype; return tensors.dtype;
case IEnumerable<Tensor> tensors: case IEnumerable<Tensor> tensors:
return tensors.First().dtype;
return tensors.Where(x => x is not null).First().dtype;
case RefVariable variable: case RefVariable variable:
return variable.dtype; return variable.dtype;
case ResourceVariable variable: case ResourceVariable variable:


+ 1
- 1
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs View File

@@ -65,7 +65,7 @@ namespace Tensorflow.Eager
{ {
outgrad_vec = output_gradients.ToList(); outgrad_vec = output_gradients.ToList();
} }
var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, false);
var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, true);




bool unconnected_gradients_zero = unconnected_gradients == "zero"; bool unconnected_gradients_zero = unconnected_gradients == "zero";


+ 6
- 1
src/TensorFlowNET.Core/Eager/EagerTensor.ToString.cs View File

@@ -10,6 +10,11 @@ namespace Tensorflow.Eager
var str = NDArrayRender.ToString(nd); var str = NDArrayRender.ToString(nd);
return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}"; return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}";
} }
public string ToString(int maxLength)
{
var nd = new NDArray(this);
var str = NDArrayRender.ToString(nd, maxLength);
return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}";
}
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs View File

@@ -29,7 +29,7 @@ namespace Tensorflow.Keras.ArgsDefinition.Rnn
[JsonProperty("unit_forget_bias")] [JsonProperty("unit_forget_bias")]
public bool UnitForgetBias { get; set; } = true; public bool UnitForgetBias { get; set; } = true;
[JsonProperty("implementation")] [JsonProperty("implementation")]
public int Implementation { get; set; } = 1;
public int Implementation { get; set; } = 2;


} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs View File

@@ -182,7 +182,7 @@ namespace Tensorflow.Keras.Layers
bool unit_forget_bias = true, bool unit_forget_bias = true,
float dropout = 0f, float dropout = 0f,
float recurrent_dropout = 0f, float recurrent_dropout = 0f,
int implementation = 1,
int implementation = 2,
bool return_sequences = false, bool return_sequences = false,
bool return_state = false, bool return_state = false,
bool go_backwards = false, bool go_backwards = false,


+ 9
- 9
src/TensorFlowNET.Core/NumPy/NDArrayRender.cs View File

@@ -7,7 +7,7 @@ namespace Tensorflow.NumPy
{ {
public class NDArrayRender public class NDArrayRender
{ {
public static string ToString(NDArray array)
public static string ToString(NDArray array, int maxLength = 10)
{ {
Shape shape = array.shape; Shape shape = array.shape;
if (shape.IsScalar) if (shape.IsScalar)
@@ -15,12 +15,12 @@ namespace Tensorflow.NumPy


var s = new StringBuilder(); var s = new StringBuilder();
s.Append("array("); s.Append("array(");
Build(s, array);
Build(s, array, maxLength);
s.Append(")"); s.Append(")");
return s.ToString(); return s.ToString();
} }


static void Build(StringBuilder s, NDArray array)
static void Build(StringBuilder s, NDArray array, int maxLength)
{ {
var shape = array.shape; var shape = array.shape;


@@ -35,11 +35,11 @@ namespace Tensorflow.NumPy
var len = shape[0]; var len = shape[0];
s.Append("["); s.Append("[");


if (len <= 10)
if (len <= maxLength)
{ {
for (int i = 0; i < len; i++) for (int i = 0; i < len; i++)
{ {
Build(s, array[i]);
Build(s, array[i], maxLength);
if (i < len - 1) if (i < len - 1)
{ {
s.Append(", "); s.Append(", ");
@@ -49,9 +49,9 @@ namespace Tensorflow.NumPy
} }
else else
{ {
for (int i = 0; i < 5; i++)
for (int i = 0; i < maxLength / 2; i++)
{ {
Build(s, array[i]);
Build(s, array[i], maxLength);
if (i < len - 1) if (i < len - 1)
{ {
s.Append(", "); s.Append(", ");
@@ -62,9 +62,9 @@ namespace Tensorflow.NumPy
s.Append(" ... "); s.Append(" ... ");
s.AppendLine(); s.AppendLine();


for (int i = (int)len - 5; i < len; i++)
for (int i = (int)len - maxLength / 2; i < len; i++)
{ {
Build(s, array[i]);
Build(s, array[i], maxLength);
if (i < len - 1) if (i < len - 1)
{ {
s.Append(", "); s.Append(", ");


+ 22
- 0
src/TensorFlowNET.Core/Operations/Initializers/NpyLoadInitializer.cs View File

@@ -0,0 +1,22 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.NumPy;

namespace Tensorflow.Operations.Initializers
{
/// <summary>
/// An initializer specially used for debugging (to load weights from disk).
/// </summary>
class NpyLoadInitializer : IInitializer
{
string _path;
public NpyLoadInitializer(string path) { _path = path; }
public string ClassName => "";
public IDictionary<string, object> Config => new Dictionary<string, object>();
public Tensor Apply(InitializerArgs args)
{
return np.load(_path);
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -111,7 +111,7 @@ https://tensorflownet.readthedocs.io</Description>
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" /> <PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
<PackageReference Include="OneOf" Version="3.0.223" /> <PackageReference Include="OneOf" Version="3.0.223" />
<PackageReference Include="Protobuf.Text" Version="0.7.0" />
<PackageReference Include="Protobuf.Text" Version="0.7.1" />
<PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" /> <PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" />
</ItemGroup> </ItemGroup>




+ 1
- 2
src/TensorFlowNET.Core/Training/Trackable.cs View File

@@ -179,8 +179,7 @@ namespace Tensorflow.Train
// handles slot variables. // handles slot variables.
if (!args.Overwrite || new_variable is RefVariable || new_variable is Trackable) if (!args.Overwrite || new_variable is RefVariable || new_variable is Trackable)
{ {
var temp = new_variable as Trackable;
var res = _track_trackable(temp, args.Name, args.Overwrite);
var res = _track_trackable(new_variable as Trackable, args.Name, args.Overwrite);
Debug.Assert(res is IVariableV1); Debug.Assert(res is IVariableV1);
return res as IVariableV1; return res as IVariableV1;
} }


+ 4
- 3
src/TensorFlowNET.Keras/Layers/LayersApi.cs View File

@@ -793,7 +793,7 @@ namespace Tensorflow.Keras.Layers
bool unit_forget_bias = true, bool unit_forget_bias = true,
float dropout = 0f, float dropout = 0f,
float recurrent_dropout = 0f, float recurrent_dropout = 0f,
int implementation = 1)
int implementation = 2)
=> new LSTMCell(new LSTMCellArgs => new LSTMCell(new LSTMCellArgs
{ {
Units = uints, Units = uints,
@@ -846,7 +846,7 @@ namespace Tensorflow.Keras.Layers
bool unit_forget_bias = true, bool unit_forget_bias = true,
float dropout = 0f, float dropout = 0f,
float recurrent_dropout = 0f, float recurrent_dropout = 0f,
int implementation = 1,
int implementation = 2,
bool return_sequences = false, bool return_sequences = false,
bool return_state = false, bool return_state = false,
bool go_backwards = false, bool go_backwards = false,
@@ -869,7 +869,8 @@ namespace Tensorflow.Keras.Layers
GoBackwards = go_backwards, GoBackwards = go_backwards,
Stateful = stateful, Stateful = stateful,
TimeMajor = time_major, TimeMajor = time_major,
Unroll = unroll
Unroll = unroll,
UnitForgetBias = unit_forget_bias
}); });


/// <summary> /// <summary>


+ 7
- 4
src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs View File

@@ -1,4 +1,5 @@
using Serilog.Core;
using Newtonsoft.Json;
using Serilog.Core;
using System.Diagnostics; using System.Diagnostics;
using Tensorflow.Common.Extensions; using Tensorflow.Common.Extensions;
using Tensorflow.Common.Types; using Tensorflow.Common.Types;
@@ -54,6 +55,7 @@ namespace Tensorflow.Keras.Layers.Rnn


public override void build(KerasShapesWrapper input_shape) public override void build(KerasShapesWrapper input_shape)
{ {
base.build(input_shape);
var single_shape = input_shape.ToSingleShape(); var single_shape = input_shape.ToSingleShape();
var input_dim = single_shape[-1]; var input_dim = single_shape[-1];
_kernel = add_weight("kernel", (input_dim, _args.Units * 4), _kernel = add_weight("kernel", (input_dim, _args.Units * 4),
@@ -82,7 +84,8 @@ namespace Tensorflow.Keras.Layers.Rnn
_bias_initializer = _args.BiasInitializer; _bias_initializer = _args.BiasInitializer;
} }
_bias = add_weight("bias", (_args.Units * 4), _bias = add_weight("bias", (_args.Units * 4),
initializer: _bias_initializer);
initializer: _bias_initializer
);
} }
built = true; built = true;
} }
@@ -203,7 +206,7 @@ namespace Tensorflow.Keras.Layers.Rnn
x_c + math_ops.matmul(h_tm1_c, _recurrent_kernel_slice)); x_c + math_ops.matmul(h_tm1_c, _recurrent_kernel_slice));
_recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor, _recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor,
new[] { 0, _args.Units * 3 }, new[] { startIndex, _args.Units }); new[] { 0, _args.Units * 3 }, new[] { startIndex, _args.Units });
var o = _args.RecurrentActivation.Apply(
var o = _args.Activation.Apply(
x_o + math_ops.matmul(h_tm1_o, _recurrent_kernel_slice)); x_o + math_ops.matmul(h_tm1_o, _recurrent_kernel_slice));


return new Tensors(c, o); return new Tensors(c, o);
@@ -220,7 +223,7 @@ namespace Tensorflow.Keras.Layers.Rnn
Tensor z0 = z[0], z1 = z[1], z2 = z[2], z3 = z[3]; Tensor z0 = z[0], z1 = z[1], z2 = z[2], z3 = z[3];
var i = _args.RecurrentActivation.Apply(z0); var i = _args.RecurrentActivation.Apply(z0);
var f = _args.RecurrentActivation.Apply(z1); var f = _args.RecurrentActivation.Apply(z1);
var c = f * c_tm1 + i * _args.RecurrentActivation.Apply(z2);
var c = f * c_tm1 + i * _args.Activation.Apply(z2);
var o = _args.RecurrentActivation.Apply(z3); var o = _args.RecurrentActivation.Apply(z3);
return new Tensors(c, o); return new Tensors(c, o);
} }


+ 7
- 10
test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs View File

@@ -60,26 +60,23 @@ namespace Tensorflow.Keras.UnitTest.Layers
{ {
var input = keras.Input((784)); var input = keras.Input((784));
var x = keras.layers.Reshape((28, 28)).Apply(input); var x = keras.layers.Reshape((28, 28)).Apply(input);
//x = keras.layers.LSTM(50, return_sequences: true).Apply(x);
//x = keras.layers.LSTM(100, return_sequences: true).Apply(x);
//x = keras.layers.LSTM(150, return_sequences: true).Apply(x);
x = keras.layers.LSTM(4, implementation: 2).Apply(x);
//x = keras.layers.Dense(100).Apply(x);
x = keras.layers.LSTM(50, return_sequences: true).Apply(x);
x = keras.layers.LSTM(100).Apply(x);
var output = keras.layers.Dense(10, activation: "softmax").Apply(x); var output = keras.layers.Dense(10, activation: "softmax").Apply(x);


var model = keras.Model(input, output); var model = keras.Model(input, output);
model.summary(); model.summary();
model.compile(keras.optimizers.Adam(), keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });
model.compile(keras.optimizers.Adam(), keras.losses.CategoricalCrossentropy(), new string[] { "accuracy" });


var data_loader = new MnistModelLoader(); var data_loader = new MnistModelLoader();
var dataset = data_loader.LoadAsync(new ModelLoadSetting var dataset = data_loader.LoadAsync(new ModelLoadSetting
{ {
TrainDir = "mnist", TrainDir = "mnist",
OneHot = false,
ValidationSize = 58000,
OneHot = true,
ValidationSize = 55000,
}).Result; }).Result;


model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 30);
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 1);
} }


[TestMethod] [TestMethod]
@@ -102,7 +99,7 @@ namespace Tensorflow.Keras.UnitTest.Layers
ValidationSize = 58000, ValidationSize = 58000,
}).Result; }).Result;


model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 10);
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 2);
} }


[TestMethod] [TestMethod]


+ 6
- 2
tools/Tensorflow.CodeGen/FunctionGenerator.cs View File

@@ -83,8 +83,12 @@ namespace Tensorflow.CodeGen


sb.AppendLine("}"); // try sb.AppendLine("}"); // try


sb.Append("catch(NotOkStatusException ex)\n{\n");
sb.AppendLine("throw ex;");
sb.Append("catch(NotOkStatusException ex1)\n{\n");
sb.AppendLine("throw ex1;");
sb.AppendLine("}"); // catch

sb.Append("catch(InvalidArgumentError ex2)\n{\n");
sb.AppendLine("throw ex2;");
sb.AppendLine("}"); // catch sb.AppendLine("}"); // catch


sb.Append("catch(Exception)\n{\n"); sb.Append("catch(Exception)\n{\n");


+ 1
- 1
tools/Tensorflow.CodeGen/Tensorflow.CodeGen.csproj View File

@@ -9,7 +9,7 @@


<ItemGroup> <ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Scripting" Version="4.6.0-1.final" /> <PackageReference Include="Microsoft.CodeAnalysis.CSharp.Scripting" Version="4.6.0-1.final" />
<PackageReference Include="Protobuf.Text" Version="0.7.0" />
<PackageReference Include="Protobuf.Text" Version="0.7.1" />
</ItemGroup> </ItemGroup>


<ItemGroup> <ItemGroup>


Loading…
Cancel
Save