Browse Source

Merge pull request #1187 from Wanglongzhi2001/master

feat: add the implementation of sample_weight in model.fit
tags/v0.150.0-BERT-Model
Haiping GitHub 2 years ago
parent
commit
0ee9d424e5
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 250 additions and 100 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs
  2. +3
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs
  3. +9
    -2
      src/TensorFlowNET.Core/Keras/Engine/IModel.cs
  4. +66
    -0
      src/TensorFlowNET.Core/Util/Data.cs
  5. +59
    -0
      src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs
  6. +3
    -0
      src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
  7. +2
    -0
      src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs
  8. +5
    -2
      src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
  9. +2
    -2
      src/TensorFlowNET.Keras/Engine/LossesContainer.cs
  10. +14
    -5
      src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
  11. +44
    -85
      src/TensorFlowNET.Keras/Engine/Model.Fit.cs
  12. +38
    -2
      src/TensorFlowNET.Keras/Engine/Model.Train.cs
  13. +2
    -2
      test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs

+ 3
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs View File

@@ -1,5 +1,6 @@
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.NumPy;

namespace Tensorflow.Keras.ArgsDefinition
{
@@ -16,5 +17,7 @@ namespace Tensorflow.Keras.ArgsDefinition
public int Worker { get; set; }
public bool UseMultiprocessing { get; set; }
public IModel Model { get; set; }
public Dictionary<int, float> ClassWeight = null;
public NDArray SampleWeight = null;
}
}

+ 3
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs View File

@@ -1,5 +1,6 @@
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.NumPy;

namespace Tensorflow.Keras.ArgsDefinition
{
@@ -18,5 +19,7 @@ namespace Tensorflow.Keras.ArgsDefinition
public bool UseMultiprocessing { get; set; } = false;
public IModel Model { get; set; }
public IVariableV1 StepsPerExecution { get; set; }
public Dictionary<int, float> ClassWeight = null;
public NDArray SampleWeight = null;
}
}

+ 9
- 2
src/TensorFlowNET.Core/Keras/Engine/IModel.cs View File

@@ -3,6 +3,7 @@ using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Saving;
using Tensorflow.NumPy;
using Tensorflow.Util;

namespace Tensorflow.Keras.Engine;

@@ -22,8 +23,10 @@ public interface IModel : ILayer
int verbose = 1,
List<ICallback> callbacks = null,
float validation_split = 0f,
(NDArray val_x, NDArray val_y)? validation_data = null,
ValidationDataPack validation_data = null,
bool shuffle = true,
Dictionary<int, float> class_weight = null,
NDArray sample_weight = null,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
@@ -35,8 +38,10 @@ public interface IModel : ILayer
int verbose = 1,
List<ICallback> callbacks = null,
float validation_split = 0f,
(IEnumerable<NDArray> val_x, NDArray val_y)? validation_data = null,
ValidationDataPack validation_data = null,
bool shuffle = true,
Dictionary<int, float> class_weight = null,
NDArray sample_weight = null,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
@@ -63,6 +68,8 @@ public interface IModel : ILayer
Dictionary<string, float> evaluate(NDArray x, NDArray y,
int batch_size = -1,
int verbose = 1,
NDArray sample_weight = null,

int steps = -1,
int max_queue_size = 10,
int workers = 1,


+ 66
- 0
src/TensorFlowNET.Core/Util/Data.cs View File

@@ -0,0 +1,66 @@
using Tensorflow.NumPy;

namespace Tensorflow.Util
{
/// <summary>
/// ValidationDataPack is used to pass validation data to fit method.
/// It can recive data which could be A tuple `(x_val, xy_val)` or `(x_val, y_val, sample_weight_val)` of Numpy arrays.
/// </summary>
public class ValidationDataPack
{
public NDArray val_x;
public NDArray val_y;
public NDArray val_sample_weight = null;

public ValidationDataPack((NDArray, NDArray) validation_data)
{
this.val_x = validation_data.Item1;
this.val_y = validation_data.Item2;
}

public ValidationDataPack((NDArray, NDArray, NDArray) validation_data)
{
this.val_x = validation_data.Item1;
this.val_y = validation_data.Item2;
this.val_sample_weight = validation_data.Item3;
}

public ValidationDataPack((IEnumerable<NDArray>, NDArray) validation_data)
{
this.val_x = validation_data.Item1.ToArray()[0];
this.val_y = validation_data.Item2;
}

public ValidationDataPack((IEnumerable<NDArray>, NDArray, NDArray) validation_data)
{
this.val_x = validation_data.Item1.ToArray()[0];
this.val_y = validation_data.Item2;
this.val_sample_weight = validation_data.Item3;
}

public static implicit operator ValidationDataPack((NDArray, NDArray) validation_data)
=> new ValidationDataPack(validation_data);

public static implicit operator ValidationDataPack((NDArray, NDArray, NDArray) validation_data)
=> new ValidationDataPack(validation_data);

public static implicit operator ValidationDataPack((IEnumerable<NDArray>, NDArray) validation_data)
=> new ValidationDataPack(validation_data);

public static implicit operator ValidationDataPack((IEnumerable<NDArray>, NDArray, NDArray) validation_data)
=> new ValidationDataPack(validation_data);

public void Deconstruct(out NDArray val_x, out NDArray val_y)
{
val_x = this.val_x;
val_y = this.val_y;
}

public void Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight)
{
val_x = this.val_x;
val_y = this.val_y;
val_sample_weight = this.val_sample_weight;
}
}
}

