Browse Source

Keras step_function doesn't return accuracy #630

tags/v0.30
Oceania2018 4 years ago
parent
commit
8141f8c59f
14 changed files with 222 additions and 21 deletions
  1. +6
    -0
      src/TensorFlowNET.Core/Binding.Util.cs
  2. +12
    -2
      src/TensorFlowNET.Core/Keras/Engine/LossesContainer.cs
  3. +57
    -5
      src/TensorFlowNET.Core/Keras/Engine/MetricsContainer.cs
  4. +11
    -2
      src/TensorFlowNET.Core/Keras/Engine/Model.Fit.cs
  5. +31
    -0
      src/TensorFlowNET.Core/Keras/Engine/Model.Metrics.cs
  6. +5
    -4
      src/TensorFlowNET.Core/Keras/Engine/Model.Train.cs
  7. +2
    -0
      src/TensorFlowNET.Core/Keras/KerasApi.cs
  8. +1
    -1
      src/TensorFlowNET.Core/Keras/Metrics/Mean.cs
  9. +27
    -0
      src/TensorFlowNET.Core/Keras/Metrics/MeanMetricWrapper.cs
  10. +14
    -0
      src/TensorFlowNET.Core/Keras/Metrics/Metric.cs
  11. +33
    -0
      src/TensorFlowNET.Core/Keras/Metrics/MetricsApi.cs
  12. +19
    -7
      src/TensorFlowNET.Core/Keras/Metrics/Reduce.cs
  13. +1
    -0
      src/TensorFlowNET.Core/Operations/Losses/Reduction.cs
  14. +3
    -0
      src/TensorFlowNET.Core/Operations/math_ops.cs

+ 6
- 0
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -49,6 +49,12 @@ namespace Tensorflow
public static void add<T>(this IList<T> list, T element) public static void add<T>(this IList<T> list, T element)
=> list.Add(element); => list.Add(element);


public static void add<T>(this IList<T> list, IEnumerable<T> elements)
{
foreach (var ele in elements)
list.Add(ele);
}

public static void append<T>(this IList<T> list, T element) public static void append<T>(this IList<T> list, T element)
=> list.Insert(list.Count, element); => list.Insert(list.Count, element);




+ 12
- 2
src/TensorFlowNET.Core/Keras/Engine/LossesContainer.cs View File

@@ -11,7 +11,6 @@ namespace Tensorflow.Keras.Engine
Mean _loss_metric; Mean _loss_metric;
bool _built; bool _built;
Tensor[] _per_output_metrics; Tensor[] _per_output_metrics;
List<Tensor> loss_metric_values;


public LossesContainer(ILossFunc losses, string[] output_names = null) public LossesContainer(ILossFunc losses, string[] output_names = null)
: base(output_names) : base(output_names)
@@ -19,7 +18,6 @@ namespace Tensorflow.Keras.Engine
_user_losses = losses; _user_losses = losses;
_losses = losses; _losses = losses;
_loss_metric = new Mean(name: "loss"); _loss_metric = new Mean(name: "loss");
loss_metric_values = new List<Tensor>();
_built = false; _built = false;
} }


@@ -37,6 +35,7 @@ namespace Tensorflow.Keras.Engine
var batch_dim = array_ops.shape(y_true)[0]; var batch_dim = array_ops.shape(y_true)[0];


var loss_values = new List<Tensor>(); var loss_values = new List<Tensor>();
var loss_metric_values = new List<Tensor>();


