Browse Source

Fix NeuralNetXorKeras accuracy. #952

tags/v0.100.4-load-saved-model
Haiping Chen 2 years ago
parent
commit
9b11d45906
7 changed files with 92 additions and 21 deletions
  1. +1
    -1
      src/TensorFlowNET.Keras/Engine/Functional.cs
  2. +2
    -2
      src/TensorFlowNET.Keras/Engine/Layer.AddWeights.cs
  3. +10
    -10
      src/TensorFlowNET.Keras/Engine/Layer.cs
  4. +40
    -7
      src/TensorFlowNET.Keras/Engine/Model.Fit.cs
  5. +15
    -0
      src/TensorFlowNET.Keras/Engine/Model.cs
  6. +1
    -1
      src/python/.vscode/launch.json
  7. +23
    -0
      src/python/xor_keras.py

+ 1
- 1
src/TensorFlowNET.Keras/Engine/Functional.cs View File

@@ -71,7 +71,7 @@ namespace Tensorflow.Keras.Engine
NodesByDepth = nodes_by_depth;
if (_layers.Count == 0)
_layers = layers;
_self_tracked_trackables = layers;
// Build self.input_names and self.output_names.
_set_output_names();



+ 2
- 2
src/TensorFlowNET.Keras/Engine/Layer.AddWeights.cs View File

@@ -53,9 +53,9 @@ namespace Tensorflow.Keras.Engine

//backend.track_variable(variable);
if (trainable == true)
trainable_weights.Add(variable);
_trainable_weights.Add(variable);
else
non_trainable_weights.Add(variable);
_non_trainable_weights.Add(variable);

return variable;
}


+ 10
- 10
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -61,12 +61,12 @@ namespace Tensorflow.Keras.Engine
protected InputSpec inputSpec;
bool dynamic = true;
public bool SupportsMasking { get; set; }
protected List<IVariableV1> trainable_weights;
protected List<IVariableV1> _trainable_weights;

public virtual List<IVariableV1> trainable_variables => trainable_weights;
public virtual List<IVariableV1> trainable_variables => _trainable_weights;

protected List<IVariableV1> non_trainable_weights;
public List<IVariableV1> non_trainable_variables => non_trainable_weights;
protected List<IVariableV1> _non_trainable_weights;
public List<IVariableV1> non_trainable_variables => _non_trainable_weights;

protected int id;
public int Id => id;
@@ -104,8 +104,8 @@ namespace Tensorflow.Keras.Engine

id = ops.uid_layer();
_init_set_name(args.Name);
trainable_weights = new List<IVariableV1>();
non_trainable_weights = new List<IVariableV1>();
_trainable_weights = new List<IVariableV1>();
_non_trainable_weights = new List<IVariableV1>();
computePreviousMask = false;
updates = new List<Operation>();
_self_tracked_trackables = new List<ILayer>();
@@ -254,7 +254,7 @@ namespace Tensorflow.Keras.Engine
{
get
{
return trainable_weights;
return _trainable_weights;
}
}

@@ -262,7 +262,7 @@ namespace Tensorflow.Keras.Engine
{
get
{
return non_trainable_weights;
return _non_trainable_weights;
}
}

@@ -271,8 +271,8 @@ namespace Tensorflow.Keras.Engine
get
{
var weights = new List<IVariableV1>();
weights.AddRange(trainable_weights);
weights.AddRange(non_trainable_weights);
weights.AddRange(_trainable_weights);
weights.AddRange(_non_trainable_weights);
return weights;
}
set


+ 40
- 7
src/TensorFlowNET.Keras/Engine/Model.Fit.cs View File

@@ -4,6 +4,7 @@ using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine.DataAdapters;
using System.Diagnostics;

namespace Tensorflow.Keras.Engine
{
@@ -87,25 +88,57 @@ namespace Tensorflow.Keras.Engine
{
stop_training = false;
_train_counter.assign(0);
Stopwatch sw = new Stopwatch();
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
reset_metrics();
// callbacks.on_epoch_begin(epoch)
on_epoch_begin(epoch, epochs);
// data_handler.catch_stop_iteration();
foreach (var step in data_handler.steps())
{
// callbacks.on_train_batch_begin(step)
sw.Start();
var results = train_step_function(iterator);
if (verbose == 1)
sw.Stop();
on_train_batch_begin(verbose, step, sw.ElapsedMilliseconds, results);

// recycle memory more frequency
if (sw.ElapsedMilliseconds > 100)
{
var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}"));
Binding.tf_output_redirect.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}");
GC.Collect();
}

GC.Collect();
sw.Reset();
}
Console.WriteLine();

GC.Collect();
GC.WaitForPendingFinalizers();
}
}

void on_epoch_begin(int epoch, int epochs)
{
Binding.tf_output_redirect.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}");
}

void on_train_batch_begin(int verbose, long step, long elapse, IEnumerable<(string, Tensor)> results)
{
if (verbose == 1)
{
var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}"));

var progress = "";
for (int i = 0; i < step + 1; i++)
for (int j = 0; j < 30 / data_handler.Inferredsteps; j++)
progress += "=";
progress += ">";

var remaining = "";
for (int i = 1; i < 30 - progress.Length; i++)
remaining += ".";

Binding.tf_output_redirect.Write($"{step + 1:D4}/{data_handler.Inferredsteps:D4} [{progress}{remaining}] - {elapse}ms/step {result_pairs}");
Console.CursorLeft = 0;
}
}
}
}

+ 15
- 0
src/TensorFlowNET.Keras/Engine/Model.cs View File

@@ -75,11 +75,26 @@ namespace Tensorflow.Keras.Engine
get
{
var variables = new List<IVariableV1>();

if (!Trainable)
{
return variables;
}

foreach (var trackable_obj in _self_tracked_trackables)
{
if (trackable_obj.Trainable)
variables.AddRange(trackable_obj.trainable_variables);
}

foreach (var layer in _layers)
{
if (layer.Trainable)
variables.AddRange(layer.trainable_variables);
}

// variables.AddRange(_trainable_weights);

return variables;
}
}


+ 1
- 1
src/python/.vscode/launch.json View File

@@ -8,7 +8,7 @@
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"program": "${workspaceFolder}/xor_keras.py",
"console": "integratedTerminal",
"justMyCode": false
}


+ 23
- 0
src/python/xor_keras.py View File

@@ -0,0 +1,23 @@
import os
import numpy as np
import tensorflow as tf

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
print(tf.__version__)
# tf.compat.v1.enable_eager_execution()
# tf.debugging.set_log_device_placement(True);
tf.config.run_functions_eagerly(True)

x = np.array([[ 0, 0 ], [ 0, 1 ], [ 1, 0 ], [ 1, 1 ]])
y = np.array([[ 0 ], [ 1 ], [ 1 ], [ 0 ] ])

model = tf.keras.Sequential()
model.add(tf.keras.Input(2))
model.add(tf.keras.layers.Dense(32, "relu"))
model.add(tf.keras.layers.Dense(1, "sigmoid"))
model.compile(optimizer = tf.keras.optimizers.Adam(),
loss = tf.keras.losses.MeanSquaredError(),
metrics = ["accuracy"])
model.fit(x, y, 1, 100)
result = model.evaluate(x, y)
print(model.predict(x, 4))

Loading…
Cancel
Save