From 1fa2f1d7b2e9102b9ad30668e03b5c3a2cb7d765 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 28 Aug 2021 14:12:24 -0500 Subject: [PATCH] fix stop_gradient in eager mode. --- .../Tensorflow.Console.csproj | 2 +- .../Distributions/distribution.py.cs | 20 +------------------ .../Operations/array_ops.cs | 2 +- 3 files changed, 3 insertions(+), 21 deletions(-) diff --git a/src/TensorFlowNET.Console/Tensorflow.Console.csproj b/src/TensorFlowNET.Console/Tensorflow.Console.csproj index 8efbf1bb..2ed2f41b 100644 --- a/src/TensorFlowNET.Console/Tensorflow.Console.csproj +++ b/src/TensorFlowNET.Console/Tensorflow.Console.csproj @@ -19,7 +19,7 @@ - + diff --git a/src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs b/src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs index 988c5326..4375788d 100644 --- a/src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs +++ b/src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs @@ -60,25 +60,7 @@ namespace Tensorflow { return tf_with(ops.name_scope(name, "moments", new { value }), scope => { - try - { - return _log_prob(value); - } -#pragma warning disable CS0168 // Variable is declared but never used - catch (Exception e1) -#pragma warning restore CS0168 // Variable is declared but never used - { - try - { - return math_ops.log(_prob(value)); -#pragma warning disable CS0168 // Variable is declared but never used - } - catch (Exception e2) -#pragma warning restore CS0168 // Variable is declared but never used - { - throw new NotImplementedException(); - } - } + return math_ops.log(value); }); } diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index fc83cf4e..b0ef1f2d 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -656,7 +656,7 @@ namespace Tensorflow /// /// public static Tensor stop_gradient(Tensor input, string name = null) - => gen_array_ops.stop_gradient(input, name); + => tf.Context.ExecuteOp("StopGradient", name, new ExecuteOpArgs(input)); /// /// Extracts a strided slice of a tensor (generalized python array indexing).