feat: add the implementation of sample_weight in model.fittags/v0.150.0-BERT-Model
@@ -1,5 +1,6 @@ | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
using Tensorflow.NumPy; | |||||
namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
{ | { | ||||
@@ -16,5 +17,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
public int Worker { get; set; } | public int Worker { get; set; } | ||||
public bool UseMultiprocessing { get; set; } | public bool UseMultiprocessing { get; set; } | ||||
public IModel Model { get; set; } | public IModel Model { get; set; } | ||||
public Dictionary<int, float> ClassWeight = null; | |||||
public NDArray SampleWeight = null; | |||||
} | } | ||||
} | } |
@@ -1,5 +1,6 @@ | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
using Tensorflow.NumPy; | |||||
namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
{ | { | ||||
@@ -18,5 +19,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
public bool UseMultiprocessing { get; set; } = false; | public bool UseMultiprocessing { get; set; } = false; | ||||
public IModel Model { get; set; } | public IModel Model { get; set; } | ||||
public IVariableV1 StepsPerExecution { get; set; } | public IVariableV1 StepsPerExecution { get; set; } | ||||
public Dictionary<int, float> ClassWeight = null; | |||||
public NDArray SampleWeight = null; | |||||
} | } | ||||
} | } |
@@ -3,6 +3,7 @@ using Tensorflow.Keras.Losses; | |||||
using Tensorflow.Keras.Metrics; | using Tensorflow.Keras.Metrics; | ||||
using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using Tensorflow.Util; | |||||
namespace Tensorflow.Keras.Engine; | namespace Tensorflow.Keras.Engine; | ||||
@@ -22,8 +23,10 @@ public interface IModel : ILayer | |||||
int verbose = 1, | int verbose = 1, | ||||
List<ICallback> callbacks = null, | List<ICallback> callbacks = null, | ||||
float validation_split = 0f, | float validation_split = 0f, | ||||
(NDArray val_x, NDArray val_y)? validation_data = null, | |||||
ValidationDataPack validation_data = null, | |||||
bool shuffle = true, | bool shuffle = true, | ||||
Dictionary<int, float> class_weight = null, | |||||
NDArray sample_weight = null, | |||||
int initial_epoch = 0, | int initial_epoch = 0, | ||||
int max_queue_size = 10, | int max_queue_size = 10, | ||||
int workers = 1, | int workers = 1, | ||||
@@ -35,8 +38,10 @@ public interface IModel : ILayer | |||||
int verbose = 1, | int verbose = 1, | ||||
List<ICallback> callbacks = null, | List<ICallback> callbacks = null, | ||||
float validation_split = 0f, | float validation_split = 0f, | ||||
(IEnumerable<NDArray> val_x, NDArray val_y)? validation_data = null, | |||||
ValidationDataPack validation_data = null, | |||||
bool shuffle = true, | bool shuffle = true, | ||||
Dictionary<int, float> class_weight = null, | |||||
NDArray sample_weight = null, | |||||
int initial_epoch = 0, | int initial_epoch = 0, | ||||
int max_queue_size = 10, | int max_queue_size = 10, | ||||
int workers = 1, | int workers = 1, | ||||
@@ -63,6 +68,8 @@ public interface IModel : ILayer | |||||
Dictionary<string, float> evaluate(NDArray x, NDArray y, | Dictionary<string, float> evaluate(NDArray x, NDArray y, | ||||
int batch_size = -1, | int batch_size = -1, | ||||
int verbose = 1, | int verbose = 1, | ||||
NDArray sample_weight = null, | |||||
int steps = -1, | int steps = -1, | ||||
int max_queue_size = 10, | int max_queue_size = 10, | ||||
int workers = 1, | int workers = 1, | ||||
@@ -0,0 +1,66 @@ | |||||
using Tensorflow.NumPy; | |||||
namespace Tensorflow.Util | |||||
{ | |||||
/// <summary> | |||||
/// ValidationDataPack is used to pass validation data to fit method. | |||||
/// It can recive data which could be A tuple `(x_val, xy_val)` or `(x_val, y_val, sample_weight_val)` of Numpy arrays. | |||||
/// </summary> | |||||
public class ValidationDataPack | |||||
{ | |||||
public NDArray val_x; | |||||
public NDArray val_y; | |||||
public NDArray val_sample_weight = null; | |||||
public ValidationDataPack((NDArray, NDArray) validation_data) | |||||
{ | |||||
this.val_x = validation_data.Item1; | |||||
this.val_y = validation_data.Item2; | |||||
} | |||||
public ValidationDataPack((NDArray, NDArray, NDArray) validation_data) | |||||
{ | |||||
this.val_x = validation_data.Item1; | |||||
this.val_y = validation_data.Item2; | |||||
this.val_sample_weight = validation_data.Item3; | |||||
} | |||||
public ValidationDataPack((IEnumerable<NDArray>, NDArray) validation_data) | |||||
{ | |||||
this.val_x = validation_data.Item1.ToArray()[0]; | |||||
this.val_y = validation_data.Item2; | |||||
} | |||||
public ValidationDataPack((IEnumerable<NDArray>, NDArray, NDArray) validation_data) | |||||
{ | |||||
this.val_x = validation_data.Item1.ToArray()[0]; | |||||
this.val_y = validation_data.Item2; | |||||
this.val_sample_weight = validation_data.Item3; | |||||
} | |||||
public static implicit operator ValidationDataPack((NDArray, NDArray) validation_data) | |||||
=> new ValidationDataPack(validation_data); | |||||
public static implicit operator ValidationDataPack((NDArray, NDArray, NDArray) validation_data) | |||||
=> new ValidationDataPack(validation_data); | |||||
public static implicit operator ValidationDataPack((IEnumerable<NDArray>, NDArray) validation_data) | |||||
=> new ValidationDataPack(validation_data); | |||||
public static implicit operator ValidationDataPack((IEnumerable<NDArray>, NDArray, NDArray) validation_data) | |||||
=> new ValidationDataPack(validation_data); | |||||
public void Deconstruct(out NDArray val_x, out NDArray val_y) | |||||
{ | |||||
val_x = this.val_x; | |||||
val_y = this.val_y; | |||||
} | |||||
public void Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight) | |||||
{ | |||||
val_x = this.val_x; | |||||
val_y = this.val_y; | |||||
val_sample_weight = this.val_sample_weight; | |||||
} | |||||
} | |||||
} |
@@ -2,6 +2,7 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Util; | |||||
namespace Tensorflow.Keras.Engine.DataAdapters | namespace Tensorflow.Keras.Engine.DataAdapters | ||||
{ | { | ||||
@@ -34,9 +35,67 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
return (x, y); | return (x, y); | ||||
} | } | ||||
public virtual (Tensors, Tensors, Tensors) Expand1d(Tensors x, Tensors y, Tensors sample_weight) | |||||
{ | |||||
for (int i = 0; i < x.Length; i++) | |||||
{ | |||||
if (x[i].shape.ndim == 1) | |||||
x[i] = array_ops.expand_dims(x[i], axis: -1); | |||||
} | |||||
for (int i = 0; i < y.Length; i++) | |||||
{ | |||||
if (y[i].shape.ndim == 1) | |||||
y[i] = array_ops.expand_dims(y[i], axis: -1); | |||||
} | |||||
for (int i = 0; i < sample_weight.Length; i++) | |||||
{ | |||||
if (sample_weight[i].shape.ndim == 1) | |||||
sample_weight[i] = array_ops.expand_dims(sample_weight[i], axis: -1); | |||||
} | |||||
return (x, y, sample_weight); | |||||
} | |||||
public virtual bool ShouldRecreateIterator() | public virtual bool ShouldRecreateIterator() | ||||
{ | { | ||||
return true; | return true; | ||||
} | } | ||||
public static ((NDArray, NDArray, NDArray),ValidationDataPack) train_validation_split((NDArray, NDArray, NDArray) x_y_sample_weight, float validation_split) | |||||
{ | |||||
var x = x_y_sample_weight.Item1; | |||||
var y = x_y_sample_weight.Item2; | |||||
var sample_weight = x_y_sample_weight.Item3; | |||||
int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split)); | |||||
var train_x = x[new Slice(0, train_count)]; | |||||
var train_y = y[new Slice(0, train_count)]; | |||||
ValidationDataPack validation_data; | |||||
if (sample_weight != null) | |||||
{ | |||||
validation_data = (x[new Slice(train_count)], y[new Slice(train_count)], sample_weight[new Slice(train_count)]); | |||||
sample_weight = sample_weight[new Slice(0, train_count)]; | |||||
} | |||||
else | |||||
{ | |||||
validation_data = (x[new Slice(train_count)], y[new Slice(train_count)]); | |||||
} | |||||
return ((train_x, train_y, sample_weight), validation_data); | |||||
} | |||||
public static ((IEnumerable<NDArray>, NDArray, NDArray), ValidationDataPack) train_validation_split((IEnumerable<NDArray>, NDArray, NDArray) x_y_sample_weight, float validation_split) | |||||
{ | |||||
var x = x_y_sample_weight.Item1; | |||||
var y = x_y_sample_weight.Item2; | |||||
var sample_weight = x_y_sample_weight.Item3; | |||||
int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split)); | |||||
var train_x = x.Select(x => x[new Slice(0, train_count)] as NDArray); | |||||
var train_y = y[new Slice(0, train_count)]; | |||||
var val_x = x.Select(x => x[new Slice(train_count)] as NDArray); | |||||
var val_y = y[new Slice(train_count)]; | |||||
NDArray tmp_sample_weight = sample_weight; | |||||
sample_weight = sample_weight[new Slice(0, train_count)]; | |||||
ValidationDataPack validation_data = (val_x, val_y, tmp_sample_weight[new Slice(train_count)]); | |||||
return ((train_x, train_y, sample_weight), validation_data); | |||||
} | |||||
} | } | ||||
} | } |
@@ -2,6 +2,7 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using Tensorflow.Keras.Utils; | |||||
namespace Tensorflow.Keras.Engine.DataAdapters | namespace Tensorflow.Keras.Engine.DataAdapters | ||||
{ | { | ||||
@@ -28,6 +29,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
public DataHandler(DataHandlerArgs args) | public DataHandler(DataHandlerArgs args) | ||||
{ | { | ||||
this.args = args; | this.args = args; | ||||
if (args.StepsPerExecution == null) | if (args.StepsPerExecution == null) | ||||
{ | { | ||||
_steps_per_execution = tf.Variable(1L); | _steps_per_execution = tf.Variable(1L); | ||||
@@ -48,6 +50,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
BatchSize = args.BatchSize, | BatchSize = args.BatchSize, | ||||
Steps = args.StepsPerEpoch, | Steps = args.StepsPerEpoch, | ||||
Epochs = args.Epochs - args.InitialEpoch, | Epochs = args.Epochs - args.InitialEpoch, | ||||
SampleWeight = args.SampleWeight, | |||||
Shuffle = args.Shuffle, | Shuffle = args.Shuffle, | ||||
MaxQueueSize = args.MaxQueueSize, | MaxQueueSize = args.MaxQueueSize, | ||||
Worker = args.Workers, | Worker = args.Workers, | ||||
@@ -17,6 +17,8 @@ | |||||
IDatasetV2 GetDataset(); | IDatasetV2 GetDataset(); | ||||
int GetSize(); | int GetSize(); | ||||
(Tensors, Tensors) Expand1d(Tensors x, Tensors y); | (Tensors, Tensors) Expand1d(Tensors x, Tensors y); | ||||
(Tensors, Tensors, Tensors) Expand1d(Tensors x, Tensors y, Tensors sample_weight); | |||||
bool ShouldRecreateIterator(); | bool ShouldRecreateIterator(); | ||||
} | } | ||||
} | } |
@@ -20,7 +20,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
public TensorLikeDataAdapter(DataAdapterArgs args) | public TensorLikeDataAdapter(DataAdapterArgs args) | ||||
{ | { | ||||
this.args = args; | this.args = args; | ||||
_process_tensorlike(); | |||||
Tensor sample_weight_tensor = args.SampleWeight != null ? _process_tensorlike(args.SampleWeight) : null; | |||||
num_samples = (int)args.X.shape[0]; | num_samples = (int)args.X.shape[0]; | ||||
var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize; | var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize; | ||||
_batch_size = batch_size; | _batch_size = batch_size; | ||||
@@ -37,6 +37,8 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
inputs.AddRange(args.X); | inputs.AddRange(args.X); | ||||
if (args.Y != null) | if (args.Y != null) | ||||
inputs.AddRange(args.Y); | inputs.AddRange(args.Y); | ||||
if (sample_weight_tensor != null) | |||||
inputs.Add(sample_weight_tensor); | |||||
dataset = slice_inputs(indices_dataset, inputs); | dataset = slice_inputs(indices_dataset, inputs); | ||||
dataset.FirstInputTensorCount = args.X.Length; | dataset.FirstInputTensorCount = args.X.Length; | ||||
} | } | ||||
@@ -94,8 +96,9 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
public override bool ShouldRecreateIterator() => false; | public override bool ShouldRecreateIterator() => false; | ||||
void _process_tensorlike() | |||||
Tensor _process_tensorlike(NDArray sample_weights) | |||||
{ | { | ||||
return tf.convert_to_tensor(sample_weights); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -26,11 +26,11 @@ namespace Tensorflow.Keras.Engine | |||||
/// </summary> | /// </summary> | ||||
/// <param name="y_true"></param> | /// <param name="y_true"></param> | ||||
/// <param name="y_pred"></param> | /// <param name="y_pred"></param> | ||||
public Tensor Call(Tensor y_true, Tensor y_pred) | |||||
public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) | |||||
{ | { | ||||
if (!_built) | if (!_built) | ||||
Build(y_pred); | Build(y_pred); | ||||
var loss_value = _losses.Call(y_true, y_pred); | |||||
var loss_value = _losses.Call(y_true, y_pred, sample_weight:sample_weight); | |||||
var loss_metric_value = loss_value; | var loss_metric_value = loss_value; | ||||
var batch_dim = array_ops.shape(y_true)[0]; | var batch_dim = array_ops.shape(y_true)[0]; | ||||
@@ -30,6 +30,7 @@ namespace Tensorflow.Keras.Engine | |||||
public Dictionary<string, float> evaluate(NDArray x, NDArray y, | public Dictionary<string, float> evaluate(NDArray x, NDArray y, | ||||
int batch_size = -1, | int batch_size = -1, | ||||
int verbose = 1, | int verbose = 1, | ||||
NDArray sample_weight = null, | |||||
int steps = -1, | int steps = -1, | ||||
int max_queue_size = 10, | int max_queue_size = 10, | ||||
int workers = 1, | int workers = 1, | ||||
@@ -51,6 +52,7 @@ namespace Tensorflow.Keras.Engine | |||||
StepsPerEpoch = steps, | StepsPerEpoch = steps, | ||||
InitialEpoch = 0, | InitialEpoch = 0, | ||||
Epochs = 1, | Epochs = 1, | ||||
SampleWeight = sample_weight, | |||||
MaxQueueSize = max_queue_size, | MaxQueueSize = max_queue_size, | ||||
Workers = workers, | Workers = workers, | ||||
UseMultiprocessing = use_multiprocessing, | UseMultiprocessing = use_multiprocessing, | ||||
@@ -140,7 +142,8 @@ namespace Tensorflow.Keras.Engine | |||||
Dictionary<string, float> test_function(DataHandler data_handler, OwnedIterator iterator) | Dictionary<string, float> test_function(DataHandler data_handler, OwnedIterator iterator) | ||||
{ | { | ||||
var data = iterator.next(); | var data = iterator.next(); | ||||
var outputs = test_step(data_handler, data[0], data[1]); | |||||
var outputs = data.Length == 2 ? test_step(data_handler, data[0], data[1]) : | |||||
test_step(data_handler, data[0], data[1], data[2]); | |||||
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); | tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); | ||||
return outputs; | return outputs; | ||||
} | } | ||||
@@ -149,17 +152,23 @@ namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
var data = iterator.next(); | var data = iterator.next(); | ||||
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; | var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; | ||||
var outputs = test_step(data_handler, data.Take(x_size).ToArray(), data.Skip(x_size).ToArray()); | |||||
var outputs = data.Length == 2 ? | |||||
test_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())) : | |||||
test_step( | |||||
data_handler, | |||||
new Tensors(data.Take(x_size).ToArray()), | |||||
new Tensors(data.Skip(x_size).Take(x_size).ToArray()), | |||||
new Tensors(data.Skip(2 * x_size).ToArray())); | |||||
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); | tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1)); | ||||
return outputs; | return outputs; | ||||
} | } | ||||
Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y) | |||||
Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight = null) | |||||
{ | { | ||||
(x, y) = data_handler.DataAdapter.Expand1d(x, y); | |||||
(x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight); | |||||
var y_pred = Apply(x, training: false); | var y_pred = Apply(x, training: false); | ||||
var loss = compiled_loss.Call(y, y_pred); | |||||
var loss = compiled_loss.Call(y, y_pred, sample_weight:sample_weight); | |||||
compiled_metrics.update_state(y, y_pred); | compiled_metrics.update_state(y, y_pred); | ||||
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2); | return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2); | ||||
} | } | ||||
@@ -6,10 +6,12 @@ using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Engine.DataAdapters; | using Tensorflow.Keras.Engine.DataAdapters; | ||||
using System.Diagnostics; | using System.Diagnostics; | ||||
using Tensorflow.Keras.Callbacks; | using Tensorflow.Keras.Callbacks; | ||||
using System.Data; | |||||
using Tensorflow.Util; | |||||
namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
{ | { | ||||
public partial class Model | public partial class Model | ||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
@@ -19,19 +21,29 @@ namespace Tensorflow.Keras.Engine | |||||
/// <param name="y"></param> | /// <param name="y"></param> | ||||
/// <param name="batch_size"></param> | /// <param name="batch_size"></param> | ||||
/// <param name="epochs"></param> | /// <param name="epochs"></param> | ||||
/// <param name="callbacks"></param> | |||||
/// <param name="verbose"></param> | /// <param name="verbose"></param> | ||||
/// <param name="callbacks"></param> | |||||
/// <param name="validation_split"></param> | /// <param name="validation_split"></param> | ||||
/// <param name="validation_data"></param> | /// <param name="validation_data"></param> | ||||
/// <param name="shuffle"></param> | /// <param name="shuffle"></param> | ||||
/// <param name="class_weight"></param> | |||||
/// <param name="sample_weight"></param> | |||||
/// <param name="initial_epoch"></param> | |||||
/// <param name="max_queue_size"></param> | |||||
/// <param name="workers"></param> | |||||
/// <param name="use_multiprocessing"></param> | |||||
/// <returns></returns> | |||||
/// <exception cref="InvalidArgumentError"></exception> | |||||
public ICallback fit(NDArray x, NDArray y, | public ICallback fit(NDArray x, NDArray y, | ||||
int batch_size = -1, | int batch_size = -1, | ||||
int epochs = 1, | int epochs = 1, | ||||
int verbose = 1, | int verbose = 1, | ||||
List<ICallback> callbacks = null, | List<ICallback> callbacks = null, | ||||
float validation_split = 0f, | float validation_split = 0f, | ||||
(NDArray val_x, NDArray val_y)? validation_data = null, | |||||
ValidationDataPack validation_data = null, | |||||
bool shuffle = true, | bool shuffle = true, | ||||
Dictionary<int, float> class_weight = null, | |||||
NDArray sample_weight = null, | |||||
int initial_epoch = 0, | int initial_epoch = 0, | ||||
int max_queue_size = 10, | int max_queue_size = 10, | ||||
int workers = 1, | int workers = 1, | ||||
@@ -43,21 +55,25 @@ namespace Tensorflow.Keras.Engine | |||||
$"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}"); | $"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}"); | ||||
} | } | ||||
var train_x = x; | |||||
var train_y = y; | |||||
// The default dtype in NDArray is double, so we need to cast sample_weight to float to mul with loss which's dtype is float. | |||||
sample_weight = sample_weight?.astype(TF_DataType.TF_FLOAT); | |||||
if (validation_split != 0f && validation_data == null) | if (validation_split != 0f && validation_data == null) | ||||
{ | { | ||||
int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split)); | |||||
train_x = x[new Slice(0, train_count)]; | |||||
train_y = y[new Slice(0, train_count)]; | |||||
validation_data = (val_x: x[new Slice(train_count)], val_y: y[new Slice(train_count)]); | |||||
((x, y, sample_weight), validation_data) = DataAdapter.train_validation_split((x, y, sample_weight), validation_split); | |||||
} | |||||
// TODO(Wanglongzhi2001) | |||||
if (class_weight != null) | |||||
{ | |||||
throw new NotImplementedException("class_weight is not implemented"); | |||||
} | } | ||||
var data_handler = new DataHandler(new DataHandlerArgs | var data_handler = new DataHandler(new DataHandlerArgs | ||||
{ | { | ||||
X = train_x, | |||||
Y = train_y, | |||||
X = x, | |||||
Y = y, | |||||
SampleWeight = sample_weight, | |||||
BatchSize = batch_size, | BatchSize = batch_size, | ||||
InitialEpoch = initial_epoch, | InitialEpoch = initial_epoch, | ||||
Epochs = epochs, | Epochs = epochs, | ||||
@@ -73,14 +89,17 @@ namespace Tensorflow.Keras.Engine | |||||
train_step_func: train_step_function); | train_step_func: train_step_function); | ||||
} | } | ||||
public ICallback fit(IEnumerable<NDArray> x, NDArray y, | public ICallback fit(IEnumerable<NDArray> x, NDArray y, | ||||
int batch_size = -1, | int batch_size = -1, | ||||
int epochs = 1, | int epochs = 1, | ||||
int verbose = 1, | int verbose = 1, | ||||
List<ICallback> callbacks = null, | List<ICallback> callbacks = null, | ||||
float validation_split = 0f, | float validation_split = 0f, | ||||
(IEnumerable<NDArray> val_x, NDArray val_y)? validation_data = null, | |||||
ValidationDataPack validation_data = null, | |||||
bool shuffle = true, | bool shuffle = true, | ||||
Dictionary<int, float> class_weight = null, | |||||
NDArray sample_weight = null, | |||||
int initial_epoch = 0, | int initial_epoch = 0, | ||||
int max_queue_size = 10, | int max_queue_size = 10, | ||||
int workers = 1, | int workers = 1, | ||||
@@ -95,27 +114,23 @@ namespace Tensorflow.Keras.Engine | |||||
} | } | ||||
} | } | ||||
var train_x = x; | |||||
var train_y = y; | |||||
sample_weight = sample_weight?.astype(TF_DataType.TF_FLOAT); | |||||
if (validation_split != 0f && validation_data == null) | if (validation_split != 0f && validation_data == null) | ||||
{ | { | ||||
int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split)); | |||||
train_x = x.Select(x => x[new Slice(0, train_count)] as NDArray); | |||||
train_y = y[new Slice(0, train_count)]; | |||||
var val_x = x.Select(x => x[new Slice(train_count)] as NDArray); | |||||
var val_y = y[new Slice(train_count)]; | |||||
validation_data = (val_x, val_y); | |||||
((x, y, sample_weight), validation_data) = DataAdapter.train_validation_split((x, y, sample_weight), validation_split); | |||||
} | } | ||||
var data_handler = new DataHandler(new DataHandlerArgs | var data_handler = new DataHandler(new DataHandlerArgs | ||||
{ | { | ||||
X = new Tensors(train_x.ToArray()), | |||||
Y = train_y, | |||||
X = new Tensors(x.ToArray()), | |||||
Y = y, | |||||
BatchSize = batch_size, | BatchSize = batch_size, | ||||
InitialEpoch = initial_epoch, | InitialEpoch = initial_epoch, | ||||
Epochs = epochs, | Epochs = epochs, | ||||
Shuffle = shuffle, | Shuffle = shuffle, | ||||
SampleWeight = sample_weight, | |||||
MaxQueueSize = max_queue_size, | MaxQueueSize = max_queue_size, | ||||
Workers = workers, | Workers = workers, | ||||
UseMultiprocessing = use_multiprocessing, | UseMultiprocessing = use_multiprocessing, | ||||
@@ -142,8 +157,10 @@ namespace Tensorflow.Keras.Engine | |||||
int verbose = 1, | int verbose = 1, | ||||
List<ICallback> callbacks = null, | List<ICallback> callbacks = null, | ||||
IDatasetV2 validation_data = null, | IDatasetV2 validation_data = null, | ||||
int validation_step = 10, // 间隔多少次会进行一次验证 | |||||
int validation_step = 10, | |||||
bool shuffle = true, | bool shuffle = true, | ||||
Dictionary<int, float> class_weight = null, | |||||
NDArray sample_weight = null, | |||||
int initial_epoch = 0, | int initial_epoch = 0, | ||||
int max_queue_size = 10, | int max_queue_size = 10, | ||||
int workers = 1, | int workers = 1, | ||||
@@ -210,7 +227,7 @@ namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
if (validation_step > 0 && epoch ==0 || (epoch) % validation_step != 0) | if (validation_step > 0 && epoch ==0 || (epoch) % validation_step != 0) | ||||
continue; | continue; | ||||
var val_logs = evaluate(validation_data); | var val_logs = evaluate(validation_data); | ||||
foreach(var log in val_logs) | foreach(var log in val_logs) | ||||
{ | { | ||||
@@ -233,7 +250,7 @@ namespace Tensorflow.Keras.Engine | |||||
return callbacks.History; | return callbacks.History; | ||||
} | } | ||||
History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, (NDArray, NDArray)? validation_data, | |||||
History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, ValidationDataPack validation_data, | |||||
Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func) | Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func) | ||||
{ | { | ||||
stop_training = false; | stop_training = false; | ||||
@@ -274,7 +291,8 @@ namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
// Because evaluate calls call_test_batch_end, this interferes with our output on the screen | // Because evaluate calls call_test_batch_end, this interferes with our output on the screen | ||||
// so we need to pass a is_val parameter to stop on_test_batch_end | // so we need to pass a is_val parameter to stop on_test_batch_end | ||||
var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2, is_val:true); | |||||
var (val_x, val_y, val_sample_weight) = validation_data; | |||||
var val_logs = evaluate(val_x, val_y, sample_weight:val_sample_weight, is_val:true); | |||||
foreach (var log in val_logs) | foreach (var log in val_logs) | ||||
{ | { | ||||
logs["val_" + log.Key] = log.Value; | logs["val_" + log.Key] = log.Value; | ||||
@@ -296,64 +314,5 @@ namespace Tensorflow.Keras.Engine | |||||
return callbacks.History; | return callbacks.History; | ||||
} | } | ||||
History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, (IEnumerable<Tensor>, NDArray)? validation_data, | |||||
Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func) | |||||
{ | |||||
stop_training = false; | |||||
_train_counter.assign(0); | |||||
var callbacks = new CallbackList(new CallbackParams | |||||
{ | |||||
Model = this, | |||||
Verbose = verbose, | |||||
Epochs = epochs, | |||||
Steps = data_handler.Inferredsteps | |||||
}); | |||||
if (callbackList != null) | |||||
{ | |||||
foreach (var callback in callbackList) | |||||
callbacks.callbacks.add(callback); | |||||
} | |||||
callbacks.on_train_begin(); | |||||
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||||
{ | |||||
reset_metrics(); | |||||
callbacks.on_epoch_begin(epoch); | |||||
// data_handler.catch_stop_iteration(); | |||||
var logs = new Dictionary<string, float>(); | |||||
long End_step = 0; | |||||
foreach (var step in data_handler.steps()) | |||||
{ | |||||
callbacks.on_train_batch_begin(step); | |||||
logs = train_step_func(data_handler, iterator); | |||||
var end_step = step + data_handler.StepIncrement; | |||||
End_step = end_step; | |||||
callbacks.on_train_batch_end(end_step, logs); | |||||
} | |||||
if (validation_data != null) | |||||
{ | |||||
var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2); | |||||
foreach (var log in val_logs) | |||||
{ | |||||
logs["val_" + log.Key] = log.Value; | |||||
callbacks.on_train_batch_end(End_step, logs); | |||||
} | |||||
} | |||||
callbacks.on_epoch_end(epoch, logs); | |||||
GC.Collect(); | |||||
GC.WaitForPendingFinalizers(); | |||||
if (stop_training) | |||||
{ | |||||
break; | |||||
} | |||||
} | |||||
return callbacks.History; | |||||
} | |||||
} | } | ||||
} | } |
@@ -12,7 +12,9 @@ namespace Tensorflow.Keras.Engine | |||||
Dictionary<string, float> train_step_function(DataHandler data_handler, OwnedIterator iterator) | Dictionary<string, float> train_step_function(DataHandler data_handler, OwnedIterator iterator) | ||||
{ | { | ||||
var data = iterator.next(); | var data = iterator.next(); | ||||
var outputs = train_step(data_handler, data[0], data[1]); | |||||
// whether have sample_weight | |||||
var outputs = data.Length == 2 ? train_step(data_handler, data[0], data[1]) : | |||||
train_step(data_handler, data[0], data[1], data[2]); | |||||
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | ||||
return outputs; | return outputs; | ||||
} | } | ||||
@@ -21,7 +23,13 @@ namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
var data = iterator.next(); | var data = iterator.next(); | ||||
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; | var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount; | ||||
var outputs = train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())); | |||||
var outputs = data.Length == 2 ? | |||||
train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())) : | |||||
train_step( | |||||
data_handler, | |||||
new Tensors(data.Take(x_size).ToArray()), | |||||
new Tensors(data.Skip(x_size).Take(x_size).ToArray()), | |||||
new Tensors(data.Skip(2 * x_size).ToArray())); | |||||
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | ||||
return outputs; | return outputs; | ||||
} | } | ||||
@@ -61,6 +69,34 @@ namespace Tensorflow.Keras.Engine | |||||
}); | }); | ||||
return dict; | return dict; | ||||
} | } | ||||
Dictionary<string, float> train_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight = null) | |||||
{ | |||||
(x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight); | |||||
using var tape = tf.GradientTape(); | |||||
var y_pred = Apply(x, training: true); | |||||
var loss = compiled_loss.Call(y, y_pred, sample_weight:sample_weight); | |||||
// For custom training steps, users can just write: | |||||
// trainable_variables = self.trainable_variables | |||||
// gradients = tape.gradient(loss, trainable_variables) | |||||
// self.optimizer.apply_gradients(zip(gradients, trainable_variables)) | |||||
// The _minimize call does a few extra steps unnecessary in most cases, | |||||
// such as loss scaling and gradient clipping. | |||||
_minimize(tape, optimizer, loss, TrainableVariables); | |||||
compiled_metrics.update_state(y, y_pred); | |||||
var dict = new Dictionary<string, float>(); | |||||
metrics.ToList().ForEach(x => | |||||
{ | |||||
var r = x.result(); | |||||
if (r.ndim > 0) | |||||
{ | |||||
r = tf.reduce_mean(r); | |||||
} | |||||
dict[x.Name] = (float)r; | |||||
}); | |||||
return dict; | |||||
} | |||||
void _minimize(GradientTape tape, IOptimizer optimizer, Tensor loss, List<IVariableV1> trainable_variables) | void _minimize(GradientTape tape, IOptimizer optimizer, Tensor loss, List<IVariableV1> trainable_variables) | ||||
{ | { | ||||
@@ -74,8 +74,8 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||||
OneHot = true, | OneHot = true, | ||||
ValidationSize = 55000, | ValidationSize = 55000, | ||||
}).Result; | }).Result; | ||||
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 1); | |||||
var sample_weight = np.ones(((int)dataset.Train.Data.shape[0])); | |||||
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 1, sample_weight:sample_weight); | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||