@@ -13,6 +13,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
{ | |||
DataHandlerArgs args; | |||
IDataAdapter _adapter; | |||
public IDataAdapter DataAdapter => _adapter; | |||
IDatasetV2 _dataset; | |||
int _inferred_steps; | |||
int _current_step; | |||
@@ -20,5 +20,6 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
bool CanHandle(Tensor x, Tensor y = null); | |||
IDatasetV2 GetDataset(); | |||
int GetSize(); | |||
(Tensor, Tensor) Expand1d(Tensor x, Tensor y); | |||
} | |||
} |
@@ -89,5 +89,12 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
public int GetSize() | |||
=> _size; | |||
public (Tensor, Tensor) Expand1d(Tensor x, Tensor y) | |||
{ | |||
if (y.TensorShape.ndim == 1) | |||
y = array_ops.expand_dims(y, axis: -1); | |||
return (x, y); | |||
} | |||
} | |||
} |
@@ -1,9 +1,11 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Losses; | |||
using Tensorflow.Keras.Metrics; | |||
using Tensorflow.Keras.Utils; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
@@ -12,6 +14,10 @@ namespace Tensorflow.Keras.Engine | |||
ILossFunc _user_losses; | |||
ILossFunc _losses; | |||
Mean _loss_metric; | |||
bool _built; | |||
Tensor[] _per_output_metrics; | |||
List<Tensor> loss_values; | |||
List<Tensor> loss_metric_values; | |||
public LossesContainer(ILossFunc losses, string[] output_names = null) | |||
: base(output_names) | |||
@@ -19,6 +25,8 @@ namespace Tensorflow.Keras.Engine | |||
_user_losses = losses; | |||
_losses = losses; | |||
_loss_metric = new Mean(name: "loss"); | |||
loss_values = new List<Tensor>(); | |||
loss_metric_values = new List<Tensor>(); | |||
_built = false; | |||
} | |||
@@ -27,14 +35,44 @@ namespace Tensorflow.Keras.Engine | |||
/// </summary> | |||
/// <param name="y_true"></param> | |||
/// <param name="y_pred"></param> | |||
public void Apply(Tensor y_true, Tensor y_pred) | |||
public Tensor Call(Tensor y_true, Tensor y_pred) | |||
{ | |||
Build(y_pred); | |||
var loss_value = _losses.Call(y_true, y_pred); | |||
var loss_metric_value = loss_value; | |||
var batch_dim = array_ops.shape(y_true)[0]; | |||
/*if (_losses.Reduction == ReductionV2.SUM_OVER_BATCH_SIZE | |||
|| _losses.Reduction == ReductionV2.AUTO) | |||
loss_value = losses_utils.scale_loss_for_distribution(loss_value);*/ | |||
loss_values.append(loss_value); | |||
loss_metric_values.append(loss_metric_value); | |||
if(loss_values.Count > 0) | |||
{ | |||
var total_loss_metric_value = math_ops.add_n(loss_metric_values.ToArray()); | |||
_loss_metric.update_state(total_loss_metric_value, batch_dim); | |||
// loss_values = losses_utils.cast_losses_to_common_dtype(loss_values); | |||
var total_loss = math_ops.add_n(loss_values.ToArray()); | |||
return total_loss; | |||
} | |||
else | |||
{ | |||
// Ok for a model to have no compiled loss. | |||
return array_ops.zeros(new TensorShape()); | |||
} | |||
} | |||
public void Build() | |||
public void Build(Tensor y_pred) | |||
{ | |||
_create_metrics(); | |||
_built = true; | |||
} | |||
void _create_metrics() | |||
{ | |||
// _per_output_metrics = _output_names.Select(x => null); | |||
} | |||
} | |||
} |
@@ -16,5 +16,18 @@ namespace Tensorflow.Keras.Engine | |||
_metrics = metrics; | |||
_built = false; | |||
} | |||
public void update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) | |||
{ | |||
if (!_built) | |||
Build(); | |||
_built = true; | |||
} | |||
void Build() | |||
{ | |||
} | |||
} | |||
} |
@@ -10,6 +10,8 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
public partial class Model | |||
{ | |||
LossesContainer compiled_loss; | |||
MetricsContainer compiled_metrics; | |||
public void compile(string optimizerName, ILossFunc lossName) | |||
{ | |||
throw new NotImplementedException(""); | |||
@@ -18,8 +20,8 @@ namespace Tensorflow.Keras.Engine | |||
public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics) | |||
{ | |||
this.optimizer = optimizer; | |||
var compiled_loss = new LossesContainer(loss, output_names: output_names); | |||
var compiled_metrics = new MetricsContainer(metrics, output_names: output_names); | |||
compiled_loss = new LossesContainer(loss, output_names: output_names); | |||
compiled_metrics = new MetricsContainer(metrics, output_names: output_names); | |||
int experimental_steps_per_execution = 1; | |||
_configure_steps_per_execution(experimental_steps_per_execution); | |||
@@ -37,7 +37,7 @@ namespace Tensorflow.Keras.Engine | |||
var val_x = x[new Slice(train_count)]; | |||
var val_y = y[new Slice(train_count)]; | |||
var data_handler = new DataHandler(new DataHandlerArgs | |||
data_handler = new DataHandler(new DataHandlerArgs | |||
{ | |||
X = train_x, | |||
Y = train_y, | |||
@@ -1,7 +1,10 @@ | |||
using NumSharp; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Gradients; | |||
using Tensorflow.Keras.Optimizers; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Engine | |||
@@ -11,7 +14,8 @@ namespace Tensorflow.Keras.Engine | |||
Tensor step_function(OwnedIterator iterator) | |||
{ | |||
var data = iterator.next(); | |||
train_step(data[0], data[1]); | |||
var outputs = train_step(data[0], data[1]); | |||
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | |||
throw new NotImplementedException(""); | |||
} | |||
@@ -20,11 +24,33 @@ namespace Tensorflow.Keras.Engine | |||
/// </summary> | |||
/// <param name="data"></param> | |||
/// <returns></returns> | |||
Tensor train_step(Tensor x, Tensor y) | |||
IEnumerable<(string, Tensor)> train_step(Tensor x, Tensor y) | |||
{ | |||
(x, y) = data_handler.DataAdapter.Expand1d(x, y); | |||
using var tape = tf.GradientTape(); | |||
var y_pred = Apply(x, is_training: true); | |||
throw new NotImplementedException(""); | |||
var loss = compiled_loss.Call(y, y_pred); | |||
// For custom training steps, users can just write: | |||
// trainable_variables = self.trainable_variables | |||
// gradients = tape.gradient(loss, trainable_variables) | |||
// self.optimizer.apply_gradients(zip(gradients, trainable_variables)) | |||
// The _minimize call does a few extra steps unnecessary in most cases, | |||
// such as loss scaling and gradient clipping. | |||
_minimize(tape, optimizer, loss, trainable_variables); | |||
compiled_metrics.update_state(y, y_pred); | |||
return new[] { ("loss", loss) }; | |||
} | |||
void _minimize(GradientTape tape, OptimizerV2 optimizer, Tensor loss, List<IVariableV1> trainable_variables) | |||
{ | |||
var gradients = tape.gradient(loss, trainable_variables); | |||
gradients = optimizer._aggregate_gradients(zip(gradients, trainable_variables)); | |||
gradients = optimizer._clip_gradients(gradients); | |||
optimizer.apply_gradients(zip(gradients, trainable_variables.Select(x => x as ResourceVariable)), | |||
experimental_aggregate_gradients: false); | |||
} | |||
} | |||
} |
@@ -6,6 +6,7 @@ using Tensorflow.Keras.Losses; | |||
using Tensorflow.Keras.Optimizers; | |||
using NumSharp; | |||
using System.Collections.Generic; | |||
using System.Data.Common; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
@@ -23,7 +24,7 @@ namespace Tensorflow.Keras.Engine | |||
#pragma warning restore CS0414 // The field 'Model._is_compiled' is assigned but its value is never used | |||
#pragma warning restore CS0108 // Member hides inherited member; missing new keyword | |||
ILossFunc loss; | |||
IOptimizer optimizer; | |||
OptimizerV2 optimizer; | |||
IVariableV1 _steps_per_execution; | |||
protected bool _is_graph_network; | |||
protected Tensors inputs; | |||
@@ -34,6 +35,7 @@ namespace Tensorflow.Keras.Engine | |||
IVariableV1 _predict_counter; | |||
bool _base_model_initialized; | |||
bool stop_training; | |||
DataHandler data_handler; | |||
public Model(ModelArgs args) | |||
: base(args) | |||
@@ -6,5 +6,7 @@ namespace Tensorflow.Keras.Losses | |||
{ | |||
public interface ILossFunc | |||
{ | |||
string Reduction { get; } | |||
Tensor Call(Tensor y_true, Tensor y_pred); | |||
} | |||
} |
@@ -1,6 +1,8 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.Utils; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Losses | |||
{ | |||
@@ -14,6 +16,8 @@ namespace Tensorflow.Keras.Losses | |||
bool _allow_sum_over_batch_size; | |||
string _name_scope; | |||
public string Reduction => reduction; | |||
public Loss(string reduction = ReductionV2.AUTO, string name = null) | |||
{ | |||
this.reduction = reduction; | |||
@@ -21,6 +25,17 @@ namespace Tensorflow.Keras.Losses | |||
_allow_sum_over_batch_size = false; | |||
} | |||
public virtual Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1) | |||
{ | |||
throw new NotImplementedException(""); | |||
} | |||
public Tensor Call(Tensor y_true, Tensor y_pred) | |||
{ | |||
var losses = Apply(y_true, y_pred); | |||
return losses_utils.compute_weighted_loss(losses, reduction: ReductionV2.SUM_OVER_BATCH_SIZE); | |||
} | |||
void _set_name_scope() | |||
{ | |||
_name_scope = name; | |||
@@ -6,15 +6,11 @@ namespace Tensorflow.Keras.Losses | |||
{ | |||
public class LossFunctionWrapper : Loss | |||
{ | |||
Action fn; | |||
public LossFunctionWrapper(Action fn, | |||
string reduction = ReductionV2.AUTO, | |||
public LossFunctionWrapper(string reduction = ReductionV2.AUTO, | |||
string name = null) | |||
: base(reduction: reduction, | |||
name: name) | |||
{ | |||
this.fn = fn; | |||
} | |||
} | |||
} |
@@ -6,6 +6,9 @@ namespace Tensorflow.Keras.Losses | |||
{ | |||
public class ReductionV2 | |||
{ | |||
public const string NONE = "none"; | |||
public const string AUTO = "auto"; | |||
public const string SUM_OVER_BATCH_SIZE = "sum_over_batch_size"; | |||
public const string WEIGHTED_MEAN = "weighted_mean"; | |||
} | |||
} |
@@ -1,6 +1,8 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Data; | |||
using System.Text; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Losses | |||
{ | |||
@@ -9,16 +11,27 @@ namespace Tensorflow.Keras.Losses | |||
public SparseCategoricalCrossentropy(bool from_logits = false, | |||
string reduction = ReductionV2.AUTO, | |||
string name = "sparse_categorical_crossentropy") : | |||
base(sparse_categorical_crossentropy, | |||
reduction: reduction, | |||
base(reduction: reduction, | |||
name: name) | |||
{ | |||
} | |||
static void sparse_categorical_crossentropy() | |||
public override Tensor Apply(Tensor target, Tensor output, bool from_logits = false, int axis = -1) | |||
{ | |||
target = tf.cast(target, dtype: TF_DataType.TF_INT64); | |||
// Try to adjust the shape so that rank of labels = rank of logits - 1. | |||
var output_shape = array_ops.shape_v2(output); | |||
var output_rank = output.TensorShape.ndim; | |||
var target_rank = target.TensorShape.ndim; | |||
var update_shape = target_rank != output_rank - 1; | |||
if (update_shape) | |||
{ | |||
target = array_ops.reshape(target, new int[] { -1 }); | |||
output = array_ops.reshape(output, new int[] { -1, output_shape[-1].numpy() }); | |||
} | |||
return tf.nn.sparse_softmax_cross_entropy_with_logits(target, output); | |||
} | |||
} | |||
} |
@@ -2,6 +2,8 @@ | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Losses; | |||
using Tensorflow.Keras.Utils; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Metrics | |||
@@ -13,9 +15,14 @@ namespace Tensorflow.Keras.Metrics | |||
{ | |||
IVariableV1 total; | |||
IVariableV1 count; | |||
string _reduction; | |||
TF_DataType _dtype; | |||
public Reduce(string reduction, string name, TF_DataType dtype = TF_DataType.DtInvalid) | |||
: base(name: name, dtype: dtype) | |||
{ | |||
_reduction = reduction; | |||
_dtype = dtype; | |||
total = add_weight("total", initializer: tf.zeros_initializer); | |||
if (reduction == Reduction.WEIGHTED_MEAN || | |||
@@ -24,5 +31,36 @@ namespace Tensorflow.Keras.Metrics | |||
count = add_weight("count", initializer: tf.zeros_initializer); | |||
} | |||
} | |||
public Tensor update_state(Tensor values, Tensor sample_weight = null) | |||
{ | |||
if(sample_weight != null) | |||
{ | |||
(values, sample_weight) = losses_utils.squeeze_or_expand_dimensions( | |||
values, sample_weight: sample_weight); | |||
sample_weight = math_ops.cast(sample_weight, dtype: values.dtype); | |||
values = math_ops.multiply(values, sample_weight); | |||
} | |||
Tensor update_total_op = null; | |||
var value_sum = math_ops.reduce_sum(values); | |||
tf_with(ops.control_dependencies(new[] { value_sum }), ctl => | |||
{ | |||
var update_total_op = total.assign_add(value_sum); | |||
}); | |||
Tensor num_values = null; | |||
if (_reduction == ReductionV2.WEIGHTED_MEAN) | |||
{ | |||
if (sample_weight == null) | |||
num_values = math_ops.cast(array_ops.size(values), _dtype); | |||
else | |||
num_values = math_ops.reduce_sum(sample_weight); | |||
} | |||
return tf_with(ops.control_dependencies(new[] { update_total_op }), ctl | |||
=> count.assign_add(num_values)); | |||
} | |||
} | |||
} |
@@ -111,11 +111,16 @@ namespace Tensorflow.Keras.Optimizers | |||
}); | |||
} | |||
Tensor[] _aggregate_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars) | |||
public Tensor[] _aggregate_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars) | |||
{ | |||
return grads_and_vars.Select(x => x.Item1).ToArray(); | |||
} | |||
public Tensor[] _clip_gradients(Tensor[] grads) | |||
{ | |||
return grads; | |||
} | |||
protected IVariableV1 get_slot(IVariableV1 var, string slot_name) | |||
{ | |||
var slot_dict = _slots[var.UniqueId]; | |||
@@ -0,0 +1,83 @@ | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using System.Linq; | |||
using Tensorflow.Keras.Losses; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Utils | |||
{ | |||
public class losses_utils | |||
{ | |||
public static Tensor compute_weighted_loss(Tensor losses, Tensor sample_weight = null, string reduction = null, string name = null) | |||
{ | |||
if (sample_weight == null) | |||
sample_weight = tf.constant(1.0f); | |||
var weighted_losses = scale_losses_by_sample_weight(losses, sample_weight); | |||
// Apply reduction function to the individual weighted losses. | |||
var loss = reduce_weighted_loss(weighted_losses, reduction); | |||
// Convert the result back to the input type. | |||
// loss = math_ops.cast(loss, losses.dtype); | |||
return loss; | |||
} | |||
public static Tensor scale_losses_by_sample_weight(Tensor losses, Tensor sample_weight) | |||
{ | |||
// losses = math_ops.cast(losses, dtypes.float32); | |||
// sample_weight = math_ops.cast(sample_weight, dtypes.float32); | |||
// Update dimensions of `sample_weight` to match with `losses` if possible. | |||
// (losses, sample_weight) = squeeze_or_expand_dimensions(losses, sample_weight); | |||
return math_ops.multiply(losses, sample_weight); | |||
} | |||
public static (Tensor, Tensor) squeeze_or_expand_dimensions(Tensor y_pred, Tensor sample_weight) | |||
{ | |||
var weights_shape = sample_weight.TensorShape; | |||
var weights_rank = weights_shape.ndim; | |||
if (weights_rank == 0) | |||
return (y_pred, sample_weight); | |||
throw new NotImplementedException(""); | |||
} | |||
public static Tensor reduce_weighted_loss(Tensor weighted_losses, string reduction) | |||
{ | |||
if (reduction == ReductionV2.NONE) | |||
return weighted_losses; | |||
else | |||
{ | |||
var loss = math_ops.reduce_sum(weighted_losses); | |||
if (reduction == ReductionV2.SUM_OVER_BATCH_SIZE) | |||
loss = _safe_mean(loss, _num_elements(weighted_losses)); | |||
return loss; | |||
} | |||
} | |||
static Tensor _safe_mean(Tensor losses, Tensor num_present) | |||
{ | |||
var total_loss = math_ops.reduce_sum(losses); | |||
return math_ops.div_no_nan(total_loss, num_present, name: "value"); | |||
} | |||
static Tensor _num_elements(Tensor losses) | |||
{ | |||
return tf_with(ops.name_scope("num_elements"), scope => | |||
{ | |||
return math_ops.cast(array_ops.size(losses, name: scope), dtype: losses.dtype); | |||
}); | |||
} | |||
} | |||
} |
@@ -24,6 +24,9 @@ namespace Tensorflow | |||
public Tensor compute_weighted_loss(Tensor losses, Tensor weights = null, string scope = null, | |||
string loss_collection = "losses", string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) | |||
{ | |||
if (weights == null) | |||
weights = tf.constant(1.0f); | |||
return tf_with(ops.name_scope(scope, default_name: "weighted_loss", (losses, weights)), delegate | |||
{ | |||
// Save the `reduction` argument for loss normalization when distributing | |||
@@ -521,6 +521,9 @@ namespace Tensorflow | |||
public static Tensor shape(Tensor input, string name = null, TF_DataType out_type = TF_DataType.TF_INT32) | |||
=> shape_internal(input, name, optimize: true, out_type: out_type); | |||
public static Tensor shape_v2(Tensor input, string name = null, TF_DataType out_type = TF_DataType.TF_INT32) | |||
=> shape_internal(input, name, optimize: true, out_type: out_type); | |||
public static Tensor size(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) | |||
=> size_internal(input, name, optimize: optimize, out_type: out_type); | |||
@@ -118,10 +118,13 @@ namespace Tensorflow | |||
/// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) | |||
/// </remarks> | |||
public static Tensor div_no_nan(Tensor x, Tensor y, string name = null) | |||
{ | |||
var op = tf.OpDefLib._apply_op_helper("DivNoNan", name: name, args: new { x, y }); | |||
return op.output; | |||
} | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("DivNoNan", name: name, new { x, y }).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"DivNoNan", name, | |||
null, | |||
x, y).FirstOrDefault(), | |||
x, y); | |||
/// <summary> | |||
/// Computes the mean of elements across dimensions of a tensor. | |||