Browse Source

Update losses_utils.cs

修改sample_weight初始化
pull/675/head
dataangel GitHub 4 years ago
parent
commit
78bc96b5a3
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions
  1. +2
    -2
      src/TensorFlowNET.Keras/Utils/losses_utils.cs

+ 2
- 2
src/TensorFlowNET.Keras/Utils/losses_utils.cs View File

@@ -1,4 +1,4 @@
/*****************************************************************************
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
@@ -25,7 +25,7 @@ namespace Tensorflow.Keras.Utils
public static Tensor compute_weighted_loss(Tensor losses, Tensor sample_weight = null, string reduction = null, string name = null)
{
if (sample_weight == null)
sample_weight = tf.constant(1.0f);
sample_weight = losses.dtype.ToString()=="TF_DOUBLE"? tf.constant(1.0) : tf.constant(1.0f);
var weighted_losses = scale_losses_by_sample_weight(losses, sample_weight);
// Apply reduction function to the individual weighted losses.
var loss = reduce_weighted_loss(weighted_losses, reduction);


Loading…
Cancel
Save