@@ -3,6 +3,8 @@ using System.Collections.Generic; | |||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
using Tensorflow.Util; | |||||
using Tensorflow.Framework; | |||||
namespace Tensorflow.Keras.Engine.DataAdapters | namespace Tensorflow.Keras.Engine.DataAdapters | ||||
{ | { | ||||
@@ -24,6 +26,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
long _steps_per_execution_value; | long _steps_per_execution_value; | ||||
int _initial_epoch => args.InitialEpoch; | int _initial_epoch => args.InitialEpoch; | ||||
int _epochs => args.Epochs; | int _epochs => args.Epochs; | ||||
NDArray _sample_weight => args.SampleWeight; | |||||
IVariableV1 _steps_per_execution; | IVariableV1 _steps_per_execution; | ||||
public DataHandler(DataHandlerArgs args) | public DataHandler(DataHandlerArgs args) | ||||
@@ -75,10 +78,75 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
} | } | ||||
_dataset = _adapter.GetDataset(); | _dataset = _adapter.GetDataset(); | ||||
_inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset); | |||||
_current_step = 0; | _current_step = 0; | ||||
_step_increment = _steps_per_execution_value - 1; | _step_increment = _steps_per_execution_value - 1; | ||||
_insufficient_data = false; | _insufficient_data = false; | ||||
_configure_dataset_and_inferred_steps(args.X, args.ClassWeight); | |||||
} | |||||
void _configure_dataset_and_inferred_steps(Tensors x, Dictionary<int, float> class_weight) | |||||
{ | |||||
if (_dataset == null) | |||||
{ | |||||
_dataset = _adapter.GetDataset(); | |||||
_inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset); | |||||
} | |||||
if (class_weight != null) | |||||
{ | |||||
_dataset = _dataset.map(_make_class_weight_map_fn(class_weight)); | |||||
} | |||||
_inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset); | |||||
} | |||||
Func<Tensors, Tensors> _make_class_weight_map_fn(Dictionary<int, float> class_weight) | |||||
{ | |||||
var class_ids = class_weight.Keys.OrderBy(key => key).ToList(); | |||||
var expected_class_ids = range(class_ids[0], class_ids[class_ids.Count - 1] + 1); | |||||
if (!class_ids.SequenceEqual(expected_class_ids)) | |||||
{ | |||||
throw new ValueError("Expected `class_weight` to be a dict with keys from 0 to one less "+ | |||||
$"than the number of classes, found {class_weight}"); | |||||
} | |||||
var class_weight_list = new List<float>(); | |||||
foreach (var class_id in class_ids) | |||||
{ | |||||
class_weight_list.Add(class_weight[class_id]); | |||||
} | |||||
var class_weight_tensor = tf.convert_to_tensor(class_weight_list.ToArray()); | |||||
Func<Tensors, Tensors> _class_weight_map_fn = (Tensors data) => | |||||
{ | |||||
var x = data[0]; | |||||
var y = data[1]; | |||||
var sw = _sample_weight == null ? null : ops.convert_to_tensor(_sample_weight); | |||||
if (y.shape.rank > 2) | |||||
{ | |||||
throw new ValueError("`class_weight` not supported for 3+ dimensional targets."); | |||||
} | |||||
var y_classes = smart_module.smart_cond( | |||||
y.shape.rank == 2 && y.shape[1] > 1, | |||||
() => math_ops.argmax(y, dimension: 1), | |||||
() => math_ops.cast(tf.reshape(y, (-1)), TF_DataType.TF_INT64)); | |||||
var cw = array_ops.gather(class_weight_tensor, y_classes); | |||||
if (sw != null) | |||||
{ | |||||
cw = tf.cast(cw, sw.dtype); | |||||
cw *= sw; | |||||
} | |||||
else | |||||
{ | |||||
sw = cw; | |||||
} | |||||
return new Tensors { x, y, sw }; | |||||
}; | |||||
return _class_weight_map_fn; | |||||
} | } | ||||
long _infer_steps(int steps_per_epoch, IDatasetV2 dataset) | long _infer_steps(int steps_per_epoch, IDatasetV2 dataset) | ||||
@@ -164,11 +164,20 @@ namespace Tensorflow.Keras.Engine | |||||
} | } | ||||
Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight = null) | |||||
Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y) | |||||
{ | |||||
(x,y) = data_handler.DataAdapter.Expand1d(x, y); | |||||
var y_pred = Apply(x, training: false); | |||||
var loss = compiled_loss.Call(y, y_pred); | |||||
compiled_metrics.update_state(y, y_pred); | |||||
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2); | |||||
} | |||||
Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight) | |||||
{ | { | ||||
(x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight); | (x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight); | ||||
var y_pred = Apply(x, training: false); | var y_pred = Apply(x, training: false); | ||||
var loss = compiled_loss.Call(y, y_pred, sample_weight:sample_weight); | |||||
var loss = compiled_loss.Call(y, y_pred, sample_weight: sample_weight); | |||||
compiled_metrics.update_state(y, y_pred); | compiled_metrics.update_state(y, y_pred); | ||||
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2); | return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2); | ||||
} | } | ||||
@@ -63,12 +63,6 @@ namespace Tensorflow.Keras.Engine | |||||
((x, y, sample_weight), validation_data) = DataAdapter.train_validation_split((x, y, sample_weight), validation_split); | ((x, y, sample_weight), validation_data) = DataAdapter.train_validation_split((x, y, sample_weight), validation_split); | ||||
} | } | ||||
// TODO(Wanglongzhi2001) | |||||
if (class_weight != null) | |||||
{ | |||||
throw new NotImplementedException("class_weight is not implemented"); | |||||
} | |||||
var data_handler = new DataHandler(new DataHandlerArgs | var data_handler = new DataHandler(new DataHandlerArgs | ||||
{ | { | ||||
X = x, | X = x, | ||||
@@ -78,6 +72,7 @@ namespace Tensorflow.Keras.Engine | |||||
InitialEpoch = initial_epoch, | InitialEpoch = initial_epoch, | ||||
Epochs = epochs, | Epochs = epochs, | ||||
Shuffle = shuffle, | Shuffle = shuffle, | ||||
ClassWeight = class_weight, | |||||
MaxQueueSize = max_queue_size, | MaxQueueSize = max_queue_size, | ||||
Workers = workers, | Workers = workers, | ||||
UseMultiprocessing = use_multiprocessing, | UseMultiprocessing = use_multiprocessing, | ||||
@@ -126,11 +121,12 @@ namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
X = new Tensors(x.ToArray()), | X = new Tensors(x.ToArray()), | ||||
Y = y, | Y = y, | ||||
SampleWeight = sample_weight, | |||||
BatchSize = batch_size, | BatchSize = batch_size, | ||||
InitialEpoch = initial_epoch, | InitialEpoch = initial_epoch, | ||||
Epochs = epochs, | Epochs = epochs, | ||||
Shuffle = shuffle, | Shuffle = shuffle, | ||||
SampleWeight = sample_weight, | |||||
ClassWeight = class_weight, | |||||
MaxQueueSize = max_queue_size, | MaxQueueSize = max_queue_size, | ||||
Workers = workers, | Workers = workers, | ||||
UseMultiprocessing = use_multiprocessing, | UseMultiprocessing = use_multiprocessing, | ||||
@@ -174,6 +170,7 @@ namespace Tensorflow.Keras.Engine | |||||
InitialEpoch = initial_epoch, | InitialEpoch = initial_epoch, | ||||
Epochs = epochs, | Epochs = epochs, | ||||
Shuffle = shuffle, | Shuffle = shuffle, | ||||
SampleWeight = sample_weight, | |||||
MaxQueueSize = max_queue_size, | MaxQueueSize = max_queue_size, | ||||
Workers = workers, | Workers = workers, | ||||
UseMultiprocessing = use_multiprocessing, | UseMultiprocessing = use_multiprocessing, | ||||