feat: add the implementation of sample_weight in model.fittags/v0.150.0-BERT-Model
@@ -1,5 +1,6 @@ | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Saving; | |||
using Tensorflow.NumPy; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
@@ -16,5 +17,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||
public int Worker { get; set; } | |||
public bool UseMultiprocessing { get; set; } | |||
public IModel Model { get; set; } | |||
public Dictionary<int, float> ClassWeight = null; | |||
public NDArray SampleWeight = null; | |||
} | |||
} |
@@ -1,5 +1,6 @@ | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Saving; | |||
using Tensorflow.NumPy; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
@@ -18,5 +19,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||
public bool UseMultiprocessing { get; set; } = false; | |||
public IModel Model { get; set; } | |||
public IVariableV1 StepsPerExecution { get; set; } | |||
public Dictionary<int, float> ClassWeight = null; | |||
public NDArray SampleWeight = null; | |||
} | |||
} |
@@ -3,6 +3,7 @@ using Tensorflow.Keras.Losses; | |||
using Tensorflow.Keras.Metrics; | |||
using Tensorflow.Keras.Saving; | |||
using Tensorflow.NumPy; | |||
using Tensorflow.Util; | |||
namespace Tensorflow.Keras.Engine; | |||
@@ -22,8 +23,10 @@ public interface IModel : ILayer | |||
int verbose = 1, | |||
List<ICallback> callbacks = null, | |||
float validation_split = 0f, | |||
(NDArray val_x, NDArray val_y)? validation_data = null, | |||
ValidationDataPack validation_data = null, | |||
bool shuffle = true, | |||
Dictionary<int, float> class_weight = null, | |||
NDArray sample_weight = null, | |||
int initial_epoch = 0, | |||
int max_queue_size = 10, | |||
int workers = 1, | |||
@@ -35,8 +38,10 @@ public interface IModel : ILayer | |||
int verbose = 1, | |||
List<ICallback> callbacks = null, | |||
float validation_split = 0f, | |||
(IEnumerable<NDArray> val_x, NDArray val_y)? validation_data = null, | |||
ValidationDataPack validation_data = null, | |||
bool shuffle = true, | |||
Dictionary<int, float> class_weight = null, | |||
NDArray sample_weight = null, | |||
int initial_epoch = 0, | |||
int max_queue_size = 10, | |||
int workers = 1, | |||
@@ -63,6 +68,8 @@ public interface IModel : ILayer | |||
Dictionary<string, float> evaluate(NDArray x, NDArray y, | |||
int batch_size = -1, | |||
int verbose = 1, | |||
NDArray sample_weight = null, | |||
int steps = -1, | |||
int max_queue_size = 10, | |||
int workers = 1, | |||
@@ -0,0 +1,66 @@ | |||
using Tensorflow.NumPy; | |||
namespace Tensorflow.Util | |||
{ | |||
/// <summary> | |||
/// ValidationDataPack is used to pass validation data to fit method. | |||
/// It can recive data which could be A tuple `(x_val, xy_val)` or `(x_val, y_val, sample_weight_val)` of Numpy arrays. | |||
/// </summary> | |||
public class ValidationDataPack | |||
{ | |||
public NDArray val_x; | |||
public NDArray val_y; | |||
public NDArray val_sample_weight = null; | |||
public ValidationDataPack((NDArray, NDArray) validation_data) | |||
{ | |||
this.val_x = validation_data.Item1; | |||
this.val_y = validation_data.Item2; | |||
} | |||
public ValidationDataPack((NDArray, NDArray, NDArray) validation_data) | |||
{ | |||
this.val_x = validation_data.Item1; | |||
this.val_y = validation_data.Item2; | |||
this.val_sample_weight = validation_data.Item3; | |||
} | |||
public ValidationDataPack((IEnumerable<NDArray>, NDArray) validation_data) | |||
{ | |||
this.val_x = validation_data.Item1.ToArray()[0]; | |||
this.val_y = validation_data.Item2; | |||
} | |||
public ValidationDataPack((IEnumerable<NDArray>, NDArray, NDArray) validation_data) | |||
{ | |||
this.val_x = validation_data.Item1.ToArray()[0]; | |||
this.val_y = validation_data.Item2; | |||
this.val_sample_weight = validation_data.Item3; | |||
} | |||
public static implicit operator ValidationDataPack((NDArray, NDArray) validation_data) | |||
=> new ValidationDataPack(validation_data); | |||
public static implicit operator ValidationDataPack((NDArray, NDArray, NDArray) validation_data) | |||
=> new ValidationDataPack(validation_data); | |||
public static implicit operator ValidationDataPack((IEnumerable<NDArray>, NDArray) validation_data) | |||
=> new ValidationDataPack(validation_data); | |||
public static implicit operator ValidationDataPack((IEnumerable<NDArray>, NDArray, NDArray) validation_data) | |||
=> new ValidationDataPack(validation_data); | |||
public void Deconstruct(out NDArray val_x, out NDArray val_y) | |||
{ | |||
val_x = this.val_x; | |||
val_y = this.val_y; | |||
} | |||
public void Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight) | |||
{ | |||
val_x = this.val_x; | |||
val_y = this.val_y; | |||
val_sample_weight = this.val_sample_weight; | |||
} | |||
} | |||
} |
@@ -2,6 +2,7 @@ | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Util; | |||
namespace Tensorflow.Keras.Engine.DataAdapters | |||
{ | |||
@@ -34,9 +35,67 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
return (x, y); | |||
} | |||
public virtual (Tensors, Tensors, Tensors) Expand1d(Tensors x, Tensors y, Tensors sample_weight) | |||
{ | |||
for (int i = 0; i < x.Length; i++) | |||
{ | |||
if (x[i].shape.ndim == 1) | |||
x[i] = array_ops.expand_dims(x[i], axis: -1); | |||
} | |||
for (int i = 0; i < y.Length; i++) | |||
{ | |||
if (y[i].shape.ndim == 1) | |||
y[i] = array_ops.expand_dims(y[i], axis: -1); | |||
} | |||
for (int i = 0; i < sample_weight.Length; i++) | |||
{ | |||
if (sample_weight[i].shape.ndim == 1) | |||
sample_weight[i] = array_ops.expand_dims(sample_weight[i], axis: -1); | |||
} | |||
return (x, y, sample_weight); | |||
} | |||
public virtual bool ShouldRecreateIterator() | |||
{ | |||
return true; | |||
} | |||
public static ((NDArray, NDArray, NDArray),ValidationDataPack) train_validation_split((NDArray, NDArray, NDArray) x_y_sample_weight, float validation_split) | |||
{ | |||
var x = x_y_sample_weight.Item1; | |||
var y = x_y_sample_weight.Item2; | |||
var sample_weight = x_y_sample_weight.Item3; | |||
int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split)); | |||
var train_x = x[new Slice(0, train_count)]; | |||
var train_y = y[new Slice(0, train_count)]; | |||
ValidationDataPack validation_data; | |||
if (sample_weight != null) | |||
{ | |||
validation_data = (x[new Slice(train_count)], y[new Slice(train_count)], sample_weight[new Slice(train_count)]); | |||
sample_weight = sample_weight[new Slice(0, train_count)]; | |||
} | |||
else | |||
{ | |||
validation_data = (x[new Slice(train_count)], y[new Slice(train_count)]); | |||
} | |||
return ((train_x, train_y, sample_weight), validation_data); | |||
} | |||
public static ((IEnumerable<NDArray>, NDArray, NDArray), ValidationDataPack) train_validation_split((IEnumerable<NDArray>, NDArray, NDArray) x_y_sample_weight, float validation_split) | |||
{ | |||
var x = x_y_sample_weight.Item1; | |||
var y = x_y_sample_weight.Item2; | |||
var sample_weight = x_y_sample_weight.Item3; | |||
int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split)); | |||
var train_x = x.Select(x => x[new Slice(0, train_count)] as NDArray); | |||
var train_y = y[new Slice(0, train_count)]; | |||
var val_x = x.Select(x => x[new Slice(train_count)] as NDArray); | |||
var val_y = y[new Slice(train_count)]; | |||
NDArray tmp_sample_weight = sample_weight; | |||
sample_weight = sample_weight[new Slice(0, train_count)]; | |||
ValidationDataPack validation_data = (val_x, val_y, tmp_sample_weight[new Slice(train_count)]); | |||
return ((train_x, train_y, sample_weight), validation_data); | |||
} | |||
} | |||
} |
@@ -2,6 +2,7 @@ | |||
using System.Collections.Generic; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using static Tensorflow.Binding; | |||
using Tensorflow.Keras.Utils; | |||
namespace Tensorflow.Keras.Engine.DataAdapters | |||
{ | |||
@@ -28,6 +29,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
public DataHandler(DataHandlerArgs args) | |||
{ | |||
this.args = args; | |||
if (args.StepsPerExecution == null) | |||
{ | |||
_steps_per_execution = tf.Variable(1L); | |||
@@ -48,6 +50,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
BatchSize = args.BatchSize, | |||
Steps = args.StepsPerEpoch, | |||
Epochs = args.Epochs - args.InitialEpoch, | |||
SampleWeight = args.SampleWeight, | |||
Shuffle = args.Shuffle, | |||
MaxQueueSize = args.MaxQueueSize, | |||
Worker = args.Workers, | |||
@@ -17,6 +17,8 @@ | |||
IDatasetV2 GetDataset(); | |||
int GetSize(); | |||
(Tensors, Tensors) Expand1d(Tensors x, Tensors y); | |||
(Tensors, Tensors, Tensors) Expand1d(Tensors x, Tensors y, Tensors sample_weight); | |||
bool ShouldRecreateIterator(); | |||
} | |||
} |
@@ -20,7 +20,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
public TensorLikeDataAdapter(DataAdapterArgs args) | |||
{ | |||
this.args = args; | |||
_process_tensorlike(); | |||
Tensor sample_weight_tensor = args.SampleWeight != null ? _process_tensorlike(args.SampleWeight) : null; | |||
num_samples = (int)args.X.shape[0]; | |||
var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize; | |||
_batch_size = batch_size; | |||
@@ -37,6 +37,8 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
inputs.AddRange(args.X); | |||
if (args.Y != null) | |||
inputs.AddRange(args.Y); | |||
if (sample_weight_tensor != null) | |||
inputs.Add(sample_weight_tensor); | |||
dataset = slice_inputs(indices_dataset, inputs); | |||
dataset.FirstInputTensorCount = args.X.Length; | |||
} | |||
@@ -94,8 +96,9 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
public override bool ShouldRecreateIterator() => false; | |||
void _process_tensorlike() | |||
Tensor _process_tensorlike(NDArray sample_weights) | |||
{ | |||
return tf.convert_to_tensor(sample_weights); | |||
} | |||
} | |||
} |
@@ -26,11 +26,11 @@ namespace Tensorflow.Keras.Engine | |||
/// </summary> | |||
/// <param name="y_true"></param> | |||
/// <param name="y_pred"></param> | |||
public Tensor Call(Tensor y_true, Tensor y_pred) | |||
public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) | |||
{ | |||
if (!_built) | |||
Build(y_pred); | |||
var loss_value = _losses.Call(y_true, y_pred); | |||
var loss_value = _losses.Call(y_true, y_pred, sample_weight:sample_weight); | |||
var loss_metric_value = loss_value; | |||
var batch_dim = array_ops.shape(y_true)[0]; | |||
@@ -30,6 +30,7 @@ namespace Tensorflow.Keras.Engine | |||
public Dictionary<string, float> evaluate(NDArray x, NDArray y, | |||
int batch_size = -1, | |||
int verbose = 1, | |||
NDArray sample_weight = null, | |||
int steps = -1, | |||
int max_queue_size = 10, | |||
int workers = 1, | |||
@@ -51,6 +52,7 @@ namespace Tensorflow.Keras.Engine | |||
StepsPerEpoch = steps, | |||
InitialEpoch = 0, | |||
Epochs = 1, | |||
SampleWeight = sample_weight, | |||
MaxQueueSize = max_queue_size, | |||
Workers = workers, | |||
UseMultiprocessing = use_multiprocessing, | |||
@@ -140,7 +142,8 @@ namespace Tensorflow.Keras.Engine | |||
Dictionary<string, float> test_function(DataHandler data_handler, OwnedIterator iterator) | |||
{ | |||
var data = iterator.next(); | |||
var outputs = test_step(data_handler, data[0], data[1]); | |||
var outputs = data.Length == 2 ? test_step(data_handler, data[0], data[1]) : | |||
test_step(data_handler, data[0], data[1], data[2]); | |||
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); | |||
return outputs; | |||
} | |||
@@ -149,17 +152,23 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
var data = iterator.next(); | |||
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; | |||
var outputs = test_step(data_handler, data.Take(x_size).ToArray(), data.Skip(x_size).ToArray()); | |||
var outputs = data.Length == 2 ? | |||
test_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())) : | |||
test_step( | |||
data_handler, | |||
new Tensors(data.Take(x_size).ToArray()), | |||
new Tensors(data.Skip(x_size).Take(x_size).ToArray()), | |||
new Tensors(data.Skip(2 * x_size).ToArray())); | |||
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); | |||
return outputs; | |||
} | |||
Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y) | |||
Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight = null) | |||
{ | |||
(x, y) = data_handler.DataAdapter.Expand1d(x, y); | |||
(x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight); | |||
var y_pred = Apply(x, training: false); | |||
var loss = compiled_loss.Call(y, y_pred); | |||
var loss = compiled_loss.Call(y, y_pred, sample_weight:sample_weight); | |||
compiled_metrics.update_state(y, y_pred); | |||
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2); | |||
} | |||
@@ -6,10 +6,12 @@ using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine.DataAdapters; | |||
using System.Diagnostics; | |||
using Tensorflow.Keras.Callbacks; | |||
using System.Data; | |||
using Tensorflow.Util; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
public partial class Model | |||
{ | |||
/// <summary> | |||
@@ -19,19 +21,29 @@ namespace Tensorflow.Keras.Engine | |||
/// <param name="y"></param> | |||
/// <param name="batch_size"></param> | |||
/// <param name="epochs"></param> | |||
/// <param name="callbacks"></param> | |||
/// <param name="verbose"></param> | |||
/// <param name="callbacks"></param> | |||
/// <param name="validation_split"></param> | |||
/// <param name="validation_data"></param> | |||
/// <param name="shuffle"></param> | |||
/// <param name="class_weight"></param> | |||
/// <param name="sample_weight"></param> | |||
/// <param name="initial_epoch"></param> | |||
/// <param name="max_queue_size"></param> | |||
/// <param name="workers"></param> | |||
/// <param name="use_multiprocessing"></param> | |||
/// <returns></returns> | |||
/// <exception cref="InvalidArgumentError"></exception> | |||
public ICallback fit(NDArray x, NDArray y, | |||
int batch_size = -1, | |||
int epochs = 1, | |||
int verbose = 1, | |||
List<ICallback> callbacks = null, | |||
float validation_split = 0f, | |||
(NDArray val_x, NDArray val_y)? validation_data = null, | |||
ValidationDataPack validation_data = null, | |||
bool shuffle = true, | |||
Dictionary<int, float> class_weight = null, | |||
NDArray sample_weight = null, | |||
int initial_epoch = 0, | |||
int max_queue_size = 10, | |||
int workers = 1, | |||
@@ -43,21 +55,25 @@ namespace Tensorflow.Keras.Engine | |||
$"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}"); | |||
} | |||
var train_x = x; | |||
var train_y = y; | |||
// The default dtype in NDArray is double, so we need to cast sample_weight to float to mul with loss which's dtype is float. | |||
sample_weight = sample_weight?.astype(TF_DataType.TF_FLOAT); | |||
if (validation_split != 0f && validation_data == null) | |||
{ | |||
int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split)); | |||
train_x = x[new Slice(0, train_count)]; | |||
train_y = y[new Slice(0, train_count)]; | |||
validation_data = (val_x: x[new Slice(train_count)], val_y: y[new Slice(train_count)]); | |||
((x, y, sample_weight), validation_data) = DataAdapter.train_validation_split((x, y, sample_weight), validation_split); | |||
} | |||
// TODO(Wanglongzhi2001) | |||
if (class_weight != null) | |||
{ | |||
throw new NotImplementedException("class_weight is not implemented"); | |||
} | |||
var data_handler = new DataHandler(new DataHandlerArgs | |||
{ | |||
X = train_x, | |||
Y = train_y, | |||
X = x, | |||
Y = y, | |||
SampleWeight = sample_weight, | |||
BatchSize = batch_size, | |||
InitialEpoch = initial_epoch, | |||
Epochs = epochs, | |||
@@ -73,14 +89,17 @@ namespace Tensorflow.Keras.Engine | |||
train_step_func: train_step_function); | |||
} | |||
public ICallback fit(IEnumerable<NDArray> x, NDArray y, | |||
int batch_size = -1, | |||
int epochs = 1, | |||
int verbose = 1, | |||
List<ICallback> callbacks = null, | |||
float validation_split = 0f, | |||
(IEnumerable<NDArray> val_x, NDArray val_y)? validation_data = null, | |||
ValidationDataPack validation_data = null, | |||
bool shuffle = true, | |||
Dictionary<int, float> class_weight = null, | |||
NDArray sample_weight = null, | |||
int initial_epoch = 0, | |||
int max_queue_size = 10, | |||
int workers = 1, | |||
@@ -95,27 +114,23 @@ namespace Tensorflow.Keras.Engine | |||
} | |||
} | |||
var train_x = x; | |||
var train_y = y; | |||
sample_weight = sample_weight?.astype(TF_DataType.TF_FLOAT); | |||
if (validation_split != 0f && validation_data == null) | |||
{ | |||
int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split)); | |||
train_x = x.Select(x => x[new Slice(0, train_count)] as NDArray); | |||
train_y = y[new Slice(0, train_count)]; | |||
var val_x = x.Select(x => x[new Slice(train_count)] as NDArray); | |||
var val_y = y[new Slice(train_count)]; | |||
validation_data = (val_x, val_y); | |||
((x, y, sample_weight), validation_data) = DataAdapter.train_validation_split((x, y, sample_weight), validation_split); | |||
} | |||
var data_handler = new DataHandler(new DataHandlerArgs | |||
{ | |||
X = new Tensors(train_x.ToArray()), | |||
Y = train_y, | |||
X = new Tensors(x.ToArray()), | |||
Y = y, | |||
BatchSize = batch_size, | |||
InitialEpoch = initial_epoch, | |||
Epochs = epochs, | |||
Shuffle = shuffle, | |||
SampleWeight = sample_weight, | |||
MaxQueueSize = max_queue_size, | |||
Workers = workers, | |||
UseMultiprocessing = use_multiprocessing, | |||
@@ -142,8 +157,10 @@ namespace Tensorflow.Keras.Engine | |||
int verbose = 1, | |||
List<ICallback> callbacks = null, | |||
IDatasetV2 validation_data = null, | |||
int validation_step = 10, // 间隔多少次会进行一次验证 | |||
int validation_step = 10, | |||
bool shuffle = true, | |||
Dictionary<int, float> class_weight = null, | |||
NDArray sample_weight = null, | |||
int initial_epoch = 0, | |||
int max_queue_size = 10, | |||
int workers = 1, | |||
@@ -210,7 +227,7 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
if (validation_step > 0 && epoch ==0 || (epoch) % validation_step != 0) | |||
continue; | |||
var val_logs = evaluate(validation_data); | |||
foreach(var log in val_logs) | |||
{ | |||
@@ -233,7 +250,7 @@ namespace Tensorflow.Keras.Engine | |||
return callbacks.History; | |||
} | |||
History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, (NDArray, NDArray)? validation_data, | |||
History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, ValidationDataPack validation_data, | |||
Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func) | |||
{ | |||
stop_training = false; | |||
@@ -274,7 +291,8 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
// Because evaluate calls call_test_batch_end, this interferes with our output on the screen | |||
// so we need to pass a is_val parameter to stop on_test_batch_end | |||
var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2, is_val:true); | |||
var (val_x, val_y, val_sample_weight) = validation_data; | |||
var val_logs = evaluate(val_x, val_y, sample_weight:val_sample_weight, is_val:true); | |||
foreach (var log in val_logs) | |||
{ | |||
logs["val_" + log.Key] = log.Value; | |||
@@ -296,64 +314,5 @@ namespace Tensorflow.Keras.Engine | |||
return callbacks.History; | |||
} | |||
History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, (IEnumerable<Tensor>, NDArray)? validation_data, | |||
Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func) | |||
{ | |||
stop_training = false; | |||
_train_counter.assign(0); | |||
var callbacks = new CallbackList(new CallbackParams | |||
{ | |||
Model = this, | |||
Verbose = verbose, | |||
Epochs = epochs, | |||
Steps = data_handler.Inferredsteps | |||
}); | |||
if (callbackList != null) | |||
{ | |||
foreach (var callback in callbackList) | |||
callbacks.callbacks.add(callback); | |||
} | |||
callbacks.on_train_begin(); | |||
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||
{ | |||
reset_metrics(); | |||
callbacks.on_epoch_begin(epoch); | |||
// data_handler.catch_stop_iteration(); | |||
var logs = new Dictionary<string, float>(); | |||
long End_step = 0; | |||
foreach (var step in data_handler.steps()) | |||
{ | |||
callbacks.on_train_batch_begin(step); | |||
logs = train_step_func(data_handler, iterator); | |||
var end_step = step + data_handler.StepIncrement; | |||
End_step = end_step; | |||
callbacks.on_train_batch_end(end_step, logs); | |||
} | |||
if (validation_data != null) | |||
{ | |||
var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2); | |||
foreach (var log in val_logs) | |||
{ | |||
logs["val_" + log.Key] = log.Value; | |||
callbacks.on_train_batch_end(End_step, logs); | |||
} | |||
} | |||
callbacks.on_epoch_end(epoch, logs); | |||
GC.Collect(); | |||
GC.WaitForPendingFinalizers(); | |||
if (stop_training) | |||
{ | |||
break; | |||
} | |||
} | |||
return callbacks.History; | |||
} | |||
} | |||
} |
@@ -12,7 +12,9 @@ namespace Tensorflow.Keras.Engine | |||
Dictionary<string, float> train_step_function(DataHandler data_handler, OwnedIterator iterator) | |||
{ | |||
var data = iterator.next(); | |||
var outputs = train_step(data_handler, data[0], data[1]); | |||
// whether have sample_weight | |||
var outputs = data.Length == 2 ? train_step(data_handler, data[0], data[1]) : | |||
train_step(data_handler, data[0], data[1], data[2]); | |||
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | |||
return outputs; | |||
} | |||
@@ -21,7 +23,13 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
var data = iterator.next(); | |||
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; | |||
var outputs = train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())); | |||
var outputs = data.Length == 2 ? | |||
train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())) : | |||
train_step( | |||
data_handler, | |||
new Tensors(data.Take(x_size).ToArray()), | |||
new Tensors(data.Skip(x_size).Take(x_size).ToArray()), | |||
new Tensors(data.Skip(2 * x_size).ToArray())); | |||
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | |||
return outputs; | |||
} | |||
@@ -61,6 +69,34 @@ namespace Tensorflow.Keras.Engine | |||
}); | |||
return dict; | |||
} | |||
Dictionary<string, float> train_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight = null) | |||
{ | |||
(x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight); | |||
using var tape = tf.GradientTape(); | |||
var y_pred = Apply(x, training: true); | |||
var loss = compiled_loss.Call(y, y_pred, sample_weight:sample_weight); | |||
// For custom training steps, users can just write: | |||
// trainable_variables = self.trainable_variables | |||
// gradients = tape.gradient(loss, trainable_variables) | |||
// self.optimizer.apply_gradients(zip(gradients, trainable_variables)) | |||
// The _minimize call does a few extra steps unnecessary in most cases, | |||
// such as loss scaling and gradient clipping. | |||
_minimize(tape, optimizer, loss, TrainableVariables); | |||
compiled_metrics.update_state(y, y_pred); | |||
var dict = new Dictionary<string, float>(); | |||
metrics.ToList().ForEach(x => | |||
{ | |||
var r = x.result(); | |||
if (r.ndim > 0) | |||
{ | |||
r = tf.reduce_mean(r); | |||
} | |||
dict[x.Name] = (float)r; | |||
}); | |||
return dict; | |||
} | |||
void _minimize(GradientTape tape, IOptimizer optimizer, Tensor loss, List<IVariableV1> trainable_variables) | |||
{ | |||
@@ -74,8 +74,8 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||
OneHot = true, | |||
ValidationSize = 55000, | |||
}).Result; | |||
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 1); | |||
var sample_weight = np.ones(((int)dataset.Train.Data.shape[0])); | |||
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 1, sample_weight:sample_weight); | |||
} | |||
[TestMethod] | |||