diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
index 91b96230..4d2c2dec 100644
--- a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
@@ -1,5 +1,6 @@
using System;
using System.Linq;
+using Microsoft.Extensions.Logging;
using Tensorflow.Gradients;
using static Tensorflow.Binding;
using static Tensorflow.tensorflow;
@@ -38,7 +39,7 @@ namespace Tensorflow.Eager
}*/
}
- // Console.WriteLine($"RecordGradient: should_record={should_record}, op_name={op_name}");
+ tf.Logger.LogDebug($"RecordGradient: should_record={should_record}, op_name={op_name}");
if (!should_record) return should_record;
Tensor[] op_outputs;
diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetPossibleGradientTypes.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetPossibleGradientTypes.cs
new file mode 100644
index 00000000..0a23cdd4
--- /dev/null
+++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TapeSetPossibleGradientTypes.cs
@@ -0,0 +1,15 @@
+using System;
+using Tensorflow.Gradients;
+using static Tensorflow.Binding;
+using static Tensorflow.tensorflow;
+
+namespace Tensorflow.Eager
+{
+ public partial class EagerRunner
+ {
+ public int TapeSetPossibleGradientTypes(params Tensor[] args)
+ {
+ return 1;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Eager/IEagerRunner.cs b/src/TensorFlowNET.Core/Eager/IEagerRunner.cs
index a6994d18..bbc1b882 100644
--- a/src/TensorFlowNET.Core/Eager/IEagerRunner.cs
+++ b/src/TensorFlowNET.Core/Eager/IEagerRunner.cs
@@ -40,5 +40,7 @@ namespace Tensorflow.Eager
Tensor[] results);
bool MustRecordGradient();
+
+ int TapeSetPossibleGradientTypes(params Tensor[] args);
}
}
diff --git a/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs b/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs
index 2550dd59..299be4a7 100644
--- a/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs
+++ b/src/TensorFlowNET.Core/Gradients/Tape.RecordOperation.cs
@@ -1,7 +1,9 @@
using System;
using System.Collections.Generic;
+using Microsoft.Extensions.Logging;
using Tensorflow.Util;
using static Tensorflow.tensorflow;
+using static Tensorflow.Binding;
namespace Tensorflow.Gradients
{
@@ -34,6 +36,7 @@ namespace Tensorflow.Gradients
foreach (var o in output_tensors)
{
tensor_tape_[o.GetID()] = op_id;
+ tf.Logger.LogDebug($"RecordOperation: tensor_tape_[{o.GetID()}] = {op_id}");
tensor_usage_[o.GetID()] = 1;
tensors.Add(o);
}
diff --git a/src/TensorFlowNET.Core/Gradients/Tape.cs b/src/TensorFlowNET.Core/Gradients/Tape.cs
index 33e3ea89..d43e4ee3 100644
--- a/src/TensorFlowNET.Core/Gradients/Tape.cs
+++ b/src/TensorFlowNET.Core/Gradients/Tape.cs
@@ -1,5 +1,7 @@
-using System.Collections.Generic;
+using System;
+using System.Collections.Generic;
using Tensorflow.Util;
+using Microsoft.Extensions.Logging;
using static Tensorflow.Binding;
using static Tensorflow.tensorflow;
@@ -42,6 +44,7 @@ namespace Tensorflow.Gradients
if (!CouldBackprop())
return;
+ tf.Logger.LogDebug($"Watch tensor_id={tensor_id}");
tensor_tape_.emplace(tensor_id, -1);
}
@@ -50,8 +53,13 @@ namespace Tensorflow.Gradients
for (int i = 0; i < tensor_ids.Length; ++i)
{
if (tensor_tape_.find(tensor_ids[i]))
+ {
if (IsDtypeTrainable(dtypes[i]))
+ {
+ tf.Logger.LogDebug($"tape.h->ShouldRecord: should_record = true, tensor_tape_.size()={tensor_tape_.Count}, tensor_ids[{i}]={tensor_ids[i]}");
return true;
+ }
+ }
}
return false;
}
diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
index 94e1c001..f6442c20 100644
--- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
+++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
@@ -84,6 +84,9 @@ TensorFlow .NET v0.30 is focused on making more Keras API work including:
+
+
+
diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs
index 20d43baf..903ec3e6 100644
--- a/src/TensorFlowNET.Core/tensorflow.cs
+++ b/src/TensorFlowNET.Core/tensorflow.cs
@@ -14,6 +14,8 @@
limitations under the License.
******************************************************************************/
+using Microsoft.Extensions.DependencyInjection;
+using Microsoft.Extensions.Logging;
using System.Collections.Generic;
using Tensorflow.Contexts;
using Tensorflow.Eager;
@@ -41,9 +43,18 @@ namespace Tensorflow
public OpDefLibrary OpDefLib;
public Context Context;
public IEagerRunner Runner;
+ public ILogger Logger;
+ ServiceProvider serviceProvider;
public tensorflow()
{
+ serviceProvider = new ServiceCollection()
+ .AddLogging(cfg => cfg.AddConsole())
+ .Configure(cfg => cfg.MinLevel = LogLevel.Warning)
+ .BuildServiceProvider();
+
+ Logger = serviceProvider.GetService>();
+
Status = new Status();
Context = new Context(new ContextOptions(), Status);
OpDefLib = new OpDefLibrary();
diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs
index f0b8a43d..f0f2ab2b 100644
--- a/src/TensorFlowNET.Keras/Engine/Functional.cs
+++ b/src/TensorFlowNET.Keras/Engine/Functional.cs
@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Utils;
+using Microsoft.Extensions.Logging;
using static Tensorflow.Binding;
namespace Tensorflow.Keras.Engine
@@ -335,7 +336,7 @@ namespace Tensorflow.Keras.Engine
var layer_inputs = node.MapArguments(tensor_dict);
- // Console.WriteLine($"{node.Layer}: {node.Layer.Name}");
+ tf.Logger.LogDebug($"{node.Layer}: {node.Layer.Name}");
var outputs = node.Layer.Apply(layer_inputs, is_training: training);
// Update tensor_dict for next input