Browse Source

Merge pull request #1189 from Wanglongzhi2001/master

feat: add the implementation of class_weight in model.fit
tags/v0.150.0-BERT-Model
Haiping GitHub 2 years ago
parent
commit
43c3705183
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 84 additions and 10 deletions
  1. +69
    -1
      src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
  2. +11
    -2
      src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
  3. +4
    -7
      src/TensorFlowNET.Keras/Engine/Model.Fit.cs

+ 69
- 1
src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs View File

@@ -3,6 +3,8 @@ using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using static Tensorflow.Binding;
using Tensorflow.Keras.Utils;
using Tensorflow.Util;
using Tensorflow.Framework;

namespace Tensorflow.Keras.Engine.DataAdapters
{
@@ -24,6 +26,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
long _steps_per_execution_value;
int _initial_epoch => args.InitialEpoch;
int _epochs => args.Epochs;
NDArray _sample_weight => args.SampleWeight;
IVariableV1 _steps_per_execution;

public DataHandler(DataHandlerArgs args)
@@ -75,10 +78,75 @@ namespace Tensorflow.Keras.Engine.DataAdapters
}
_dataset = _adapter.GetDataset();
_inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset);
_current_step = 0;
_step_increment = _steps_per_execution_value - 1;
_insufficient_data = false;
_configure_dataset_and_inferred_steps(args.X, args.ClassWeight);
}

void _configure_dataset_and_inferred_steps(Tensors x, Dictionary<int, float> class_weight)
{
if (_dataset == null)
{
_dataset = _adapter.GetDataset();
_inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset);
}

if (class_weight != null)
{
_dataset = _dataset.map(_make_class_weight_map_fn(class_weight));
}
_inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset);
}


Func<Tensors, Tensors> _make_class_weight_map_fn(Dictionary<int, float> class_weight)
{
var class_ids = class_weight.Keys.OrderBy(key => key).ToList();
var expected_class_ids = range(class_ids[0], class_ids[class_ids.Count - 1] + 1);
if (!class_ids.SequenceEqual(expected_class_ids))
{
throw new ValueError("Expected `class_weight` to be a dict with keys from 0 to one less "+
$"than the number of classes, found {class_weight}");
}
var class_weight_list = new List<float>();
foreach (var class_id in class_ids)
{
class_weight_list.Add(class_weight[class_id]);
}
var class_weight_tensor = tf.convert_to_tensor(class_weight_list.ToArray());

Func<Tensors, Tensors> _class_weight_map_fn = (Tensors data) =>
{
var x = data[0];
var y = data[1];
var sw = _sample_weight == null ? null : ops.convert_to_tensor(_sample_weight);

if (y.shape.rank > 2)
{
throw new ValueError("`class_weight` not supported for 3+ dimensional targets.");
}

var y_classes = smart_module.smart_cond(
y.shape.rank == 2 && y.shape[1] > 1,
() => math_ops.argmax(y, dimension: 1),
() => math_ops.cast(tf.reshape(y, (-1)), TF_DataType.TF_INT64));

var cw = array_ops.gather(class_weight_tensor, y_classes);
if (sw != null)
{
cw = tf.cast(cw, sw.dtype);
cw *= sw;
}
else
{
sw = cw;
}
return new Tensors { x, y, sw };
};

return _class_weight_map_fn;
}

long _infer_steps(int steps_per_epoch, IDatasetV2 dataset)


+ 11
- 2
src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs View File

@@ -164,11 +164,20 @@ namespace Tensorflow.Keras.Engine
}


Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight = null)
Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y)
{
(x,y) = data_handler.DataAdapter.Expand1d(x, y);
var y_pred = Apply(x, training: false);
var loss = compiled_loss.Call(y, y_pred);
compiled_metrics.update_state(y, y_pred);
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2);
}

Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight)
{
(x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight);
var y_pred = Apply(x, training: false);
var loss = compiled_loss.Call(y, y_pred, sample_weight:sample_weight);
var loss = compiled_loss.Call(y, y_pred, sample_weight: sample_weight);
compiled_metrics.update_state(y, y_pred);
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2);
}


+ 4
- 7
src/TensorFlowNET.Keras/Engine/Model.Fit.cs View File

@@ -63,12 +63,6 @@ namespace Tensorflow.Keras.Engine
((x, y, sample_weight), validation_data) = DataAdapter.train_validation_split((x, y, sample_weight), validation_split);
}

// TODO(Wanglongzhi2001)
if (class_weight != null)
{
throw new NotImplementedException("class_weight is not implemented");
}

var data_handler = new DataHandler(new DataHandlerArgs
{
X = x,
@@ -78,6 +72,7 @@ namespace Tensorflow.Keras.Engine
InitialEpoch = initial_epoch,
Epochs = epochs,
Shuffle = shuffle,
ClassWeight = class_weight,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
@@ -126,11 +121,12 @@ namespace Tensorflow.Keras.Engine
{
X = new Tensors(x.ToArray()),
Y = y,
SampleWeight = sample_weight,
BatchSize = batch_size,
InitialEpoch = initial_epoch,
Epochs = epochs,
Shuffle = shuffle,
SampleWeight = sample_weight,
ClassWeight = class_weight,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
@@ -174,6 +170,7 @@ namespace Tensorflow.Keras.Engine
InitialEpoch = initial_epoch,
Epochs = epochs,
Shuffle = shuffle,
SampleWeight = sample_weight,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,


Loading…
Cancel
Save