+ 59
- 0
src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs View File

@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Util;

namespace Tensorflow.Keras.Engine.DataAdapters
{
@@ -34,9 +35,67 @@ namespace Tensorflow.Keras.Engine.DataAdapters
return (x, y);
}

public virtual (Tensors, Tensors, Tensors) Expand1d(Tensors x, Tensors y, Tensors sample_weight)
{
for (int i = 0; i < x.Length; i++)
{
if (x[i].shape.ndim == 1)
x[i] = array_ops.expand_dims(x[i], axis: -1);
}
for (int i = 0; i < y.Length; i++)
{
if (y[i].shape.ndim == 1)
y[i] = array_ops.expand_dims(y[i], axis: -1);
}
for (int i = 0; i < sample_weight.Length; i++)
{
if (sample_weight[i].shape.ndim == 1)
sample_weight[i] = array_ops.expand_dims(sample_weight[i], axis: -1);
}
return (x, y, sample_weight);
}

public virtual bool ShouldRecreateIterator()
{
return true;
}

public static ((NDArray, NDArray, NDArray),ValidationDataPack) train_validation_split((NDArray, NDArray, NDArray) x_y_sample_weight, float validation_split)
{
var x = x_y_sample_weight.Item1;
var y = x_y_sample_weight.Item2;
var sample_weight = x_y_sample_weight.Item3;
int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split));
var train_x = x[new Slice(0, train_count)];
var train_y = y[new Slice(0, train_count)];
ValidationDataPack validation_data;
if (sample_weight != null)
{
validation_data = (x[new Slice(train_count)], y[new Slice(train_count)], sample_weight[new Slice(train_count)]);
sample_weight = sample_weight[new Slice(0, train_count)];
}
else
{
validation_data = (x[new Slice(train_count)], y[new Slice(train_count)]);
}

return ((train_x, train_y, sample_weight), validation_data);
}

public static ((IEnumerable<NDArray>, NDArray, NDArray), ValidationDataPack) train_validation_split((IEnumerable<NDArray>, NDArray, NDArray) x_y_sample_weight, float validation_split)
{
var x = x_y_sample_weight.Item1;
var y = x_y_sample_weight.Item2;
var sample_weight = x_y_sample_weight.Item3;
int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split));
var train_x = x.Select(x => x[new Slice(0, train_count)] as NDArray);
var train_y = y[new Slice(0, train_count)];
var val_x = x.Select(x => x[new Slice(train_count)] as NDArray);
var val_y = y[new Slice(train_count)];
NDArray tmp_sample_weight = sample_weight;
sample_weight = sample_weight[new Slice(0, train_count)];
ValidationDataPack validation_data = (val_x, val_y, tmp_sample_weight[new Slice(train_count)]);
return ((train_x, train_y, sample_weight), validation_data);
}
}
}

+ 3
- 0
src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs View File

@@ -2,6 +2,7 @@
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using static Tensorflow.Binding;
using Tensorflow.Keras.Utils;

