|
|
@@ -15,6 +15,7 @@ |
|
|
|
******************************************************************************/ |
|
|
|
|
|
|
|
using System; |
|
|
|
using System.Xml.Linq; |
|
|
|
using Tensorflow.Keras.Losses; |
|
|
|
using static Tensorflow.Binding; |
|
|
|
|
|
|
@@ -37,15 +38,57 @@ namespace Tensorflow.Keras.Utils |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
public static (Tensor, Tensor) squeeze_or_expand_dimensions(Tensor y_pred, Tensor sample_weight) |
|
|
|
public static (Tensor, Tensor) squeeze_or_expand_dimensions(Tensor y_pred, Tensor y_true = null, Tensor sample_weight = null) |
|
|
|
{ |
|
|
|
var y_pred_shape = y_pred.shape; |
|
|
|
var y_pred_rank = y_pred_shape.ndim; |
|
|
|
if (y_true != null) |
|
|
|
{ |
|
|
|
var y_true_shape = y_true.shape; |
|
|
|
var y_true_rank = y_true_shape.ndim; |
|
|
|
if (y_true_rank > -1 && y_pred_rank > -1) |
|
|
|
{ |
|
|
|
if (y_pred_rank - y_true_rank != 1 || y_pred_shape[-1] == 1) |
|
|
|
{ |
|
|
|
(y_true, y_pred) = remove_squeezable_dimensions(y_true, y_pred); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (sample_weight == null) |
|
|
|
{ |
|
|
|
return (y_pred, y_true); |
|
|
|
} |
|
|
|
|
|
|
|
var weights_shape = sample_weight.shape; |
|
|
|
var weights_rank = weights_shape.ndim; |
|
|
|
if (weights_rank == 0) |
|
|
|
return (y_pred, sample_weight); |
|
|
|
|
|
|
|
if (y_pred_rank > -1 && weights_rank > -1) |
|
|
|
{ |
|
|
|
if (weights_rank - y_pred_rank == 1) |
|
|
|
{ |
|
|
|
sample_weight = tf.squeeze(sample_weight, -1); |
|
|
|
} |
|
|
|
else if (y_pred_rank - weights_rank == 1) |
|
|
|
{ |
|
|
|
sample_weight = tf.expand_dims(sample_weight, -1); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
return (y_pred, sample_weight); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
throw new NotImplementedException(""); |
|
|
|
} |
|
|
|
|
|
|
|
public static (Tensor, Tensor) remove_squeezable_dimensions(Tensor labels, Tensor predictions, int expected_rank_diff = 0, string name = null) |
|
|
|
{ |
|
|
|
return (labels, predictions); |
|
|
|
} |
|
|
|
|
|
|
|
public static Tensor reduce_weighted_loss(Tensor weighted_losses, string reduction) |
|
|
|
{ |
|
|
|
if (reduction == ReductionV2.NONE) |
|
|
|