Browse Source

Merge pull request #1089 from DevNullx64/SimpleRefacto

refactor: Standardize TensorFlowNET.Keras/Losses/*
tags/v0.110.0-LSTM-Model
Haiping GitHub 2 years ago
parent
commit
367ac9efce
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 200 additions and 253 deletions
  1. +2
    -2
      src/TensorFlowNET.Keras/Losses/BinaryCrossentropy.cs
  2. +2
    -2
      src/TensorFlowNET.Keras/Losses/CategoricalCrossentropy.cs
  3. +17
    -23
      src/TensorFlowNET.Keras/Losses/CosineSimilarity.cs
  4. +23
    -30
      src/TensorFlowNET.Keras/Losses/Huber.cs
  5. +15
    -22
      src/TensorFlowNET.Keras/Losses/LogCosh.cs
  6. +43
    -47
      src/TensorFlowNET.Keras/Losses/Loss.cs
  7. +10
    -12
      src/TensorFlowNET.Keras/Losses/LossFunctionWrapper.cs
  8. +11
    -18
      src/TensorFlowNET.Keras/Losses/MeanAbsoluteError.cs
  9. +12
    -19
      src/TensorFlowNET.Keras/Losses/MeanAbsolutePercentageError.cs
  10. +11
    -18
      src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs
  11. +22
    -27
      src/TensorFlowNET.Keras/Losses/MeanSquaredLogarithmicError.cs
  12. +1
    -2
      src/TensorFlowNET.Keras/Losses/SigmoidFocalCrossEntropy.cs
  13. +31
    -31
      src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs

+ 2
- 2
src/TensorFlowNET.Keras/Losses/BinaryCrossentropy.cs View File

@@ -1,8 +1,9 @@
namespace Tensorflow.Keras.Losses;

public class BinaryCrossentropy : LossFunctionWrapper, ILossFunc
public class BinaryCrossentropy : LossFunctionWrapper
{
float label_smoothing;

public BinaryCrossentropy(
bool from_logits = false,
float label_smoothing = 0,
@@ -15,7 +16,6 @@ public class BinaryCrossentropy : LossFunctionWrapper, ILossFunc
this.label_smoothing = label_smoothing;
}


public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1)
{
var sum = keras.backend.binary_crossentropy(y_true, y_pred, from_logits: from_logits);


+ 2
- 2
src/TensorFlowNET.Keras/Losses/CategoricalCrossentropy.cs View File

@@ -1,8 +1,9 @@
namespace Tensorflow.Keras.Losses;

public class CategoricalCrossentropy : LossFunctionWrapper, ILossFunc
public class CategoricalCrossentropy : LossFunctionWrapper
{
float label_smoothing;

public CategoricalCrossentropy(
bool from_logits = false,
float label_smoothing = 0,
@@ -15,7 +16,6 @@ public class CategoricalCrossentropy : LossFunctionWrapper, ILossFunc
this.label_smoothing = label_smoothing;
}


public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1)
{
// Try to adjust the shape so that rank of labels = rank of logits - 1.


+ 17
- 23
src/TensorFlowNET.Keras/Losses/CosineSimilarity.cs View File

@@ -1,28 +1,22 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
namespace Tensorflow.Keras.Losses;

namespace Tensorflow.Keras.Losses
public class CosineSimilarity : LossFunctionWrapper
{
public class CosineSimilarity : LossFunctionWrapper, ILossFunc
protected int axis = -1;

public CosineSimilarity(
string reduction = null,
int axis = -1,
string name = null) :
base(reduction: reduction, name: name == null ? "cosine_similarity" : name)
{
protected int axis=-1;
public CosineSimilarity(
string reduction = null,
int axis=-1,
string name = null) :
base(reduction: reduction, name: name == null ? "cosine_similarity" : name)
{
this.axis = axis;
}
this.axis = axis;
}

public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
{
Tensor y_true_normalize = nn_impl.l2_normalize(y_true, axis : this.axis);
Tensor y_pred_normalize = nn_impl.l2_normalize(y_pred, axis: this.axis);
return -math_ops.reduce_sum(y_true_normalize * y_pred_normalize, axis : constant_op.constant(this.axis));
}
public override Tensor Apply(Tensor y_true = null, Tensor y_pred = null, bool from_logits = false, int axis = -1)
{
Tensor y_true_normalize = nn_impl.l2_normalize(y_true, axis: this.axis);
Tensor y_pred_normalize = nn_impl.l2_normalize(y_pred, axis: this.axis);
return -math_ops.reduce_sum(y_true_normalize * y_pred_normalize, axis: constant_op.constant(this.axis));
}
}
}

+ 23
- 30
src/TensorFlowNET.Keras/Losses/Huber.cs View File

@@ -1,36 +1,29 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
namespace Tensorflow.Keras.Losses;

namespace Tensorflow.Keras.Losses
public class Huber : LossFunctionWrapper
{
public class Huber : LossFunctionWrapper, ILossFunc
protected Tensor delta = tf.Variable(1.0);

public Huber(
string reduction = null,
Tensor delta = null,
string name = null) :
base(reduction: reduction, name: name == null ? "huber" : name)
{
protected Tensor delta = tf.Variable(1.0) ;
public Huber (
string reduction = null,
Tensor delta = null,
string name = null) :
base(reduction: reduction, name: name == null ? "huber" : name)
{
this.delta = delta==null? this.delta: delta;
}
this.delta = delta == null ? this.delta : delta;
}

public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
{
Tensor y_pred_cast = math_ops.cast(y_pred, dtype: TF_DataType.TF_FLOAT);
Tensor y_true_cast = math_ops.cast(y_true, dtype: TF_DataType.TF_FLOAT);
Tensor delta = math_ops.cast(this.delta, dtype: TF_DataType.TF_FLOAT);
Tensor error = math_ops.subtract(y_pred_cast, y_true_cast);
Tensor abs_error = math_ops.abs(error);
Tensor half = ops.convert_to_tensor(0.5, dtype: abs_error.dtype);
return gen_math_ops.mean(array_ops.where_v2(abs_error <= delta,
half * math_ops.pow(error, 2),
half * math_ops.pow(delta, 2) + delta * (abs_error - delta)),
ops.convert_to_tensor(-1));
}
public override Tensor Apply(Tensor y_true = null, Tensor y_pred = null, bool from_logits = false, int axis = -1)
{
Tensor y_pred_cast = math_ops.cast(y_pred, dtype: TF_DataType.TF_FLOAT);
Tensor y_true_cast = math_ops.cast(y_true, dtype: TF_DataType.TF_FLOAT);
Tensor delta = math_ops.cast(this.delta, dtype: TF_DataType.TF_FLOAT);
Tensor error = math_ops.subtract(y_pred_cast, y_true_cast);
Tensor abs_error = math_ops.abs(error);
Tensor half = ops.convert_to_tensor(0.5, dtype: abs_error.dtype);
return gen_math_ops.mean(array_ops.where_v2(abs_error <= delta,
half * math_ops.pow(error, 2),
half * math_ops.pow(delta, 2) + delta * (abs_error - delta)),
ops.convert_to_tensor(-1));
}
}

+ 15
- 22
src/TensorFlowNET.Keras/Losses/LogCosh.cs View File

@@ -1,27 +1,20 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Operations;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
namespace Tensorflow.Keras.Losses;

namespace Tensorflow.Keras.Losses
public class LogCosh : LossFunctionWrapper
{
public class LogCosh : LossFunctionWrapper, ILossFunc
{
public LogCosh(
string reduction = null,
string name = null) :
base(reduction: reduction, name: name == null ? "log_cosh" : name){ }
public LogCosh(
string reduction = null,
string name = null) :
base(reduction: reduction, name: name == null ? "log_cosh" : name)
{ }

public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
{
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
Tensor x = y_pred_dispatch - y_true_cast;
public override Tensor Apply(Tensor y_true = null, Tensor y_pred = null, bool from_logits = false, int axis = -1)
{
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
Tensor x = y_pred_dispatch - y_true_cast;

return gen_math_ops.mean(x + gen_nn_ops.softplus(-2.0 * x) - math_ops.cast(math_ops.log(tf.Variable(2.0)), x.dtype),
ops.convert_to_tensor(-1));
}
return gen_math_ops.mean(x + gen_nn_ops.softplus(-2.0 * x) - math_ops.cast(math_ops.log(tf.Variable(2.0)), x.dtype),
ops.convert_to_tensor(-1));
}
}
}

+ 43
- 47
src/TensorFlowNET.Keras/Losses/Loss.cs View File

@@ -1,55 +1,51 @@
using System;
using Tensorflow.Keras.Utils;
using Tensorflow.Keras.Utils;

namespace Tensorflow.Keras.Losses
namespace Tensorflow.Keras.Losses;

/// <summary>
/// Loss base class.
/// </summary>
public abstract class Loss : ILossFunc
{
/// <summary>
/// Loss base class.
/// </summary>
public abstract class Loss
protected string reduction;
protected string name;
bool _allow_sum_over_batch_size;
protected bool from_logits = false;
string _name_scope;

public string Reduction => reduction;
public string Name => name;

public Loss(string reduction = ReductionV2.AUTO,
string name = null,
bool from_logits = false)
{
protected string reduction;
protected string name;
bool _allow_sum_over_batch_size;
protected bool from_logits = false;
string _name_scope;

public string Reduction => reduction;
public string Name => name;
public Loss(string reduction = ReductionV2.AUTO,
string name = null,
bool from_logits = false)
{
this.reduction = reduction == null ? ReductionV2.SUM_OVER_BATCH_SIZE : reduction;
this.name = name;
this.from_logits = from_logits;
_allow_sum_over_batch_size = false;
}
this.reduction = reduction == null ? ReductionV2.SUM_OVER_BATCH_SIZE : reduction;
this.name = name;
this.from_logits = from_logits;
_allow_sum_over_batch_size = false;
}

public virtual Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1)
{
throw new NotImplementedException("");
}
public abstract Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1);

public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
{
var losses = Apply(y_true, y_pred, from_logits: from_logits);
var reduction = GetReduction();
return losses_utils.compute_weighted_loss(losses, reduction: reduction, sample_weight: sample_weight);
}
public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
{
var losses = Apply(y_true, y_pred, from_logits: from_logits);
var reduction = GetReduction();
return losses_utils.compute_weighted_loss(losses, reduction: reduction, sample_weight: sample_weight);
}

string GetReduction()
{
return reduction switch
{
ReductionV2.AUTO => ReductionV2.SUM_OVER_BATCH_SIZE,
_ => reduction
};
}

void _set_name_scope()
string GetReduction()
{
return reduction switch
{
_name_scope = name;
}
ReductionV2.AUTO => ReductionV2.SUM_OVER_BATCH_SIZE,
_ => reduction
};
}

void _set_name_scope()
{
_name_scope = name;
}
}
}

+ 10
- 12
src/TensorFlowNET.Keras/Losses/LossFunctionWrapper.cs View File

@@ -1,16 +1,14 @@
using Tensorflow.Keras.Utils;

namespace Tensorflow.Keras.Losses
namespace Tensorflow.Keras.Losses;

public abstract class LossFunctionWrapper : Loss
{
public class LossFunctionWrapper : Loss
{
public LossFunctionWrapper(string reduction = ReductionV2.AUTO,
string name = null,
bool from_logits = false)
: base(reduction: reduction,
name: name,
from_logits: from_logits)
{
}
}
public LossFunctionWrapper(string reduction = ReductionV2.AUTO,
string name = null,
bool from_logits = false)
: base(reduction: reduction,
name: name,
from_logits: from_logits)
{ }
}

+ 11
- 18
src/TensorFlowNET.Keras/Losses/MeanAbsoluteError.cs View File

@@ -1,23 +1,16 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
namespace Tensorflow.Keras.Losses;

namespace Tensorflow.Keras.Losses
public class MeanAbsoluteError : LossFunctionWrapper
{
public class MeanAbsoluteError : LossFunctionWrapper, ILossFunc
{
public MeanAbsoluteError(
string reduction = null,
string name = null) :
base(reduction: reduction, name: name == null ? "mean_absolute_error" : name){ }
public MeanAbsoluteError(
string reduction = null,
string name = null) :
base(reduction: reduction, name: name == null ? "mean_absolute_error" : name){ }

public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
{
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
return gen_math_ops.mean(math_ops.abs(y_pred_dispatch - y_true_cast), ops.convert_to_tensor(-1));
}
public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
{
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
return gen_math_ops.mean(math_ops.abs(y_pred_dispatch - y_true_cast), ops.convert_to_tensor(-1));
}
}

+ 12
- 19
src/TensorFlowNET.Keras/Losses/MeanAbsolutePercentageError.cs View File

@@ -1,24 +1,17 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
namespace Tensorflow.Keras.Losses;

namespace Tensorflow.Keras.Losses
public class MeanAbsolutePercentageError : LossFunctionWrapper
{
public class MeanAbsolutePercentageError : LossFunctionWrapper, ILossFunc
{
public MeanAbsolutePercentageError(
string reduction = null,
string name = null) :
base(reduction: reduction, name: name == null ? "mean_absolute_percentage_error" : name){ }
public MeanAbsolutePercentageError(
string reduction = null,
string name = null) :
base(reduction: reduction, name: name == null ? "mean_absolute_percentage_error" : name){ }

public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
{
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
Tensor diff = math_ops.abs(y_true_cast - y_pred_dispatch) / gen_math_ops.maximum(math_ops.abs(y_true_cast), gen_math_ops.cast(tf.constant(1e-7), y_pred_dispatch.dtype));
return gen_math_ops.cast(tf.constant(100), y_pred_dispatch.dtype) * gen_math_ops.mean(diff, ops.convert_to_tensor(-1));
}
public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
{
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
Tensor diff = math_ops.abs(y_true_cast - y_pred_dispatch) / gen_math_ops.maximum(math_ops.abs(y_true_cast), gen_math_ops.cast(tf.constant(1e-7), y_pred_dispatch.dtype));
return gen_math_ops.cast(tf.constant(100), y_pred_dispatch.dtype) * gen_math_ops.mean(diff, ops.convert_to_tensor(-1));
}
}

+ 11
- 18
src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs View File

@@ -1,23 +1,16 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
namespace Tensorflow.Keras.Losses;

namespace Tensorflow.Keras.Losses
public class MeanSquaredError : LossFunctionWrapper
{
public class MeanSquaredError : LossFunctionWrapper, ILossFunc
{
public MeanSquaredError(
string reduction = null,
string name = null) :
base(reduction: reduction, name: name==null? "mean_squared_error" : name){ }
public MeanSquaredError(
string reduction = null,
string name = null) :
base(reduction: reduction, name: name==null? "mean_squared_error" : name){ }

public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
{
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
return gen_math_ops.mean(gen_math_ops.squared_difference(y_pred_dispatch, y_true_cast), ops.convert_to_tensor(-1));
}
public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
{
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
return gen_math_ops.mean(gen_math_ops.squared_difference(y_pred_dispatch, y_true_cast), ops.convert_to_tensor(-1));
}
}

+ 22
- 27
src/TensorFlowNET.Keras/Losses/MeanSquaredLogarithmicError.cs View File

@@ -1,33 +1,28 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
namespace Tensorflow.Keras.Losses;

namespace Tensorflow.Keras.Losses
public class MeanSquaredLogarithmicError : LossFunctionWrapper
{
public class MeanSquaredLogarithmicError : LossFunctionWrapper, ILossFunc
{
public MeanSquaredLogarithmicError(
string reduction = null,
string name = null) :
base(reduction: reduction, name: name == null ? "mean_squared_logarithmic_error" : name){ }

public MeanSquaredLogarithmicError(
string reduction = null,
string name = null) :
base(reduction: reduction, name: name == null ? "mean_squared_logarithmic_error" : name)
{ }

public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
public override Tensor Apply(Tensor y_true = null, Tensor y_pred = null, bool from_logits = false, int axis = -1)
{
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
Tensor first_log = null, second_log = null;
if (y_pred_dispatch.dtype == TF_DataType.TF_DOUBLE)
{
first_log = math_ops.log(math_ops.maximum(y_pred_dispatch, 1e-7) + 1.0);
second_log = math_ops.log(math_ops.maximum(y_true_cast, 1e-7) + 1.0);
}
else
{
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
Tensor first_log=null, second_log=null;
if (y_pred_dispatch.dtype == TF_DataType.TF_DOUBLE) {
first_log = math_ops.log(math_ops.maximum(y_pred_dispatch, 1e-7) + 1.0);
second_log = math_ops.log(math_ops.maximum(y_true_cast, 1e-7) + 1.0);
}
else {
first_log = math_ops.log(math_ops.maximum(y_pred_dispatch, 1e-7f) + 1.0f);
second_log = math_ops.log(math_ops.maximum(y_true_cast, 1e-7f) + 1.0f);
}
return gen_math_ops.mean(gen_math_ops.squared_difference(first_log, second_log), ops.convert_to_tensor(-1));
first_log = math_ops.log(math_ops.maximum(y_pred_dispatch, 1e-7f) + 1.0f);
second_log = math_ops.log(math_ops.maximum(y_true_cast, 1e-7f) + 1.0f);
}
return gen_math_ops.mean(gen_math_ops.squared_difference(first_log, second_log), ops.convert_to_tensor(-1));
}
}
}

+ 1
- 2
src/TensorFlowNET.Keras/Losses/SigmoidFocalCrossEntropy.cs View File

@@ -2,7 +2,7 @@

namespace Tensorflow.Keras.Losses;

public class SigmoidFocalCrossEntropy : LossFunctionWrapper, ILossFunc
public class SigmoidFocalCrossEntropy : LossFunctionWrapper
{
float _alpha;
float _gamma;
@@ -20,7 +20,6 @@ public class SigmoidFocalCrossEntropy : LossFunctionWrapper, ILossFunc
_gamma = gamma;
}


public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1)
{
y_true = tf.cast(y_true, dtype: y_pred.dtype);


+ 31
- 31
src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs View File

@@ -1,41 +1,41 @@
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Losses
namespace Tensorflow.Keras.Losses;

public class SparseCategoricalCrossentropy : LossFunctionWrapper
{
public class SparseCategoricalCrossentropy : LossFunctionWrapper, ILossFunc
private bool _from_logits = false;

public SparseCategoricalCrossentropy(
bool from_logits = false,
string reduction = null,
string name = null) :
base(reduction: reduction, name: name == null ? "sparse_categorical_crossentropy" : name)
{
_from_logits = from_logits;
}

public override Tensor Apply(Tensor target, Tensor output, bool from_logits = false, int axis = -1)
{
private bool _from_logits = false;
public SparseCategoricalCrossentropy(
bool from_logits = false,
string reduction = null,
string name = null) :
base(reduction: reduction, name: name == null ? "sparse_categorical_crossentropy" : name)
target = tf.cast(target, dtype: TF_DataType.TF_INT64);

if (!_from_logits)
{
_from_logits = from_logits;
var epsilon = tf.constant(KerasApi.keras.backend.epsilon(), output.dtype);
output = tf.clip_by_value(output, epsilon, 1 - epsilon);
output = tf.log(output);
}

public override Tensor Apply(Tensor target, Tensor output, bool from_logits = false, int axis = -1)
// Try to adjust the shape so that rank of labels = rank of logits - 1.
var output_shape = array_ops.shape_v2(output);
var output_rank = output.shape.ndim;
var target_rank = target.shape.ndim;
var update_shape = target_rank != output_rank - 1;
if (update_shape)
{
target = tf.cast(target, dtype: TF_DataType.TF_INT64);

if (!_from_logits)
{
var epsilon = tf.constant(KerasApi.keras.backend.epsilon(), output.dtype);
output = tf.clip_by_value(output, epsilon, 1 - epsilon);
output = tf.log(output);
}

// Try to adjust the shape so that rank of labels = rank of logits - 1.
var output_shape = array_ops.shape_v2(output);
var output_rank = output.shape.ndim;
var target_rank = target.shape.ndim;
var update_shape = target_rank != output_rank - 1;
if (update_shape)
{
target = array_ops.reshape(target, new int[] { -1 });
output = array_ops.reshape(output, new int[] { -1, output_shape[-1].numpy() });
}
return tf.nn.sparse_softmax_cross_entropy_with_logits(target, output);
target = array_ops.reshape(target, new int[] { -1 });
output = array_ops.reshape(output, new int[] { -1, output_shape[-1].numpy() });
}
return tf.nn.sparse_softmax_cross_entropy_with_logits(target, output);
}
}
}

Loading…
Cancel
Save