From 78bc96b5a37fc62dc0e7905af5a2336e59d5abbb Mon Sep 17 00:00:00 2001 From: dataangel Date: Mon, 14 Dec 2020 06:53:48 +0800 Subject: [PATCH] Update losses_utils.cs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修改sample_weight初始化 --- src/TensorFlowNET.Keras/Utils/losses_utils.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/TensorFlowNET.Keras/Utils/losses_utils.cs b/src/TensorFlowNET.Keras/Utils/losses_utils.cs index 8a1ebbc5..4cccc08d 100644 --- a/src/TensorFlowNET.Keras/Utils/losses_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/losses_utils.cs @@ -1,4 +1,4 @@ -/***************************************************************************** +/***************************************************************************** Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -25,7 +25,7 @@ namespace Tensorflow.Keras.Utils public static Tensor compute_weighted_loss(Tensor losses, Tensor sample_weight = null, string reduction = null, string name = null) { if (sample_weight == null) - sample_weight = tf.constant(1.0f); + sample_weight = losses.dtype.ToString()=="TF_DOUBLE"? tf.constant(1.0) : tf.constant(1.0f); var weighted_losses = scale_losses_by_sample_weight(losses, sample_weight); // Apply reduction function to the individual weighted losses. var loss = reduce_weighted_loss(weighted_losses, reduction);