namespace Tensorflow.Keras.Engine.DataAdapters
{
@@ -28,6 +29,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
public DataHandler(DataHandlerArgs args)
{
this.args = args;
if (args.StepsPerExecution == null)
{
_steps_per_execution = tf.Variable(1L);
@@ -48,6 +50,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
BatchSize = args.BatchSize,
Steps = args.StepsPerEpoch,
Epochs = args.Epochs - args.InitialEpoch,
SampleWeight = args.SampleWeight,
Shuffle = args.Shuffle,
MaxQueueSize = args.MaxQueueSize,
Worker = args.Workers,


+ 2
- 0
src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs View File

@@ -17,6 +17,8 @@
IDatasetV2 GetDataset();
int GetSize();
(Tensors, Tensors) Expand1d(Tensors x, Tensors y);
(Tensors, Tensors, Tensors) Expand1d(Tensors x, Tensors y, Tensors sample_weight);

bool ShouldRecreateIterator();
}
}

+ 5
- 2
src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs View File

@@ -20,7 +20,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
public TensorLikeDataAdapter(DataAdapterArgs args)
{
this.args = args;
_process_tensorlike();
Tensor sample_weight_tensor = args.SampleWeight != null ? _process_tensorlike(args.SampleWeight) : null;
num_samples = (int)args.X.shape[0];
var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize;
_batch_size = batch_size;
@@ -37,6 +37,8 @@ namespace Tensorflow.Keras.Engine.DataAdapters
inputs.AddRange(args.X);
if (args.Y != null)
inputs.AddRange(args.Y);
if (sample_weight_tensor != null)
inputs.Add(sample_weight_tensor);
dataset = slice_inputs(indices_dataset, inputs);
dataset.FirstInputTensorCount = args.X.Length;
}
@@ -94,8 +96,9 @@ namespace Tensorflow.Keras.Engine.DataAdapters

public override bool ShouldRecreateIterator() => false;

void _process_tensorlike()
Tensor _process_tensorlike(NDArray sample_weights)
{
return tf.convert_to_tensor(sample_weights);
}
}
}

+ 2
- 2
src/TensorFlowNET.Keras/Engine/LossesContainer.cs View File

@@ -26,11 +26,11 @@ namespace Tensorflow.Keras.Engine
/// </summary>
/// <param name="y_true"></param>
/// <param name="y_pred"></param>
public Tensor Call(Tensor y_true, Tensor y_pred)
public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
{
if (!_built)
Build(y_pred);
var loss_value = _losses.Call(y_true, y_pred);
var loss_value = _losses.Call(y_true, y_pred, sample_weight:sample_weight);
var loss_metric_value = loss_value;
var batch_dim = array_ops.shape(y_true)[0];



+ 14
- 5
src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs View File

@@ -30,6 +30,7 @@ namespace Tensorflow.Keras.Engine
public Dictionary<string, float> evaluate(NDArray x, NDArray y,
int batch_size = -1,
int verbose = 1,
NDArray sample_weight = null,
int steps = -1,
int max_queue_size = 10,
int workers = 1,
@@ -51,6 +52,7 @@ namespace Tensorflow.Keras.Engine
StepsPerEpoch = steps,
InitialEpoch = 0,
Epochs = 1,
SampleWeight = sample_weight,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
@@ -140,7 +142,8 @@ namespace Tensorflow.Keras.Engine
Dictionary<string, float> test_function(DataHandler data_handler, OwnedIterator iterator)
{
var data = iterator.next();
var outputs = test_step(data_handler, data[0], data[1]);
var outputs = data.Length == 2 ? test_step(data_handler, data[0], data[1]) :
test_step(data_handler, data[0], data[1], data[2]);
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
return outputs;
}
@@ -149,17 +152,23 @@ namespace Tensorflow.Keras.Engine
{
var data = iterator.next();
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
var outputs = test_step(data_handler, data.Take(x_size).ToArray(), data.Skip(x_size).ToArray());
var outputs = data.Length == 2 ?
test_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())) :
test_step(
data_handler,
new Tensors(data.Take(x_size).ToArray()),
new Tensors(data.Skip(x_size).Take(x_size).ToArray()),
new Tensors(data.Skip(2 * x_size).ToArray()));
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
return outputs;
}


Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y)
Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight = null)
{
(x, y) = data_handler.DataAdapter.Expand1d(x, y);
(x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight);
var y_pred = Apply(x, training: false);
var loss = compiled_loss.Call(y, y_pred);
var loss = compiled_loss.Call(y, y_pred, sample_weight:sample_weight);
compiled_metrics.update_state(y, y_pred);
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2);
}


+ 44
- 85
src/TensorFlowNET.Keras/Engine/Model.Fit.cs View File

@@ -6,10 +6,12 @@ using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine.DataAdapters;
using System.Diagnostics;
using Tensorflow.Keras.Callbacks;
using System.Data;
using Tensorflow.Util;

