Browse Source

Keras step_function doesn't return accuracy #630

tags/v0.30
Oceania2018 4 years ago
parent
commit
8141f8c59f
14 changed files with 222 additions and 21 deletions
  1. +6
    -0
      src/TensorFlowNET.Core/Binding.Util.cs
  2. +12
    -2
      src/TensorFlowNET.Core/Keras/Engine/LossesContainer.cs
  3. +57
    -5
      src/TensorFlowNET.Core/Keras/Engine/MetricsContainer.cs
  4. +11
    -2
      src/TensorFlowNET.Core/Keras/Engine/Model.Fit.cs
  5. +31
    -0
      src/TensorFlowNET.Core/Keras/Engine/Model.Metrics.cs
  6. +5
    -4
      src/TensorFlowNET.Core/Keras/Engine/Model.Train.cs
  7. +2
    -0
      src/TensorFlowNET.Core/Keras/KerasApi.cs
  8. +1
    -1
      src/TensorFlowNET.Core/Keras/Metrics/Mean.cs
  9. +27
    -0
      src/TensorFlowNET.Core/Keras/Metrics/MeanMetricWrapper.cs
  10. +14
    -0
      src/TensorFlowNET.Core/Keras/Metrics/Metric.cs
  11. +33
    -0
      src/TensorFlowNET.Core/Keras/Metrics/MetricsApi.cs
  12. +19
    -7
      src/TensorFlowNET.Core/Keras/Metrics/Reduce.cs
  13. +1
    -0
      src/TensorFlowNET.Core/Operations/Losses/Reduction.cs
  14. +3
    -0
      src/TensorFlowNET.Core/Operations/math_ops.cs

+ 6
- 0
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -49,6 +49,12 @@ namespace Tensorflow
public static void add<T>(this IList<T> list, T element)
=> list.Add(element);

public static void add<T>(this IList<T> list, IEnumerable<T> elements)
{
foreach (var ele in elements)
list.Add(ele);
}

public static void append<T>(this IList<T> list, T element)
=> list.Insert(list.Count, element);



+ 12
- 2
src/TensorFlowNET.Core/Keras/Engine/LossesContainer.cs View File

@@ -11,7 +11,6 @@ namespace Tensorflow.Keras.Engine
Mean _loss_metric;
bool _built;
Tensor[] _per_output_metrics;
List<Tensor> loss_metric_values;

public LossesContainer(ILossFunc losses, string[] output_names = null)
: base(output_names)
@@ -19,7 +18,6 @@ namespace Tensorflow.Keras.Engine
_user_losses = losses;
_losses = losses;
_loss_metric = new Mean(name: "loss");
loss_metric_values = new List<Tensor>();
_built = false;
}

@@ -37,6 +35,7 @@ namespace Tensorflow.Keras.Engine
var batch_dim = array_ops.shape(y_true)[0];

var loss_values = new List<Tensor>();
var loss_metric_values = new List<Tensor>();

