diff --git a/src/TensorFlowNET.Keras/Utils/losses_utils.cs b/src/TensorFlowNET.Keras/Utils/losses_utils.cs index 8a1ebbc5..ec6f6e4e 100644 --- a/src/TensorFlowNET.Keras/Utils/losses_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/losses_utils.cs @@ -1,4 +1,4 @@ -/***************************************************************************** +/***************************************************************************** Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,7 +25,7 @@ namespace Tensorflow.Keras.Utils public static Tensor compute_weighted_loss(Tensor losses, Tensor sample_weight = null, string reduction = null, string name = null) { if (sample_weight == null) - sample_weight = tf.constant(1.0f); + sample_weight = losses.dtype == TF_DataType.TF_DOUBLE ? tf.constant(1.0) : tf.constant(1.0f); var weighted_losses = scale_losses_by_sample_weight(losses, sample_weight); // Apply reduction function to the individual weighted losses. var loss = reduce_weighted_loss(weighted_losses, reduction);