/*if (_losses.Reduction == ReductionV2.SUM_OVER_BATCH_SIZE /*if (_losses.Reduction == ReductionV2.SUM_OVER_BATCH_SIZE
|| _losses.Reduction == ReductionV2.AUTO) || _losses.Reduction == ReductionV2.AUTO)
@@ -69,5 +68,16 @@ namespace Tensorflow.Keras.Engine
{ {
// _per_output_metrics = _output_names.Select(x => null); // _per_output_metrics = _output_names.Select(x => null);
} }

public IEnumerable<Metric> metrics
{
get
{
if (!_built)
return new List<Metric>();

return new[] { _loss_metric };
}
}
} }
} }

+ 57
- 5
src/TensorFlowNET.Core/Keras/Engine/MetricsContainer.cs View File

@@ -1,29 +1,81 @@
namespace Tensorflow.Keras.Engine
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.Metrics;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Engine
{ {
public class MetricsContainer : Container public class MetricsContainer : Container
{ {
string[] _user_metrics; string[] _user_metrics;
string[] _metrics;
string[] _metric_names;
Metric[] _metrics;
List<Metric> _metrics_in_order;


public MetricsContainer(string[] metrics, string[] output_names = null) public MetricsContainer(string[] metrics, string[] output_names = null)
: base(output_names) : base(output_names)
{ {
_user_metrics = metrics; _user_metrics = metrics;
_metrics = metrics;
_metric_names = metrics;
_built = false; _built = false;
} }


public void update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) public void update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
{ {
if (!_built) if (!_built)
Build();
Build(y_true, y_pred);

foreach (var metric_obj in _metrics_in_order)
metric_obj.update_state(y_true, y_pred);
}


void Build(Tensor y_true, Tensor y_pred)
{
_metrics = _get_metric_objects(_metric_names, y_true, y_pred);
_set_metric_names();
_create_ordered_metrics();
_built = true; _built = true;
} }


void Build()
void _set_metric_names()
{

}

void _create_ordered_metrics()
{
_metrics_in_order = new List<Metric>();
foreach (var m in _metrics)
_metrics_in_order.append(m);
}

Metric[] _get_metric_objects(string[] metrics, Tensor y_t, Tensor y_p)
{
return metrics.Select(x => _get_metric_object(x, y_t, y_p)).ToArray();
}

Metric _get_metric_object(string metric, Tensor y_t, Tensor y_p)
{
Func<Tensor, Tensor, Tensor> metric_obj = null;
if (metric == "accuracy" || metric == "acc")
{
metric_obj = tf.keras.metrics.sparse_categorical_accuracy;
return new MeanMetricWrapper(metric_obj, metric);
}

throw new NotImplementedException("");
}

public IEnumerable<Metric> metrics
{ {
get
{
if (!_built)
return new List<Metric>();


return _metrics_in_order;
}
} }
} }
} }

+ 11
- 2
src/TensorFlowNET.Core/Keras/Engine/Model.Fit.cs View File

@@ -1,5 +1,7 @@
using NumSharp; using NumSharp;
using System; using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine.DataAdapters; using Tensorflow.Keras.Engine.DataAdapters;


@@ -51,17 +53,24 @@ namespace Tensorflow.Keras.Engine


stop_training = false; stop_training = false;
_train_counter.assign(0); _train_counter.assign(0);
bool first_step = true;
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{ {
// reset_metrics(); // reset_metrics();
// callbacks.on_epoch_begin(epoch) // callbacks.on_epoch_begin(epoch)
// data_handler.catch_stop_iteration(); // data_handler.catch_stop_iteration();
IEnumerable<(string, Tensor)> results = null;
foreach (var step in data_handler.steps()) foreach (var step in data_handler.steps())
{ {
// callbacks.on_train_batch_begin(step) // callbacks.on_train_batch_begin(step)
step_function(iterator);
results = step_function(iterator);
if (first_step)
{
Console.WriteLine($"epoch: {epoch}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
first_step = false;
}
} }
Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
} }
} }
} }


+ 31
- 0
src/TensorFlowNET.Core/Keras/Engine/Model.Metrics.cs View File

@@ -0,0 +1,31 @@
using System.Collections.Generic;
using Tensorflow.Keras.Metrics;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Engine
{
public partial class Model
{
public IEnumerable<Metric> metrics
{
get
{
var _metrics = new List<Metric>();
if (_is_compiled)
{
if (compiled_loss != null)
_metrics.add(compiled_loss.metrics);
if (compiled_metrics != null)
_metrics.add(compiled_metrics.metrics);
}

foreach(var layer in _flatten_layers())
{
// _metrics.extend(layer.metrics);
}

return _metrics;
}
}
}
}

+ 5
- 4
src/TensorFlowNET.Core/Keras/Engine/Model.Train.cs View File

@@ -8,12 +8,12 @@ namespace Tensorflow.Keras.Engine
{ {
public partial class Model public partial class Model
{ {
Tensor step_function(OwnedIterator iterator)
IEnumerable<(string, Tensor)> step_function(OwnedIterator iterator)
{ {
var data = iterator.next(); var data = iterator.next();
var outputs = train_step(data[0], data[1]); var outputs = train_step(data[0], data[1]);
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
return null;
return outputs;
} }


/// <summary> /// <summary>
@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine
/// </summary> /// </summary>
/// <param name="data"></param> /// <param name="data"></param>
/// <returns></returns> /// <returns></returns>
IEnumerable<(string, Tensor)> train_step(Tensor x, Tensor y)
List<(string, Tensor)> train_step(Tensor x, Tensor y)
{ {
(x, y) = data_handler.DataAdapter.Expand1d(x, y); (x, y) = data_handler.DataAdapter.Expand1d(x, y);
using var tape = tf.GradientTape(); using var tape = tf.GradientTape();
@@ -36,7 +36,8 @@ namespace Tensorflow.Keras.Engine
// such as loss scaling and gradient clipping. // such as loss scaling and gradient clipping.
_minimize(tape, optimizer, loss, trainable_variables); _minimize(tape, optimizer, loss, trainable_variables);
compiled_metrics.update_state(y, y_pred); compiled_metrics.update_state(y, y_pred);
return new[] { ("loss", loss) };
return metrics.Select(x => (x.Name, x.result())).ToList();
} }


void _minimize(GradientTape tape, OptimizerV2 optimizer, Tensor loss, List<IVariableV1> trainable_variables) void _minimize(GradientTape tape, OptimizerV2 optimizer, Tensor loss, List<IVariableV1> trainable_variables)


+ 2
- 0
src/TensorFlowNET.Core/Keras/KerasApi.cs View File

@@ -5,6 +5,7 @@ using Tensorflow.Keras.Datasets;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers; using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Losses; using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Optimizers; using Tensorflow.Keras.Optimizers;


namespace Tensorflow namespace Tensorflow
@@ -20,6 +21,7 @@ namespace Tensorflow
public Preprocessing preprocessing { get; } = new Preprocessing(); public Preprocessing preprocessing { get; } = new Preprocessing();
public BackendImpl backend { get; } = new BackendImpl(); public BackendImpl backend { get; } = new BackendImpl();
public OptimizerApi optimizers { get; } = new OptimizerApi(); public OptimizerApi optimizers { get; } = new OptimizerApi();
public MetricsApi metrics { get; } = new MetricsApi();


public Sequential Sequential(List<Layer> layers = null, public Sequential Sequential(List<Layer> layers = null,
string name = null) string name = null)


+ 1
- 1
src/TensorFlowNET.Core/Keras/Metrics/Mean.cs View File

@@ -5,7 +5,7 @@
/// </summary> /// </summary>
public class Mean : Reduce public class Mean : Reduce
{ {
public Mean(string name = "mean", TF_DataType dtype = TF_DataType.DtInvalid)
public Mean(string name = "mean", TF_DataType dtype = TF_DataType.TF_FLOAT)
: base(Reduction.WEIGHTED_MEAN, name, dtype: dtype) : base(Reduction.WEIGHTED_MEAN, name, dtype: dtype)
{ {




+ 27
- 0
src/TensorFlowNET.Core/Keras/Metrics/MeanMetricWrapper.cs View File

@@ -0,0 +1,27 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Metrics
{
public class MeanMetricWrapper : Mean
{
string name;
Func<Tensor, Tensor, Tensor> _fn = null;
public MeanMetricWrapper(Func<Tensor, Tensor, Tensor> fn, string name, TF_DataType dtype = TF_DataType.TF_FLOAT)
: base(name: name, dtype: dtype)
{
_fn = fn;
}

public override Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
{
y_true = math_ops.cast(y_true, _dtype);
y_pred = math_ops.cast(y_pred, _dtype);

var matches = _fn(y_true, y_pred);
return update_state(matches, sample_weight: sample_weight);
}
}
}

+ 14
- 0
src/TensorFlowNET.Core/Keras/Metrics/Metric.cs View File

@@ -10,6 +10,11 @@ namespace Tensorflow.Keras.Metrics
/// </summary> /// </summary>
public class Metric : Layer public class Metric : Layer
{ {
protected IVariableV1 total;
protected IVariableV1 count;
protected string _reduction;
protected TF_DataType _dtype;

public Metric(string name = null, TF_DataType dtype = TF_DataType.DtInvalid) public Metric(string name = null, TF_DataType dtype = TF_DataType.DtInvalid)
: base(new LayerArgs : base(new LayerArgs
{ {
@@ -44,5 +49,14 @@ namespace Tensorflow.Keras.Metrics
aggregation: aggregation); aggregation: aggregation);
}); });
} }

public virtual Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
=> throw new NotImplementedException("");

public virtual Tensor result()
=> throw new NotImplementedException("");

public override string ToString()
=> $"{name} {(float)total.numpy()}/{(float)count.numpy()}";
} }
} }

+ 33
- 0
src/TensorFlowNET.Core/Keras/Metrics/MetricsApi.cs View File

@@ -0,0 +1,33 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Metrics
{
public class MetricsApi
{
/// <summary>
/// Calculates how often predictions matches integer labels.
/// </summary>
/// <param name="y_true">Integer ground truth values.</param>
/// <param name="y_pred">The prediction values.</param>
/// <returns>Sparse categorical accuracy values.</returns>
public Tensor sparse_categorical_accuracy(Tensor y_true, Tensor y_pred)
{
var y_pred_rank = y_pred.TensorShape.ndim;
var y_true_rank = y_true.TensorShape.ndim;
// If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
if (y_true_rank != -1 && y_pred_rank != -1
&& y_true.shape.Length == y_pred.shape.Length)
y_true = array_ops.squeeze(y_true, axis: new[] { -1 });
y_pred = math_ops.argmax(y_pred, -1);

// If the predicted output and actual output types don't match, force cast them
// to match.
if (y_pred.dtype != y_true.dtype)
y_pred = math_ops.cast(y_pred, y_true.dtype);

return math_ops.cast(math_ops.equal(y_true, y_pred), TF_DataType.TF_FLOAT);
}
}
}

+ 19
- 7
src/TensorFlowNET.Core/Keras/Metrics/Reduce.cs View File

@@ -9,11 +9,6 @@ namespace Tensorflow.Keras.Metrics
/// </summary> /// </summary>
public class Reduce : Metric public class Reduce : Metric
{ {
IVariableV1 total;
IVariableV1 count;
string _reduction;
TF_DataType _dtype;

public Reduce(string reduction, string name, TF_DataType dtype = TF_DataType.DtInvalid) public Reduce(string reduction, string name, TF_DataType dtype = TF_DataType.DtInvalid)
: base(name: name, dtype: dtype) : base(name: name, dtype: dtype)
{ {
@@ -43,11 +38,18 @@ namespace Tensorflow.Keras.Metrics
var value_sum = math_ops.reduce_sum(values); var value_sum = math_ops.reduce_sum(values);
tf_with(ops.control_dependencies(new[] { value_sum }), ctl => tf_with(ops.control_dependencies(new[] { value_sum }), ctl =>
{ {
var update_total_op = total.assign_add(value_sum);
update_total_op = total.assign_add(value_sum);
}); });


// Exit early if the reduction doesn't have a denominator.
if (_reduction == Reduction.SUM)
return update_total_op;

// Update `count` for reductions that require a denominator.
Tensor num_values = null; Tensor num_values = null;
if (_reduction == ReductionV2.WEIGHTED_MEAN)
if (_reduction == Reduction.SUM_OVER_BATCH_SIZE)
num_values = math_ops.cast(array_ops.size(values), _dtype);
else if (_reduction == ReductionV2.WEIGHTED_MEAN)
{ {
if (sample_weight == null) if (sample_weight == null)
num_values = math_ops.cast(array_ops.size(values), _dtype); num_values = math_ops.cast(array_ops.size(values), _dtype);
@@ -58,5 +60,15 @@ namespace Tensorflow.Keras.Metrics
return tf_with(ops.control_dependencies(new[] { update_total_op }), ctl return tf_with(ops.control_dependencies(new[] { update_total_op }), ctl
=> count.assign_add(num_values)); => count.assign_add(num_values));
} }

public override Tensor result()
{
if (_reduction == Reduction.SUM)
return array_ops.identity(total.AsTensor());
else if (_reduction == Reduction.WEIGHTED_MEAN || _reduction == Reduction.SUM_OVER_BATCH_SIZE)
return math_ops.div_no_nan(total.AsTensor(), count.AsTensor());

return base.result();
}
} }
} }

+ 1
- 0
src/TensorFlowNET.Core/Operations/Losses/Reduction.cs View File

@@ -3,6 +3,7 @@
public class Reduction public class Reduction
{ {
public const string NONE = "none"; public const string NONE = "none";
public const string SUM = "sum";
public const string WEIGHTED_SUM = "weighted_sum"; public const string WEIGHTED_SUM = "weighted_sum";
public const string SUM_OVER_BATCH_SIZE = "weighted_sum_over_batch_size"; public const string SUM_OVER_BATCH_SIZE = "weighted_sum_over_batch_size";
public const string WEIGHTED_MEAN = "weighted_mean"; public const string WEIGHTED_MEAN = "weighted_mean";


+ 3
- 0
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -68,6 +68,9 @@ namespace Tensorflow
return gen_math_ops.add_n(inputs, name: name); return gen_math_ops.add_n(inputs, name: name);
} }


public static Tensor argmax(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null)
=> gen_math_ops.arg_max(input, dimension, output_type: output_type, name: name);

public static Tensor round(Tensor x, string name = null) public static Tensor round(Tensor x, string name = null)
{ {
x = ops.convert_to_tensor(x, name: "x"); x = ops.convert_to_tensor(x, name: "x");


Loading…
Cancel
Save