/*if (_losses.Reduction == ReductionV2.SUM_OVER_BATCH_SIZE
|| _losses.Reduction == ReductionV2.AUTO)
@@ -69,5 +68,16 @@ namespace Tensorflow.Keras.Engine
{
// _per_output_metrics = _output_names.Select(x => null);
}

public IEnumerable<Metric> metrics
{
get
{
if (!_built)
return new List<Metric>();

return new[] { _loss_metric };
}
}
}
}

+ 57
- 5
src/TensorFlowNET.Core/Keras/Engine/MetricsContainer.cs View File

@@ -1,29 +1,81 @@
namespace Tensorflow.Keras.Engine
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.Metrics;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Engine
{
public class MetricsContainer : Container
{
string[] _user_metrics;
string[] _metrics;
string[] _metric_names;
Metric[] _metrics;
List<Metric> _metrics_in_order;

public MetricsContainer(string[] metrics, string[] output_names = null)
: base(output_names)
{
_user_metrics = metrics;
_metrics = metrics;
_metric_names = metrics;
_built = false;
}

public void update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
{
if (!_built)
Build();
Build(y_true, y_pred);

foreach (var metric_obj in _metrics_in_order)
metric_obj.update_state(y_true, y_pred);
}

void Build(Tensor y_true, Tensor y_pred)
{
_metrics = _get_metric_objects(_metric_names, y_true, y_pred);
_set_metric_names();
_create_ordered_metrics();
_built = true;
}

void Build()
void _set_metric_names()
{

}

void _create_ordered_metrics()
{
_metrics_in_order = new List<Metric>();
foreach (var m in _metrics)
_metrics_in_order.append(m);
}

Metric[] _get_metric_objects(string[] metrics, Tensor y_t, Tensor y_p)
{
return metrics.Select(x => _get_metric_object(x, y_t, y_p)).ToArray();
}

Metric _get_metric_object(string metric, Tensor y_t, Tensor y_p)
{
Func<Tensor, Tensor, Tensor> metric_obj = null;
if (metric == "accuracy" || metric == "acc")
{
metric_obj = tf.keras.metrics.sparse_categorical_accuracy;
return new MeanMetricWrapper(metric_obj, metric);
}

throw new NotImplementedException("");
}

public IEnumerable<Metric> metrics
{
get
{
if (!_built)
return new List<Metric>();

return _metrics_in_order;
}
}
}
}

+ 11
- 2
src/TensorFlowNET.Core/Keras/Engine/Model.Fit.cs View File

@@ -1,5 +1,7 @@
using NumSharp;
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine.DataAdapters;

@@ -51,17 +53,24 @@ namespace Tensorflow.Keras.Engine

stop_training = false;
_train_counter.assign(0);
bool first_step = true;
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
// reset_metrics();
// callbacks.on_epoch_begin(epoch)
// data_handler.catch_stop_iteration();
IEnumerable<(string, Tensor)> results = null;
foreach (var step in data_handler.steps())
{
// callbacks.on_train_batch_begin(step)
step_function(iterator);
results = step_function(iterator);
if (first_step)
{
Console.WriteLine($"epoch: {epoch}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
first_step = false;
}
}
Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
}
}
}


+ 31
- 0
src/TensorFlowNET.Core/Keras/Engine/Model.Metrics.cs View File

@@ -0,0 +1,31 @@
using System.Collections.Generic;
using Tensorflow.Keras.Metrics;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Engine
{
public partial class Model
{
public IEnumerable<Metric> metrics
{
get
{
var _metrics = new List<Metric>();
if (_is_compiled)
{
if (compiled_loss != null)
_metrics.add(compiled_loss.metrics);
if (compiled_metrics != null)
_metrics.add(compiled_metrics.metrics);
}

foreach(var layer in _flatten_layers())
{
// _metrics.extend(layer.metrics);
}

return _metrics;
}
}
}
}

+ 5
- 4
src/TensorFlowNET.Core/Keras/Engine/Model.Train.cs View File

@@ -8,12 +8,12 @@ namespace Tensorflow.Keras.Engine
{
public partial class Model
{
Tensor step_function(OwnedIterator iterator)
IEnumerable<(string, Tensor)> step_function(OwnedIterator iterator)
{
var data = iterator.next();
var outputs = train_step(data[0], data[1]);
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
return null;
return outputs;
}

/// <summary>
@@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine
/// </summary>
/// <param name="data"></param>
/// <returns></returns>
IEnumerable<(string, Tensor)> train_step(Tensor x, Tensor y)
List<(string, Tensor)> train_step(Tensor x, Tensor y)
{
(x, y) = data_handler.DataAdapter.Expand1d(x, y);
using var tape = tf.GradientTape();
@@ -36,7 +36,8 @@ namespace Tensorflow.Keras.Engine
// such as loss scaling and gradient clipping.
_minimize(tape, optimizer, loss, trainable_variables);
compiled_metrics.update_state(y, y_pred);
return new[] { ("loss", loss) };
return metrics.Select(x => (x.Name, x.result())).ToList();
}

void _minimize(GradientTape tape, OptimizerV2 optimizer, Tensor loss, List<IVariableV1> trainable_variables)


+ 2
- 0
src/TensorFlowNET.Core/Keras/KerasApi.cs View File

@@ -5,6 +5,7 @@ using Tensorflow.Keras.Datasets;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Optimizers;

namespace Tensorflow
@@ -20,6 +21,7 @@ namespace Tensorflow
public Preprocessing preprocessing { get; } = new Preprocessing();
public BackendImpl backend { get; } = new BackendImpl();
public OptimizerApi optimizers { get; } = new OptimizerApi();
public MetricsApi metrics { get; } = new MetricsApi();

public Sequential Sequential(List<Layer> layers = null,
string name = null)


+ 1
- 1
src/TensorFlowNET.Core/Keras/Metrics/Mean.cs View File

@@ -5,7 +5,7 @@
/// </summary>
public class Mean : Reduce
{
public Mean(string name = "mean", TF_DataType dtype = TF_DataType.DtInvalid)
public Mean(string name = "mean", TF_DataType dtype = TF_DataType.TF_FLOAT)
: base(Reduction.WEIGHTED_MEAN, name, dtype: dtype)
{



+ 27
- 0
src/TensorFlowNET.Core/Keras/Metrics/MeanMetricWrapper.cs View File

@@ -0,0 +1,27 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Metrics
{
public class MeanMetricWrapper : Mean
{
string name;
Func<Tensor, Tensor, Tensor> _fn = null;
public MeanMetricWrapper(Func<Tensor, Tensor, Tensor> fn, string name, TF_DataType dtype = TF_DataType.TF_FLOAT)
: base(name: name, dtype: dtype)
{
_fn = fn;
}

public override Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
{
y_true = math_ops.cast(y_true, _dtype);
y_pred = math_ops.cast(y_pred, _dtype);

var matches = _fn(y_true, y_pred);
return update_state(matches, sample_weight: sample_weight);
}
}
}

+ 14
- 0
src/TensorFlowNET.Core/Keras/Metrics/Metric.cs View File

@@ -10,6 +10,11 @@ namespace Tensorflow.Keras.Metrics
/// </summary>
public class Metric : Layer
{
protected IVariableV1 total;
protected IVariableV1 count;
protected string _reduction;
protected TF_DataType _dtype;

public Metric(string name = null, TF_DataType dtype = TF_DataType.DtInvalid)
: base(new LayerArgs
{
@@ -44,5 +49,14 @@ namespace Tensorflow.Keras.Metrics
aggregation: aggregation);
});
}

public virtual Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
=> throw new NotImplementedException("");

public virtual Tensor result()
=> throw new NotImplementedException("");

public override string ToString()
=> $"{name} {(float)total.numpy()}/{(float)count.numpy()}";
}
}

+ 33
- 0
src/TensorFlowNET.Core/Keras/Metrics/MetricsApi.cs View File

@@ -0,0 +1,33 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Metrics
{
public class MetricsApi
{
/// <summary>
/// Calculates how often predictions matches integer labels.
/// </summary>
/// <param name="y_true">Integer ground truth values.</param>
/// <param name="y_pred">The prediction values.</param>
/// <returns>Sparse categorical accuracy values.</returns>
public Tensor sparse_categorical_accuracy(Tensor y_true, Tensor y_pred)
{
var y_pred_rank = y_pred.TensorShape.ndim;
var y_true_rank = y_true.TensorShape.ndim;
// If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
if (y_true_rank != -1 && y_pred_rank != -1
&& y_true.shape.Length == y_pred.shape.Length)
y_true = array_ops.squeeze(y_true, axis: new[] { -1 });
y_pred = math_ops.argmax(y_pred, -1);

// If the predicted output and actual output types don't match, force cast them
// to match.
if (y_pred.dtype != y_true.dtype)
y_pred = math_ops.cast(y_pred, y_true.dtype);

return math_ops.cast(math_ops.equal(y_true, y_pred), TF_DataType.TF_FLOAT);
}
}
}

+ 19
- 7
src/TensorFlowNET.Core/Keras/Metrics/Reduce.cs View File

@@ -9,11 +9,6 @@ namespace Tensorflow.Keras.Metrics
/// </summary>
public class Reduce : Metric
{
IVariableV1 total;
IVariableV1 count;
string _reduction;
TF_DataType _dtype;

public Reduce(string reduction, string name, TF_DataType dtype = TF_DataType.DtInvalid)
: base(name: name, dtype: dtype)
{
@@ -43,11 +38,18 @@ namespace Tensorflow.Keras.Metrics
var value_sum = math_ops.reduce_sum(values);
tf_with(ops.control_dependencies(new[] { value_sum }), ctl =>
{
var update_total_op = total.assign_add(value_sum);
update_total_op = total.assign_add(value_sum);
});

// Exit early if the reduction doesn't have a denominator.
if (_reduction == Reduction.SUM)
return update_total_op;

// Update `count` for reductions that require a denominator.
Tensor num_values = null;
if (_reduction == ReductionV2.WEIGHTED_MEAN)
if (_reduction == Reduction.SUM_OVER_BATCH_SIZE)
num_values = math_ops.cast(array_ops.size(values), _dtype);
else if (_reduction == ReductionV2.WEIGHTED_MEAN)
{
if (sample_weight == null)
num_values = math_ops.cast(array_ops.size(values), _dtype);
@@ -58,5 +60,15 @@ namespace Tensorflow.Keras.Metrics
return tf_with(ops.control_dependencies(new[] { update_total_op }), ctl
=> count.assign_add(num_values));
}

public override Tensor result()
{
if (_reduction == Reduction.SUM)
return array_ops.identity(total.AsTensor());
else if (_reduction == Reduction.WEIGHTED_MEAN || _reduction == Reduction.SUM_OVER_BATCH_SIZE)
return math_ops.div_no_nan(total.AsTensor(), count.AsTensor());

return base.result();
}
}
}

+ 1
- 0
src/TensorFlowNET.Core/Operations/Losses/Reduction.cs View File

@@ -3,6 +3,7 @@
public class Reduction
{
public const string NONE = "none";
public const string SUM = "sum";
public const string WEIGHTED_SUM = "weighted_sum";
public const string SUM_OVER_BATCH_SIZE = "weighted_sum_over_batch_size";
public const string WEIGHTED_MEAN = "weighted_mean";


+ 3
- 0
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -68,6 +68,9 @@ namespace Tensorflow
return gen_math_ops.add_n(inputs, name: name);
}

public static Tensor argmax(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null)
=> gen_math_ops.arg_max(input, dimension, output_type: output_type, name: name);

public static Tensor round(Tensor x, string name = null)
{
x = ops.convert_to_tensor(x, name: "x");


Loading…
Cancel
Save