@@ -13,6 +13,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
{ | { | ||||
DataHandlerArgs args; | DataHandlerArgs args; | ||||
IDataAdapter _adapter; | IDataAdapter _adapter; | ||||
public IDataAdapter DataAdapter => _adapter; | |||||
IDatasetV2 _dataset; | IDatasetV2 _dataset; | ||||
int _inferred_steps; | int _inferred_steps; | ||||
int _current_step; | int _current_step; | ||||
@@ -20,5 +20,6 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
bool CanHandle(Tensor x, Tensor y = null); | bool CanHandle(Tensor x, Tensor y = null); | ||||
IDatasetV2 GetDataset(); | IDatasetV2 GetDataset(); | ||||
int GetSize(); | int GetSize(); | ||||
(Tensor, Tensor) Expand1d(Tensor x, Tensor y); | |||||
} | } | ||||
} | } |
@@ -89,5 +89,12 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
public int GetSize() | public int GetSize() | ||||
=> _size; | => _size; | ||||
public (Tensor, Tensor) Expand1d(Tensor x, Tensor y) | |||||
{ | |||||
if (y.TensorShape.ndim == 1) | |||||
y = array_ops.expand_dims(y, axis: -1); | |||||
return (x, y); | |||||
} | |||||
} | } | ||||
} | } |
@@ -1,9 +1,11 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | |||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Losses; | using Tensorflow.Keras.Losses; | ||||
using Tensorflow.Keras.Metrics; | using Tensorflow.Keras.Metrics; | ||||
using Tensorflow.Keras.Utils; | |||||
namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
{ | { | ||||
@@ -12,6 +14,10 @@ namespace Tensorflow.Keras.Engine | |||||
ILossFunc _user_losses; | ILossFunc _user_losses; | ||||
ILossFunc _losses; | ILossFunc _losses; | ||||
Mean _loss_metric; | Mean _loss_metric; | ||||
bool _built; | |||||
Tensor[] _per_output_metrics; | |||||
List<Tensor> loss_values; | |||||
List<Tensor> loss_metric_values; | |||||
public LossesContainer(ILossFunc losses, string[] output_names = null) | public LossesContainer(ILossFunc losses, string[] output_names = null) | ||||
: base(output_names) | : base(output_names) | ||||
@@ -19,6 +25,8 @@ namespace Tensorflow.Keras.Engine | |||||
_user_losses = losses; | _user_losses = losses; | ||||
_losses = losses; | _losses = losses; | ||||
_loss_metric = new Mean(name: "loss"); | _loss_metric = new Mean(name: "loss"); | ||||
loss_values = new List<Tensor>(); | |||||
loss_metric_values = new List<Tensor>(); | |||||
_built = false; | _built = false; | ||||
} | } | ||||
@@ -27,14 +35,44 @@ namespace Tensorflow.Keras.Engine | |||||
/// </summary> | /// </summary> | ||||
/// <param name="y_true"></param> | /// <param name="y_true"></param> | ||||
/// <param name="y_pred"></param> | /// <param name="y_pred"></param> | ||||
public void Apply(Tensor y_true, Tensor y_pred) | |||||
public Tensor Call(Tensor y_true, Tensor y_pred) | |||||
{ | { | ||||
Build(y_pred); | |||||
var loss_value = _losses.Call(y_true, y_pred); | |||||
var loss_metric_value = loss_value; | |||||
var batch_dim = array_ops.shape(y_true)[0]; | |||||
/*if (_losses.Reduction == ReductionV2.SUM_OVER_BATCH_SIZE | |||||
|| _losses.Reduction == ReductionV2.AUTO) | |||||
loss_value = losses_utils.scale_loss_for_distribution(loss_value);*/ | |||||
loss_values.append(loss_value); | |||||
loss_metric_values.append(loss_metric_value); | |||||
if(loss_values.Count > 0) | |||||
{ | |||||
var total_loss_metric_value = math_ops.add_n(loss_metric_values.ToArray()); | |||||
_loss_metric.update_state(total_loss_metric_value, batch_dim); | |||||
// loss_values = losses_utils.cast_losses_to_common_dtype(loss_values); | |||||
var total_loss = math_ops.add_n(loss_values.ToArray()); | |||||
return total_loss; | |||||
} | |||||
else | |||||
{ | |||||
// Ok for a model to have no compiled loss. | |||||
return array_ops.zeros(new TensorShape()); | |||||
} | |||||
} | } | ||||
public void Build() | |||||
public void Build(Tensor y_pred) | |||||
{ | { | ||||
_create_metrics(); | |||||
_built = true; | |||||
} | |||||
void _create_metrics() | |||||
{ | |||||
// _per_output_metrics = _output_names.Select(x => null); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -16,5 +16,18 @@ namespace Tensorflow.Keras.Engine | |||||
_metrics = metrics; | _metrics = metrics; | ||||
_built = false; | _built = false; | ||||
} | } | ||||
public void update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) | |||||
{ | |||||
if (!_built) | |||||
Build(); | |||||
_built = true; | |||||
} | |||||
void Build() | |||||
{ | |||||
} | |||||
} | } | ||||
} | } |
@@ -10,6 +10,8 @@ namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
public partial class Model | public partial class Model | ||||
{ | { | ||||
LossesContainer compiled_loss; | |||||
MetricsContainer compiled_metrics; | |||||
public void compile(string optimizerName, ILossFunc lossName) | public void compile(string optimizerName, ILossFunc lossName) | ||||
{ | { | ||||
throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
@@ -18,8 +20,8 @@ namespace Tensorflow.Keras.Engine | |||||
public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics) | public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics) | ||||
{ | { | ||||
this.optimizer = optimizer; | this.optimizer = optimizer; | ||||
var compiled_loss = new LossesContainer(loss, output_names: output_names); | |||||
var compiled_metrics = new MetricsContainer(metrics, output_names: output_names); | |||||
compiled_loss = new LossesContainer(loss, output_names: output_names); | |||||
compiled_metrics = new MetricsContainer(metrics, output_names: output_names); | |||||
int experimental_steps_per_execution = 1; | int experimental_steps_per_execution = 1; | ||||
_configure_steps_per_execution(experimental_steps_per_execution); | _configure_steps_per_execution(experimental_steps_per_execution); | ||||
@@ -37,7 +37,7 @@ namespace Tensorflow.Keras.Engine | |||||
var val_x = x[new Slice(train_count)]; | var val_x = x[new Slice(train_count)]; | ||||
var val_y = y[new Slice(train_count)]; | var val_y = y[new Slice(train_count)]; | ||||
var data_handler = new DataHandler(new DataHandlerArgs | |||||
data_handler = new DataHandler(new DataHandlerArgs | |||||
{ | { | ||||
X = train_x, | X = train_x, | ||||
Y = train_y, | Y = train_y, | ||||
@@ -1,7 +1,10 @@ | |||||
using NumSharp; | using NumSharp; | ||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | |||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Gradients; | |||||
using Tensorflow.Keras.Optimizers; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
@@ -11,7 +14,8 @@ namespace Tensorflow.Keras.Engine | |||||
Tensor step_function(OwnedIterator iterator) | Tensor step_function(OwnedIterator iterator) | ||||
{ | { | ||||
var data = iterator.next(); | var data = iterator.next(); | ||||
train_step(data[0], data[1]); | |||||
var outputs = train_step(data[0], data[1]); | |||||
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1)); | |||||
throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
} | } | ||||
@@ -20,11 +24,33 @@ namespace Tensorflow.Keras.Engine | |||||
/// </summary> | /// </summary> | ||||
/// <param name="data"></param> | /// <param name="data"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
Tensor train_step(Tensor x, Tensor y) | |||||
IEnumerable<(string, Tensor)> train_step(Tensor x, Tensor y) | |||||
{ | { | ||||
(x, y) = data_handler.DataAdapter.Expand1d(x, y); | |||||
using var tape = tf.GradientTape(); | using var tape = tf.GradientTape(); | ||||
var y_pred = Apply(x, is_training: true); | var y_pred = Apply(x, is_training: true); | ||||
throw new NotImplementedException(""); | |||||
var loss = compiled_loss.Call(y, y_pred); | |||||
// For custom training steps, users can just write: | |||||
// trainable_variables = self.trainable_variables | |||||
// gradients = tape.gradient(loss, trainable_variables) | |||||
// self.optimizer.apply_gradients(zip(gradients, trainable_variables)) | |||||
// The _minimize call does a few extra steps unnecessary in most cases, | |||||
// such as loss scaling and gradient clipping. | |||||
_minimize(tape, optimizer, loss, trainable_variables); | |||||
compiled_metrics.update_state(y, y_pred); | |||||
return new[] { ("loss", loss) }; | |||||
} | |||||
void _minimize(GradientTape tape, OptimizerV2 optimizer, Tensor loss, List<IVariableV1> trainable_variables) | |||||
{ | |||||
var gradients = tape.gradient(loss, trainable_variables); | |||||
gradients = optimizer._aggregate_gradients(zip(gradients, trainable_variables)); | |||||
gradients = optimizer._clip_gradients(gradients); | |||||
optimizer.apply_gradients(zip(gradients, trainable_variables.Select(x => x as ResourceVariable)), | |||||
experimental_aggregate_gradients: false); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -6,6 +6,7 @@ using Tensorflow.Keras.Losses; | |||||
using Tensorflow.Keras.Optimizers; | using Tensorflow.Keras.Optimizers; | ||||
using NumSharp; | using NumSharp; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Data.Common; | |||||
namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
{ | { | ||||
@@ -23,7 +24,7 @@ namespace Tensorflow.Keras.Engine | |||||
#pragma warning restore CS0414 // The field 'Model._is_compiled' is assigned but its value is never used | #pragma warning restore CS0414 // The field 'Model._is_compiled' is assigned but its value is never used | ||||
#pragma warning restore CS0108 // Member hides inherited member; missing new keyword | #pragma warning restore CS0108 // Member hides inherited member; missing new keyword | ||||
ILossFunc loss; | ILossFunc loss; | ||||
IOptimizer optimizer; | |||||
OptimizerV2 optimizer; | |||||
IVariableV1 _steps_per_execution; | IVariableV1 _steps_per_execution; | ||||
protected bool _is_graph_network; | protected bool _is_graph_network; | ||||
protected Tensors inputs; | protected Tensors inputs; | ||||
@@ -34,6 +35,7 @@ namespace Tensorflow.Keras.Engine | |||||
IVariableV1 _predict_counter; | IVariableV1 _predict_counter; | ||||
bool _base_model_initialized; | bool _base_model_initialized; | ||||
bool stop_training; | bool stop_training; | ||||
DataHandler data_handler; | |||||
public Model(ModelArgs args) | public Model(ModelArgs args) | ||||
: base(args) | : base(args) | ||||
@@ -6,5 +6,7 @@ namespace Tensorflow.Keras.Losses | |||||
{ | { | ||||
public interface ILossFunc | public interface ILossFunc | ||||
{ | { | ||||
string Reduction { get; } | |||||
Tensor Call(Tensor y_true, Tensor y_pred); | |||||
} | } | ||||
} | } |
@@ -1,6 +1,8 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Keras.Utils; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Keras.Losses | namespace Tensorflow.Keras.Losses | ||||
{ | { | ||||
@@ -14,6 +16,8 @@ namespace Tensorflow.Keras.Losses | |||||
bool _allow_sum_over_batch_size; | bool _allow_sum_over_batch_size; | ||||
string _name_scope; | string _name_scope; | ||||
public string Reduction => reduction; | |||||
public Loss(string reduction = ReductionV2.AUTO, string name = null) | public Loss(string reduction = ReductionV2.AUTO, string name = null) | ||||
{ | { | ||||
this.reduction = reduction; | this.reduction = reduction; | ||||
@@ -21,6 +25,17 @@ namespace Tensorflow.Keras.Losses | |||||
_allow_sum_over_batch_size = false; | _allow_sum_over_batch_size = false; | ||||
} | } | ||||
public virtual Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1) | |||||
{ | |||||
throw new NotImplementedException(""); | |||||
} | |||||
public Tensor Call(Tensor y_true, Tensor y_pred) | |||||
{ | |||||
var losses = Apply(y_true, y_pred); | |||||
return losses_utils.compute_weighted_loss(losses, reduction: ReductionV2.SUM_OVER_BATCH_SIZE); | |||||
} | |||||
void _set_name_scope() | void _set_name_scope() | ||||
{ | { | ||||
_name_scope = name; | _name_scope = name; | ||||
@@ -6,15 +6,11 @@ namespace Tensorflow.Keras.Losses | |||||
{ | { | ||||
public class LossFunctionWrapper : Loss | public class LossFunctionWrapper : Loss | ||||
{ | { | ||||
Action fn; | |||||
public LossFunctionWrapper(Action fn, | |||||
string reduction = ReductionV2.AUTO, | |||||
public LossFunctionWrapper(string reduction = ReductionV2.AUTO, | |||||
string name = null) | string name = null) | ||||
: base(reduction: reduction, | : base(reduction: reduction, | ||||
name: name) | name: name) | ||||
{ | { | ||||
this.fn = fn; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -6,6 +6,9 @@ namespace Tensorflow.Keras.Losses | |||||
{ | { | ||||
public class ReductionV2 | public class ReductionV2 | ||||
{ | { | ||||
public const string NONE = "none"; | |||||
public const string AUTO = "auto"; | public const string AUTO = "auto"; | ||||
public const string SUM_OVER_BATCH_SIZE = "sum_over_batch_size"; | |||||
public const string WEIGHTED_MEAN = "weighted_mean"; | |||||
} | } | ||||
} | } |
@@ -1,6 +1,8 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Data; | |||||
using System.Text; | using System.Text; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Keras.Losses | namespace Tensorflow.Keras.Losses | ||||
{ | { | ||||
@@ -9,16 +11,27 @@ namespace Tensorflow.Keras.Losses | |||||
public SparseCategoricalCrossentropy(bool from_logits = false, | public SparseCategoricalCrossentropy(bool from_logits = false, | ||||
string reduction = ReductionV2.AUTO, | string reduction = ReductionV2.AUTO, | ||||
string name = "sparse_categorical_crossentropy") : | string name = "sparse_categorical_crossentropy") : | ||||
base(sparse_categorical_crossentropy, | |||||
reduction: reduction, | |||||
base(reduction: reduction, | |||||
name: name) | name: name) | ||||
{ | { | ||||
} | } | ||||
static void sparse_categorical_crossentropy() | |||||
public override Tensor Apply(Tensor target, Tensor output, bool from_logits = false, int axis = -1) | |||||
{ | { | ||||
target = tf.cast(target, dtype: TF_DataType.TF_INT64); | |||||
// Try to adjust the shape so that rank of labels = rank of logits - 1. | |||||
var output_shape = array_ops.shape_v2(output); | |||||
var output_rank = output.TensorShape.ndim; | |||||
var target_rank = target.TensorShape.ndim; | |||||
var update_shape = target_rank != output_rank - 1; | |||||
if (update_shape) | |||||
{ | |||||
target = array_ops.reshape(target, new int[] { -1 }); | |||||
output = array_ops.reshape(output, new int[] { -1, output_shape[-1].numpy() }); | |||||
} | |||||
return tf.nn.sparse_softmax_cross_entropy_with_logits(target, output); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -2,6 +2,8 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Losses; | |||||
using Tensorflow.Keras.Utils; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
@@ -13,9 +15,14 @@ namespace Tensorflow.Keras.Metrics | |||||
{ | { | ||||
IVariableV1 total; | IVariableV1 total; | ||||
IVariableV1 count; | IVariableV1 count; | ||||
string _reduction; | |||||
TF_DataType _dtype; | |||||
public Reduce(string reduction, string name, TF_DataType dtype = TF_DataType.DtInvalid) | public Reduce(string reduction, string name, TF_DataType dtype = TF_DataType.DtInvalid) | ||||
: base(name: name, dtype: dtype) | : base(name: name, dtype: dtype) | ||||
{ | { | ||||
_reduction = reduction; | |||||
_dtype = dtype; | |||||
total = add_weight("total", initializer: tf.zeros_initializer); | total = add_weight("total", initializer: tf.zeros_initializer); | ||||
if (reduction == Reduction.WEIGHTED_MEAN || | if (reduction == Reduction.WEIGHTED_MEAN || | ||||
@@ -24,5 +31,36 @@ namespace Tensorflow.Keras.Metrics | |||||
count = add_weight("count", initializer: tf.zeros_initializer); | count = add_weight("count", initializer: tf.zeros_initializer); | ||||
} | } | ||||
} | } | ||||
public Tensor update_state(Tensor values, Tensor sample_weight = null) | |||||
{ | |||||
if(sample_weight != null) | |||||
{ | |||||
(values, sample_weight) = losses_utils.squeeze_or_expand_dimensions( | |||||
values, sample_weight: sample_weight); | |||||
sample_weight = math_ops.cast(sample_weight, dtype: values.dtype); | |||||
values = math_ops.multiply(values, sample_weight); | |||||
} | |||||
Tensor update_total_op = null; | |||||
var value_sum = math_ops.reduce_sum(values); | |||||
tf_with(ops.control_dependencies(new[] { value_sum }), ctl => | |||||
{ | |||||
var update_total_op = total.assign_add(value_sum); | |||||
}); | |||||
Tensor num_values = null; | |||||
if (_reduction == ReductionV2.WEIGHTED_MEAN) | |||||
{ | |||||
if (sample_weight == null) | |||||
num_values = math_ops.cast(array_ops.size(values), _dtype); | |||||
else | |||||
num_values = math_ops.reduce_sum(sample_weight); | |||||
} | |||||
return tf_with(ops.control_dependencies(new[] { update_total_op }), ctl | |||||
=> count.assign_add(num_values)); | |||||
} | |||||
} | } | ||||
} | } |
@@ -111,11 +111,16 @@ namespace Tensorflow.Keras.Optimizers | |||||
}); | }); | ||||
} | } | ||||
Tensor[] _aggregate_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars) | |||||
public Tensor[] _aggregate_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars) | |||||
{ | { | ||||
return grads_and_vars.Select(x => x.Item1).ToArray(); | return grads_and_vars.Select(x => x.Item1).ToArray(); | ||||
} | } | ||||
public Tensor[] _clip_gradients(Tensor[] grads) | |||||
{ | |||||
return grads; | |||||
} | |||||
protected IVariableV1 get_slot(IVariableV1 var, string slot_name) | protected IVariableV1 get_slot(IVariableV1 var, string slot_name) | ||||
{ | { | ||||
var slot_dict = _slots[var.UniqueId]; | var slot_dict = _slots[var.UniqueId]; | ||||
@@ -0,0 +1,83 @@ | |||||
/***************************************************************************** | |||||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using System; | |||||
using System.Linq; | |||||
using Tensorflow.Keras.Losses; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Keras.Utils | |||||
{ | |||||
public class losses_utils | |||||
{ | |||||
public static Tensor compute_weighted_loss(Tensor losses, Tensor sample_weight = null, string reduction = null, string name = null) | |||||
{ | |||||
if (sample_weight == null) | |||||
sample_weight = tf.constant(1.0f); | |||||
var weighted_losses = scale_losses_by_sample_weight(losses, sample_weight); | |||||
// Apply reduction function to the individual weighted losses. | |||||
var loss = reduce_weighted_loss(weighted_losses, reduction); | |||||
// Convert the result back to the input type. | |||||
// loss = math_ops.cast(loss, losses.dtype); | |||||
return loss; | |||||
} | |||||
public static Tensor scale_losses_by_sample_weight(Tensor losses, Tensor sample_weight) | |||||
{ | |||||
// losses = math_ops.cast(losses, dtypes.float32); | |||||
// sample_weight = math_ops.cast(sample_weight, dtypes.float32); | |||||
// Update dimensions of `sample_weight` to match with `losses` if possible. | |||||
// (losses, sample_weight) = squeeze_or_expand_dimensions(losses, sample_weight); | |||||
return math_ops.multiply(losses, sample_weight); | |||||
} | |||||
public static (Tensor, Tensor) squeeze_or_expand_dimensions(Tensor y_pred, Tensor sample_weight) | |||||
{ | |||||
var weights_shape = sample_weight.TensorShape; | |||||
var weights_rank = weights_shape.ndim; | |||||
if (weights_rank == 0) | |||||
return (y_pred, sample_weight); | |||||
throw new NotImplementedException(""); | |||||
} | |||||
public static Tensor reduce_weighted_loss(Tensor weighted_losses, string reduction) | |||||
{ | |||||
if (reduction == ReductionV2.NONE) | |||||
return weighted_losses; | |||||
else | |||||
{ | |||||
var loss = math_ops.reduce_sum(weighted_losses); | |||||
if (reduction == ReductionV2.SUM_OVER_BATCH_SIZE) | |||||
loss = _safe_mean(loss, _num_elements(weighted_losses)); | |||||
return loss; | |||||
} | |||||
} | |||||
static Tensor _safe_mean(Tensor losses, Tensor num_present) | |||||
{ | |||||
var total_loss = math_ops.reduce_sum(losses); | |||||
return math_ops.div_no_nan(total_loss, num_present, name: "value"); | |||||
} | |||||
static Tensor _num_elements(Tensor losses) | |||||
{ | |||||
return tf_with(ops.name_scope("num_elements"), scope => | |||||
{ | |||||
return math_ops.cast(array_ops.size(losses, name: scope), dtype: losses.dtype); | |||||
}); | |||||
} | |||||
} | |||||
} |
@@ -24,6 +24,9 @@ namespace Tensorflow | |||||
public Tensor compute_weighted_loss(Tensor losses, Tensor weights = null, string scope = null, | public Tensor compute_weighted_loss(Tensor losses, Tensor weights = null, string scope = null, | ||||
string loss_collection = "losses", string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) | string loss_collection = "losses", string reduction = Reduction.SUM_BY_NONZERO_WEIGHTS) | ||||
{ | { | ||||
if (weights == null) | |||||
weights = tf.constant(1.0f); | |||||
return tf_with(ops.name_scope(scope, default_name: "weighted_loss", (losses, weights)), delegate | return tf_with(ops.name_scope(scope, default_name: "weighted_loss", (losses, weights)), delegate | ||||
{ | { | ||||
// Save the `reduction` argument for loss normalization when distributing | // Save the `reduction` argument for loss normalization when distributing | ||||
@@ -521,6 +521,9 @@ namespace Tensorflow | |||||
public static Tensor shape(Tensor input, string name = null, TF_DataType out_type = TF_DataType.TF_INT32) | public static Tensor shape(Tensor input, string name = null, TF_DataType out_type = TF_DataType.TF_INT32) | ||||
=> shape_internal(input, name, optimize: true, out_type: out_type); | => shape_internal(input, name, optimize: true, out_type: out_type); | ||||
public static Tensor shape_v2(Tensor input, string name = null, TF_DataType out_type = TF_DataType.TF_INT32) | |||||
=> shape_internal(input, name, optimize: true, out_type: out_type); | |||||
public static Tensor size(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) | public static Tensor size(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) | ||||
=> size_internal(input, name, optimize: optimize, out_type: out_type); | => size_internal(input, name, optimize: optimize, out_type: out_type); | ||||
@@ -118,10 +118,13 @@ namespace Tensorflow | |||||
/// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) | /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) | ||||
/// </remarks> | /// </remarks> | ||||
public static Tensor div_no_nan(Tensor x, Tensor y, string name = null) | public static Tensor div_no_nan(Tensor x, Tensor y, string name = null) | ||||
{ | |||||
var op = tf.OpDefLib._apply_op_helper("DivNoNan", name: name, args: new { x, y }); | |||||
return op.output; | |||||
} | |||||
=> tf.Context.RunInAutoMode(() | |||||
=> tf.OpDefLib._apply_op_helper("DivNoNan", name: name, new { x, y }).output, () | |||||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
"DivNoNan", name, | |||||
null, | |||||
x, y).FirstOrDefault(), | |||||
x, y); | |||||
/// <summary> | /// <summary> | ||||
/// Computes the mean of elements across dimensions of a tensor. | /// Computes the mean of elements across dimensions of a tensor. | ||||