namespace Tensorflow.Keras.Engine
{


public partial class Model
{
/// <summary>
@@ -19,19 +21,29 @@ namespace Tensorflow.Keras.Engine
/// <param name="y"></param>
/// <param name="batch_size"></param>
/// <param name="epochs"></param>
/// <param name="callbacks"></param>
/// <param name="verbose"></param>
/// <param name="callbacks"></param>
/// <param name="validation_split"></param>
/// <param name="validation_data"></param>
/// <param name="shuffle"></param>
/// <param name="class_weight"></param>
/// <param name="sample_weight"></param>
/// <param name="initial_epoch"></param>
/// <param name="max_queue_size"></param>
/// <param name="workers"></param>
/// <param name="use_multiprocessing"></param>
/// <returns></returns>
/// <exception cref="InvalidArgumentError"></exception>
public ICallback fit(NDArray x, NDArray y,
int batch_size = -1,
int epochs = 1,
int verbose = 1,
List<ICallback> callbacks = null,
float validation_split = 0f,
(NDArray val_x, NDArray val_y)? validation_data = null,
ValidationDataPack validation_data = null,
bool shuffle = true,
Dictionary<int, float> class_weight = null,
NDArray sample_weight = null,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
@@ -43,21 +55,25 @@ namespace Tensorflow.Keras.Engine
$"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}");
}

var train_x = x;
var train_y = y;
// The default dtype in NDArray is double, so we need to cast sample_weight to float to mul with loss which's dtype is float.
sample_weight = sample_weight?.astype(TF_DataType.TF_FLOAT);

if (validation_split != 0f && validation_data == null)
{
int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split));
train_x = x[new Slice(0, train_count)];
train_y = y[new Slice(0, train_count)];
validation_data = (val_x: x[new Slice(train_count)], val_y: y[new Slice(train_count)]);
((x, y, sample_weight), validation_data) = DataAdapter.train_validation_split((x, y, sample_weight), validation_split);
}

// TODO(Wanglongzhi2001)
if (class_weight != null)
{
throw new NotImplementedException("class_weight is not implemented");
}

var data_handler = new DataHandler(new DataHandlerArgs
{
X = train_x,
Y = train_y,
X = x,
Y = y,
SampleWeight = sample_weight,
BatchSize = batch_size,
InitialEpoch = initial_epoch,
Epochs = epochs,
@@ -73,14 +89,17 @@ namespace Tensorflow.Keras.Engine
train_step_func: train_step_function);
}


public ICallback fit(IEnumerable<NDArray> x, NDArray y,
int batch_size = -1,
int epochs = 1,
int verbose = 1,
List<ICallback> callbacks = null,
float validation_split = 0f,
(IEnumerable<NDArray> val_x, NDArray val_y)? validation_data = null,
ValidationDataPack validation_data = null,
bool shuffle = true,
Dictionary<int, float> class_weight = null,
NDArray sample_weight = null,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
@@ -95,27 +114,23 @@ namespace Tensorflow.Keras.Engine
}
}

var train_x = x;
var train_y = y;
sample_weight = sample_weight?.astype(TF_DataType.TF_FLOAT);
if (validation_split != 0f && validation_data == null)
{
int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split));
train_x = x.Select(x => x[new Slice(0, train_count)] as NDArray);
train_y = y[new Slice(0, train_count)];
var val_x = x.Select(x => x[new Slice(train_count)] as NDArray);
var val_y = y[new Slice(train_count)];
validation_data = (val_x, val_y);
((x, y, sample_weight), validation_data) = DataAdapter.train_validation_split((x, y, sample_weight), validation_split);
}


