diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index d58fbe70..ed895ffc 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -21,6 +21,7 @@ using System.Collections; using System.Collections.Generic; using System.ComponentModel; using System.Diagnostics; +using System.IO; using System.Linq; namespace Tensorflow @@ -112,16 +113,33 @@ namespace Tensorflow } } + private static TextWriter writer = null; + + public static TextWriter tf_output_redirect { + set + { + var originWriter = writer ?? Console.Out; + originWriter.Flush(); + if (originWriter is StringWriter) + (originWriter as StringWriter).GetStringBuilder().Clear(); + writer = value; + } + get + { + return writer ?? Console.Out; + } + } + public static void print(object obj) { - Console.WriteLine(_tostring(obj)); + tf_output_redirect.WriteLine(_tostring(obj)); } public static void print(string format, params object[] objects) { if (!format.Contains("{}")) { - Console.WriteLine(format + " " + string.Join(" ", objects.Select(x => x.ToString()))); + tf_output_redirect.WriteLine(format + " " + string.Join(" ", objects.Select(x => x.ToString()))); return; } @@ -130,7 +148,7 @@ namespace Tensorflow } - Console.WriteLine(format); + tf_output_redirect.WriteLine(format); } public static int len(object a) diff --git a/src/TensorFlowNET.Core/Data/MnistDataSet.cs b/src/TensorFlowNET.Core/Data/MnistDataSet.cs index 3c5640fa..10e71a1d 100644 --- a/src/TensorFlowNET.Core/Data/MnistDataSet.cs +++ b/src/TensorFlowNET.Core/Data/MnistDataSet.cs @@ -24,7 +24,7 @@ namespace Tensorflow sw.Start(); images = np.multiply(images, 1.0f / 255.0f); sw.Stop(); - Console.WriteLine($"{sw.ElapsedMilliseconds}ms"); + Binding.tf_output_redirect.WriteLine($"{sw.ElapsedMilliseconds}ms"); Data = images; labels = labels.astype(dataType); diff --git a/src/TensorFlowNET.Core/Data/Utils.cs b/src/TensorFlowNET.Core/Data/Utils.cs index 3ca6ae23..082a9a68 100644 --- a/src/TensorFlowNET.Core/Data/Utils.cs +++ b/src/TensorFlowNET.Core/Data/Utils.cs @@ -27,14 +27,14 @@ namespace Tensorflow if (showProgressInConsole) { - Console.WriteLine($"Downloading {fileName}"); + Binding.tf_output_redirect.WriteLine($"Downloading {fileName}"); } if (File.Exists(fileSaveTo)) { if (showProgressInConsole) { - Console.WriteLine($"The file {fileName} already exists"); + Binding.tf_output_redirect.WriteLine($"The file {fileName} already exists"); } return; @@ -64,12 +64,12 @@ namespace Tensorflow var destFilePath = Path.Combine(saveTo, destFileName); if (showProgressInConsole) - Console.WriteLine($"Unzippinng {Path.GetFileName(zipFile)}"); + Binding.tf_output_redirect.WriteLine($"Unzippinng {Path.GetFileName(zipFile)}"); if (File.Exists(destFilePath)) { if (showProgressInConsole) - Console.WriteLine($"The file {destFileName} already exists"); + Binding.tf_output_redirect.WriteLine($"The file {destFileName} already exists"); } using (GZipStream unzipStream = new GZipStream(File.OpenRead(zipFile), CompressionMode.Decompress)) @@ -107,7 +107,7 @@ namespace Tensorflow } await showProgressTask; - Console.WriteLine("Done."); + Binding.tf_output_redirect.WriteLine("Done."); } private static async Task ShowProgressInConsole(CancellationTokenSource cts) @@ -119,17 +119,17 @@ namespace Tensorflow while (!cts.IsCancellationRequested) { await Task.Delay(100); - Console.Write("."); + Binding.tf_output_redirect.Write("."); cols++; if (cols % 50 == 0) { - Console.WriteLine(); + Binding.tf_output_redirect.WriteLine(); } } if (cols > 0) - Console.WriteLine(); + Binding.tf_output_redirect.WriteLine(); } } } diff --git a/src/TensorFlowNET.Core/Framework/c_api_util.cs b/src/TensorFlowNET.Core/Framework/c_api_util.cs index 5d5cb9b3..9cfbf0d0 100644 --- a/src/TensorFlowNET.Core/Framework/c_api_util.cs +++ b/src/TensorFlowNET.Core/Framework/c_api_util.cs @@ -62,18 +62,18 @@ namespace Tensorflow if (!File.Exists(file)) { var wc = new WebClient(); - Console.WriteLine($"Downloading Tensorflow library from {url}..."); + Binding.tf_output_redirect.WriteLine($"Downloading Tensorflow library from {url}..."); var download = Task.Run(() => wc.DownloadFile(url, file)); while (!download.IsCompleted) { Thread.Sleep(1000); - Console.Write("."); + Binding.tf_output_redirect.Write("."); } - Console.WriteLine(""); - Console.WriteLine($"Downloaded successfully."); + Binding.tf_output_redirect.WriteLine(""); + Binding.tf_output_redirect.WriteLine($"Downloaded successfully."); } - Console.WriteLine($"Extracting..."); + Binding.tf_output_redirect.WriteLine($"Extracting..."); var task = Task.Run(() => { switch (Environment.OSVersion.Platform) @@ -97,11 +97,11 @@ namespace Tensorflow while (!task.IsCompleted) { Thread.Sleep(100); - Console.Write("."); + Binding.tf_output_redirect.Write("."); } - Console.WriteLine(""); - Console.WriteLine("Extraction is completed."); + Binding.tf_output_redirect.WriteLine(""); + Binding.tf_output_redirect.WriteLine("Extraction is completed."); } isDllDownloaded = true; diff --git a/src/TensorFlowNET.Core/Framework/meta_graph.cs b/src/TensorFlowNET.Core/Framework/meta_graph.cs index 096b25a2..6ce3bf3c 100644 --- a/src/TensorFlowNET.Core/Framework/meta_graph.cs +++ b/src/TensorFlowNET.Core/Framework/meta_graph.cs @@ -134,7 +134,7 @@ namespace Tensorflow } break; default: - Console.WriteLine($"import_scoped_meta_graph_with_return_elements {col.Key}"); + Binding.tf_output_redirect.WriteLine($"import_scoped_meta_graph_with_return_elements {col.Key}"); continue; } } @@ -142,7 +142,7 @@ namespace Tensorflow break; default: - Console.WriteLine($"Cannot identify data type for collection {col.Key}. Skipping."); + Binding.tf_output_redirect.WriteLine($"Cannot identify data type for collection {col.Key}. Skipping."); break; } } diff --git a/src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs b/src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs index 615dc43d..988c5326 100644 --- a/src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs +++ b/src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs @@ -166,7 +166,7 @@ namespace Tensorflow public void repr() { - Console.WriteLine($""); + Binding.tf_output_redirect.WriteLine($""); } public bool eq(ReparameterizationType other) diff --git a/src/TensorFlowNET.Core/Training/Saving/Saver.cs b/src/TensorFlowNET.Core/Training/Saving/Saver.cs index 23d0f431..bdee7abc 100644 --- a/src/TensorFlowNET.Core/Training/Saving/Saver.cs +++ b/src/TensorFlowNET.Core/Training/Saving/Saver.cs @@ -242,7 +242,7 @@ namespace Tensorflow if (!checkpoint_management.checkpoint_exists(save_path)) throw new ValueError($"The passed save_path is not a valid checkpoint: {save_path}"); - Console.WriteLine($"Restoring parameters from {save_path}"); + Binding.tf_output_redirect.WriteLine($"Restoring parameters from {save_path}"); if (tf.Context.executing_eagerly()) #pragma warning disable CS0642 // Possible mistaken empty statement diff --git a/src/TensorFlowNET.Core/Training/Saving/saver.py.cs b/src/TensorFlowNET.Core/Training/Saving/saver.py.cs index 9307dc5d..4f583600 100644 --- a/src/TensorFlowNET.Core/Training/Saving/saver.py.cs +++ b/src/TensorFlowNET.Core/Training/Saving/saver.py.cs @@ -78,7 +78,7 @@ namespace Tensorflow else { // If no graph variables exist, then a Saver cannot be constructed. - Console.WriteLine("Saver not created because there are no variables in the" + + Binding.tf_output_redirect.WriteLine("Saver not created because there are no variables in the" + " graph to restore"); return null; } @@ -102,7 +102,7 @@ namespace Tensorflow var output_graph_def = tf.graph_util.convert_variables_to_constants(sess, graph.as_graph_def(), output_node_names); - Console.WriteLine($"Froze {output_graph_def.Node.Count} nodes."); + Binding.tf_output_redirect.WriteLine($"Froze {output_graph_def.Node.Count} nodes."); File.WriteAllBytes(output_pb, output_graph_def.ToByteArray()); return output_pb; } diff --git a/src/TensorFlowNET.Core/Util/CmdHelper.cs b/src/TensorFlowNET.Core/Util/CmdHelper.cs index 13acbcb0..9e9fb81f 100644 --- a/src/TensorFlowNET.Core/Util/CmdHelper.cs +++ b/src/TensorFlowNET.Core/Util/CmdHelper.cs @@ -31,7 +31,7 @@ namespace Tensorflow.Util proc.Start(); while (!proc.StandardOutput.EndOfStream) - Console.WriteLine(proc.StandardOutput.ReadLine()); + Binding.tf_output_redirect.WriteLine(proc.StandardOutput.ReadLine()); } public static void Bash(string command) @@ -44,7 +44,7 @@ namespace Tensorflow.Util proc.Start(); while (!proc.StandardOutput.EndOfStream) - Console.WriteLine(proc.StandardOutput.ReadLine()); + Binding.tf_output_redirect.WriteLine(proc.StandardOutput.ReadLine()); } } } diff --git a/src/TensorFlowNET.Keras/Datasets/MNIST.cs b/src/TensorFlowNET.Keras/Datasets/MNIST.cs index 8fa61b41..582404a2 100644 --- a/src/TensorFlowNET.Keras/Datasets/MNIST.cs +++ b/src/TensorFlowNET.Keras/Datasets/MNIST.cs @@ -61,7 +61,7 @@ namespace Tensorflow.Keras.Datasets if (File.Exists(fileSaveTo)) { - Console.WriteLine($"The file {fileSaveTo} already exists"); + Binding.tf_output_redirect.WriteLine($"The file {fileSaveTo} already exists"); return fileSaveTo; } diff --git a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs index 8a484c3b..11910db4 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs @@ -46,7 +46,7 @@ namespace Tensorflow.Keras.Engine StepsPerExecution = _steps_per_execution }); - Console.WriteLine($"Testing..."); + Binding.tf_output_redirect.WriteLine($"Testing..."); foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) { // reset_metrics(); @@ -58,7 +58,7 @@ namespace Tensorflow.Keras.Engine // callbacks.on_train_batch_begin(step) results = test_function(iterator); } - Console.WriteLine($"iterator: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}"))); + Binding.tf_output_redirect.WriteLine($"iterator: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}"))); } } diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs index 2e83f75d..ad58efa1 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs @@ -99,7 +99,7 @@ namespace Tensorflow.Keras.Engine if (verbose == 1) { var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}")); - Console.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}"); + Binding.tf_output_redirect.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}"); } } diff --git a/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.get_training_or_validation_split.cs b/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.get_training_or_validation_split.cs index a80e960a..2f3d8f52 100644 --- a/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.get_training_or_validation_split.cs +++ b/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.get_training_or_validation_split.cs @@ -24,13 +24,13 @@ namespace Tensorflow.Keras.Preprocessings var num_val_samples = Convert.ToInt32(samples.Length * validation_split); if (subset == "training") { - Console.WriteLine($"Using {samples.Length - num_val_samples} files for training."); + Binding.tf_output_redirect.WriteLine($"Using {samples.Length - num_val_samples} files for training."); samples = samples.Take(samples.Length - num_val_samples).ToArray(); labels = labels.Take(labels.Length - num_val_samples).ToArray(); } else if (subset == "validation") { - Console.WriteLine($"Using {num_val_samples} files for validation."); + Binding.tf_output_redirect.WriteLine($"Using {num_val_samples} files for validation."); samples = samples.Skip(samples.Length - num_val_samples).ToArray(); labels = labels.Skip(labels.Length - num_val_samples).ToArray(); } diff --git a/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs b/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs index 03c9f8d1..904b0805 100644 --- a/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs +++ b/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs @@ -61,7 +61,7 @@ namespace Tensorflow.Keras.Preprocessings } } - Console.WriteLine($"Found {return_file_paths.Length} files belonging to {class_names.Length} classes."); + Binding.tf_output_redirect.WriteLine($"Found {return_file_paths.Length} files belonging to {class_names.Length} classes."); return (return_file_paths, return_labels, class_names); } } diff --git a/src/TensorFlowNET.Keras/Utils/Compress.cs b/src/TensorFlowNET.Keras/Utils/Compress.cs index 5a4f99c7..a865d2ae 100644 --- a/src/TensorFlowNET.Keras/Utils/Compress.cs +++ b/src/TensorFlowNET.Keras/Utils/Compress.cs @@ -53,7 +53,7 @@ namespace Tensorflow.Keras.Utils var flag = gzArchiveName.Split(Path.DirectorySeparatorChar).Last().Split('.').First() + ".bin"; if (File.Exists(Path.Combine(destFolder, flag))) return; - Console.WriteLine($"Extracting."); + Binding.tf_output_redirect.WriteLine($"Extracting."); var task = Task.Run(() => { ZipFile.ExtractToDirectory(gzArchiveName, destFolder); @@ -62,12 +62,12 @@ namespace Tensorflow.Keras.Utils while (!task.IsCompleted) { Thread.Sleep(200); - Console.Write("."); + Binding.tf_output_redirect.Write("."); } File.Create(Path.Combine(destFolder, flag)); - Console.WriteLine(""); - Console.WriteLine("Extracting is completed."); + Binding.tf_output_redirect.WriteLine(""); + Binding.tf_output_redirect.WriteLine("Extracting is completed."); } public static void ExtractTGZ(String gzArchiveName, String destFolder) @@ -75,7 +75,7 @@ namespace Tensorflow.Keras.Utils var flag = gzArchiveName.Split(Path.DirectorySeparatorChar).Last().Split('.').First() + ".bin"; if (File.Exists(Path.Combine(destFolder, flag))) return; - Console.WriteLine($"Extracting."); + Binding.tf_output_redirect.WriteLine($"Extracting."); var task = Task.Run(() => { using (var inStream = File.OpenRead(gzArchiveName)) @@ -91,12 +91,12 @@ namespace Tensorflow.Keras.Utils while (!task.IsCompleted) { Thread.Sleep(200); - Console.Write("."); + Binding.tf_output_redirect.Write("."); } File.Create(Path.Combine(destFolder, flag)); - Console.WriteLine(""); - Console.WriteLine("Extracting is completed."); + Binding.tf_output_redirect.WriteLine(""); + Binding.tf_output_redirect.WriteLine("Extracting is completed."); } } } diff --git a/src/TensorFlowNET.Keras/Utils/Web.cs b/src/TensorFlowNET.Keras/Utils/Web.cs index 4e9f09d9..9f10feb8 100644 --- a/src/TensorFlowNET.Keras/Utils/Web.cs +++ b/src/TensorFlowNET.Keras/Utils/Web.cs @@ -36,20 +36,20 @@ namespace Tensorflow.Keras.Utils if (File.Exists(relativeFilePath)) { - Console.WriteLine($"{relativeFilePath} already exists."); + Binding.tf_output_redirect.WriteLine($"{relativeFilePath} already exists."); return false; } var wc = new WebClient(); - Console.WriteLine($"Downloading from {url}"); + Binding.tf_output_redirect.WriteLine($"Downloading from {url}"); var download = Task.Run(() => wc.DownloadFile(url, relativeFilePath)); while (!download.IsCompleted) { Thread.Sleep(1000); - Console.Write("."); + Binding.tf_output_redirect.Write("."); } - Console.WriteLine(""); - Console.WriteLine($"Downloaded to {relativeFilePath}"); + Binding.tf_output_redirect.WriteLine(""); + Binding.tf_output_redirect.WriteLine($"Downloaded to {relativeFilePath}"); return true; }