var data_handler = new DataHandler(new DataHandlerArgs
{
X = new Tensors(train_x.ToArray()),
Y = train_y,
X = new Tensors(x.ToArray()),
Y = y,
BatchSize = batch_size,
InitialEpoch = initial_epoch,
Epochs = epochs,
Shuffle = shuffle,
SampleWeight = sample_weight,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
@@ -142,8 +157,10 @@ namespace Tensorflow.Keras.Engine
int verbose = 1,
List<ICallback> callbacks = null,
IDatasetV2 validation_data = null,
int validation_step = 10, // 间隔多少次会进行一次验证
int validation_step = 10,
bool shuffle = true,
Dictionary<int, float> class_weight = null,
NDArray sample_weight = null,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
@@ -210,7 +227,7 @@ namespace Tensorflow.Keras.Engine
{
if (validation_step > 0 && epoch ==0 || (epoch) % validation_step != 0)
continue;
var val_logs = evaluate(validation_data);
foreach(var log in val_logs)
{
@@ -233,7 +250,7 @@ namespace Tensorflow.Keras.Engine
return callbacks.History;
}

History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, (NDArray, NDArray)? validation_data,
History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, ValidationDataPack validation_data,
Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func)
{
stop_training = false;
@@ -274,7 +291,8 @@ namespace Tensorflow.Keras.Engine
{
// Because evaluate calls call_test_batch_end, this interferes with our output on the screen
// so we need to pass a is_val parameter to stop on_test_batch_end
var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2, is_val:true);
var (val_x, val_y, val_sample_weight) = validation_data;
var val_logs = evaluate(val_x, val_y, sample_weight:val_sample_weight, is_val:true);
foreach (var log in val_logs)
{
logs["val_" + log.Key] = log.Value;
@@ -296,64 +314,5 @@ namespace Tensorflow.Keras.Engine
return callbacks.History;
}

History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, (IEnumerable<Tensor>, NDArray)? validation_data,
Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func)
{
stop_training = false;
_train_counter.assign(0);
var callbacks = new CallbackList(new CallbackParams
{
Model = this,
Verbose = verbose,
Epochs = epochs,
Steps = data_handler.Inferredsteps
});

if (callbackList != null)
{
foreach (var callback in callbackList)
callbacks.callbacks.add(callback);
}

callbacks.on_train_begin();

foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
reset_metrics();
callbacks.on_epoch_begin(epoch);
// data_handler.catch_stop_iteration();
var logs = new Dictionary<string, float>();
long End_step = 0;
foreach (var step in data_handler.steps())
{
callbacks.on_train_batch_begin(step);
logs = train_step_func(data_handler, iterator);
var end_step = step + data_handler.StepIncrement;
End_step = end_step;
callbacks.on_train_batch_end(end_step, logs);
}

if (validation_data != null)
{
var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2);
foreach (var log in val_logs)
{
logs["val_" + log.Key] = log.Value;
callbacks.on_train_batch_end(End_step, logs);
}
}

callbacks.on_epoch_end(epoch, logs);

GC.Collect();
GC.WaitForPendingFinalizers();
if (stop_training)
{
break;
}
}

return callbacks.History;
}
}
}

+ 38
- 2
src/TensorFlowNET.Keras/Engine/Model.Train.cs View File

@@ -12,7 +12,9 @@ namespace Tensorflow.Keras.Engine
Dictionary<string, float> train_step_function(DataHandler data_handler, OwnedIterator iterator)
{
var data = iterator.next();
var outputs = train_step(data_handler, data[0], data[1]);
// whether have sample_weight
var outputs = data.Length == 2 ? train_step(data_handler, data[0], data[1]) :
train_step(data_handler, data[0], data[1], data[2]);
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
return outputs;
}
@@ -21,7 +23,13 @@ namespace Tensorflow.Keras.Engine
{
var data = iterator.next();
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
var outputs = train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray()));
var outputs = data.Length == 2 ?
train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())) :
train_step(
data_handler,
new Tensors(data.Take(x_size).ToArray()),
new Tensors(data.Skip(x_size).Take(x_size).ToArray()),
new Tensors(data.Skip(2 * x_size).ToArray()));
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
return outputs;
}
@@ -61,6 +69,34 @@ namespace Tensorflow.Keras.Engine
});
return dict;
}
Dictionary<string, float> train_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight = null)
{
(x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight);
using var tape = tf.GradientTape();
var y_pred = Apply(x, training: true);
var loss = compiled_loss.Call(y, y_pred, sample_weight:sample_weight);

// For custom training steps, users can just write:
// trainable_variables = self.trainable_variables
// gradients = tape.gradient(loss, trainable_variables)
// self.optimizer.apply_gradients(zip(gradients, trainable_variables))
// The _minimize call does a few extra steps unnecessary in most cases,
// such as loss scaling and gradient clipping.
_minimize(tape, optimizer, loss, TrainableVariables);
compiled_metrics.update_state(y, y_pred);

var dict = new Dictionary<string, float>();
metrics.ToList().ForEach(x =>
{
var r = x.result();
if (r.ndim > 0)
{
r = tf.reduce_mean(r);
}
dict[x.Name] = (float)r;
});
return dict;
}

void _minimize(GradientTape tape, IOptimizer optimizer, Tensor loss, List<IVariableV1> trainable_variables)
{


+ 2
- 2
test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs View File

@@ -74,8 +74,8 @@ namespace Tensorflow.Keras.UnitTest.Layers
OneHot = true,
ValidationSize = 55000,
}).Result;
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 1);
var sample_weight = np.ones(((int)dataset.Train.Data.shape[0]));
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 1, sample_weight:sample_weight);
}

[TestMethod]


Loading…
Cancel
Save