Browse Source

Merge branch 'master' into Conv_1D

tags/v0.40-tf2.4-tstring
Niklas Gustafsson 4 years ago
parent
commit
58f3194909
17 changed files with 121 additions and 48 deletions
  1. +21
    -3
      src/TensorFlowNET.Core/Binding.Util.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Data/MnistDataSet.cs
  3. +8
    -8
      src/TensorFlowNET.Core/Data/Utils.cs
  4. +8
    -8
      src/TensorFlowNET.Core/Framework/c_api_util.cs
  5. +2
    -2
      src/TensorFlowNET.Core/Framework/meta_graph.cs
  6. +1
    -1
      src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Training/Saving/Saver.cs
  8. +2
    -2
      src/TensorFlowNET.Core/Training/Saving/saver.py.cs
  9. +2
    -2
      src/TensorFlowNET.Core/Util/CmdHelper.cs
  10. +1
    -1
      src/TensorFlowNET.Keras/Datasets/MNIST.cs
  11. +2
    -2
      src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
  12. +1
    -1
      src/TensorFlowNET.Keras/Engine/Model.Fit.cs
  13. +2
    -2
      src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.get_training_or_validation_split.cs
  14. +1
    -1
      src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs
  15. +8
    -8
      src/TensorFlowNET.Keras/Utils/Compress.cs
  16. +5
    -5
      src/TensorFlowNET.Keras/Utils/Web.cs
  17. +55
    -0
      test/TensorFlowNET.Keras.UnitTest/OutputTest.cs

+ 21
- 3
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -21,6 +21,7 @@ using System.Collections;
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.IO;
using System.Linq;

namespace Tensorflow
@@ -112,16 +113,33 @@ namespace Tensorflow
}
}

private static TextWriter writer = null;

public static TextWriter tf_output_redirect {
set
{
var originWriter = writer ?? Console.Out;
originWriter.Flush();
if (originWriter is StringWriter)
(originWriter as StringWriter).GetStringBuilder().Clear();
writer = value;
}
get
{
return writer ?? Console.Out;
}
}

public static void print(object obj)
{
Console.WriteLine(_tostring(obj));
tf_output_redirect.WriteLine(_tostring(obj));
}

public static void print(string format, params object[] objects)
{
if (!format.Contains("{}"))
{
Console.WriteLine(format + " " + string.Join(" ", objects.Select(x => x.ToString())));
tf_output_redirect.WriteLine(format + " " + string.Join(" ", objects.Select(x => x.ToString())));
return;
}

@@ -130,7 +148,7 @@ namespace Tensorflow

}

Console.WriteLine(format);
tf_output_redirect.WriteLine(format);
}

public static int len(object a)


+ 1
- 1
src/TensorFlowNET.Core/Data/MnistDataSet.cs View File

@@ -24,7 +24,7 @@ namespace Tensorflow
sw.Start();
images = np.multiply(images, 1.0f / 255.0f);
sw.Stop();
Console.WriteLine($"{sw.ElapsedMilliseconds}ms");
Binding.tf_output_redirect.WriteLine($"{sw.ElapsedMilliseconds}ms");
Data = images;

labels = labels.astype(dataType);


+ 8
- 8
src/TensorFlowNET.Core/Data/Utils.cs View File

@@ -27,14 +27,14 @@ namespace Tensorflow

if (showProgressInConsole)
{
Console.WriteLine($"Downloading {fileName}");
Binding.tf_output_redirect.WriteLine($"Downloading {fileName}");
}

if (File.Exists(fileSaveTo))
{
if (showProgressInConsole)
{
Console.WriteLine($"The file {fileName} already exists");
Binding.tf_output_redirect.WriteLine($"The file {fileName} already exists");
}

return;
@@ -64,12 +64,12 @@ namespace Tensorflow
var destFilePath = Path.Combine(saveTo, destFileName);

if (showProgressInConsole)
Console.WriteLine($"Unzippinng {Path.GetFileName(zipFile)}");
Binding.tf_output_redirect.WriteLine($"Unzippinng {Path.GetFileName(zipFile)}");

if (File.Exists(destFilePath))
{
if (showProgressInConsole)
Console.WriteLine($"The file {destFileName} already exists");
Binding.tf_output_redirect.WriteLine($"The file {destFileName} already exists");
}

using (GZipStream unzipStream = new GZipStream(File.OpenRead(zipFile), CompressionMode.Decompress))
@@ -107,7 +107,7 @@ namespace Tensorflow
}

await showProgressTask;
Console.WriteLine("Done.");
Binding.tf_output_redirect.WriteLine("Done.");
}

private static async Task ShowProgressInConsole(CancellationTokenSource cts)
@@ -119,17 +119,17 @@ namespace Tensorflow
while (!cts.IsCancellationRequested)
{
await Task.Delay(100);
Console.Write(".");
Binding.tf_output_redirect.Write(".");
cols++;

if (cols % 50 == 0)
{
Console.WriteLine();
Binding.tf_output_redirect.WriteLine();
}
}

if (cols > 0)
Console.WriteLine();
Binding.tf_output_redirect.WriteLine();
}
}
}

+ 8
- 8
src/TensorFlowNET.Core/Framework/c_api_util.cs View File

@@ -62,18 +62,18 @@ namespace Tensorflow
if (!File.Exists(file))
{
var wc = new WebClient();
Console.WriteLine($"Downloading Tensorflow library from {url}...");
Binding.tf_output_redirect.WriteLine($"Downloading Tensorflow library from {url}...");
var download = Task.Run(() => wc.DownloadFile(url, file));
while (!download.IsCompleted)
{
Thread.Sleep(1000);
Console.Write(".");
Binding.tf_output_redirect.Write(".");
}
Console.WriteLine("");
Console.WriteLine($"Downloaded successfully.");
Binding.tf_output_redirect.WriteLine("");
Binding.tf_output_redirect.WriteLine($"Downloaded successfully.");
}

Console.WriteLine($"Extracting...");
Binding.tf_output_redirect.WriteLine($"Extracting...");
var task = Task.Run(() =>
{
switch (Environment.OSVersion.Platform)
@@ -97,11 +97,11 @@ namespace Tensorflow
while (!task.IsCompleted)
{
Thread.Sleep(100);
Console.Write(".");
Binding.tf_output_redirect.Write(".");
}

Console.WriteLine("");
Console.WriteLine("Extraction is completed.");
Binding.tf_output_redirect.WriteLine("");
Binding.tf_output_redirect.WriteLine("Extraction is completed.");
}

isDllDownloaded = true;


+ 2
- 2
src/TensorFlowNET.Core/Framework/meta_graph.cs View File

@@ -134,7 +134,7 @@ namespace Tensorflow
}
break;
default:
Console.WriteLine($"import_scoped_meta_graph_with_return_elements {col.Key}");
Binding.tf_output_redirect.WriteLine($"import_scoped_meta_graph_with_return_elements {col.Key}");
continue;
}
}
@@ -142,7 +142,7 @@ namespace Tensorflow

break;
default:
Console.WriteLine($"Cannot identify data type for collection {col.Key}. Skipping.");
Binding.tf_output_redirect.WriteLine($"Cannot identify data type for collection {col.Key}. Skipping.");
break;
}
}


+ 1
- 1
src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs View File

@@ -166,7 +166,7 @@ namespace Tensorflow

public void repr()
{
Console.WriteLine($"<Reparameteriation Type: {this._rep_type}>");
Binding.tf_output_redirect.WriteLine($"<Reparameteriation Type: {this._rep_type}>");
}

public bool eq(ReparameterizationType other)


+ 1
- 1
src/TensorFlowNET.Core/Training/Saving/Saver.cs View File

@@ -242,7 +242,7 @@ namespace Tensorflow
if (!checkpoint_management.checkpoint_exists(save_path))
throw new ValueError($"The passed save_path is not a valid checkpoint: {save_path}");

Console.WriteLine($"Restoring parameters from {save_path}");
Binding.tf_output_redirect.WriteLine($"Restoring parameters from {save_path}");

if (tf.Context.executing_eagerly())
#pragma warning disable CS0642 // Possible mistaken empty statement


+ 2
- 2
src/TensorFlowNET.Core/Training/Saving/saver.py.cs View File

@@ -78,7 +78,7 @@ namespace Tensorflow
else
{
// If no graph variables exist, then a Saver cannot be constructed.
Console.WriteLine("Saver not created because there are no variables in the" +
Binding.tf_output_redirect.WriteLine("Saver not created because there are no variables in the" +
" graph to restore");
return null;
}
@@ -102,7 +102,7 @@ namespace Tensorflow
var output_graph_def = tf.graph_util.convert_variables_to_constants(sess,
graph.as_graph_def(),
output_node_names);
Console.WriteLine($"Froze {output_graph_def.Node.Count} nodes.");
Binding.tf_output_redirect.WriteLine($"Froze {output_graph_def.Node.Count} nodes.");
File.WriteAllBytes(output_pb, output_graph_def.ToByteArray());
return output_pb;
}


+ 2
- 2
src/TensorFlowNET.Core/Util/CmdHelper.cs View File

@@ -31,7 +31,7 @@ namespace Tensorflow.Util
proc.Start();

while (!proc.StandardOutput.EndOfStream)
Console.WriteLine(proc.StandardOutput.ReadLine());
Binding.tf_output_redirect.WriteLine(proc.StandardOutput.ReadLine());
}

public static void Bash(string command)
@@ -44,7 +44,7 @@ namespace Tensorflow.Util
proc.Start();

while (!proc.StandardOutput.EndOfStream)
Console.WriteLine(proc.StandardOutput.ReadLine());
Binding.tf_output_redirect.WriteLine(proc.StandardOutput.ReadLine());
}
}
}

+ 1
- 1
src/TensorFlowNET.Keras/Datasets/MNIST.cs View File

@@ -61,7 +61,7 @@ namespace Tensorflow.Keras.Datasets

if (File.Exists(fileSaveTo))
{
Console.WriteLine($"The file {fileSaveTo} already exists");
Binding.tf_output_redirect.WriteLine($"The file {fileSaveTo} already exists");
return fileSaveTo;
}



+ 2
- 2
src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs View File

@@ -46,7 +46,7 @@ namespace Tensorflow.Keras.Engine
StepsPerExecution = _steps_per_execution
});

Console.WriteLine($"Testing...");
Binding.tf_output_redirect.WriteLine($"Testing...");
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
// reset_metrics();
@@ -58,7 +58,7 @@ namespace Tensorflow.Keras.Engine
// callbacks.on_train_batch_begin(step)
results = test_function(iterator);
}
Console.WriteLine($"iterator: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
Binding.tf_output_redirect.WriteLine($"iterator: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
}
}



+ 1
- 1
src/TensorFlowNET.Keras/Engine/Model.Fit.cs View File

@@ -99,7 +99,7 @@ namespace Tensorflow.Keras.Engine
if (verbose == 1)
{
var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}"));
Console.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}");
Binding.tf_output_redirect.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}");
}
}



+ 2
- 2
src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.get_training_or_validation_split.cs View File

@@ -24,13 +24,13 @@ namespace Tensorflow.Keras.Preprocessings
var num_val_samples = Convert.ToInt32(samples.Length * validation_split);
if (subset == "training")
{
Console.WriteLine($"Using {samples.Length - num_val_samples} files for training.");
Binding.tf_output_redirect.WriteLine($"Using {samples.Length - num_val_samples} files for training.");
samples = samples.Take(samples.Length - num_val_samples).ToArray();
labels = labels.Take(labels.Length - num_val_samples).ToArray();
}
else if (subset == "validation")
{
Console.WriteLine($"Using {num_val_samples} files for validation.");
Binding.tf_output_redirect.WriteLine($"Using {num_val_samples} files for validation.");
samples = samples.Skip(samples.Length - num_val_samples).ToArray();
labels = labels.Skip(labels.Length - num_val_samples).ToArray();
}


+ 1
- 1
src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs View File

@@ -61,7 +61,7 @@ namespace Tensorflow.Keras.Preprocessings
}
}

Console.WriteLine($"Found {return_file_paths.Length} files belonging to {class_names.Length} classes.");
Binding.tf_output_redirect.WriteLine($"Found {return_file_paths.Length} files belonging to {class_names.Length} classes.");
return (return_file_paths, return_labels, class_names);
}
}


+ 8
- 8
src/TensorFlowNET.Keras/Utils/Compress.cs View File

@@ -53,7 +53,7 @@ namespace Tensorflow.Keras.Utils
var flag = gzArchiveName.Split(Path.DirectorySeparatorChar).Last().Split('.').First() + ".bin";
if (File.Exists(Path.Combine(destFolder, flag))) return;

Console.WriteLine($"Extracting.");
Binding.tf_output_redirect.WriteLine($"Extracting.");
var task = Task.Run(() =>
{
ZipFile.ExtractToDirectory(gzArchiveName, destFolder);
@@ -62,12 +62,12 @@ namespace Tensorflow.Keras.Utils
while (!task.IsCompleted)
{
Thread.Sleep(200);
Console.Write(".");
Binding.tf_output_redirect.Write(".");
}

File.Create(Path.Combine(destFolder, flag));
Console.WriteLine("");
Console.WriteLine("Extracting is completed.");
Binding.tf_output_redirect.WriteLine("");
Binding.tf_output_redirect.WriteLine("Extracting is completed.");
}

public static void ExtractTGZ(String gzArchiveName, String destFolder)
@@ -75,7 +75,7 @@ namespace Tensorflow.Keras.Utils
var flag = gzArchiveName.Split(Path.DirectorySeparatorChar).Last().Split('.').First() + ".bin";
if (File.Exists(Path.Combine(destFolder, flag))) return;

Console.WriteLine($"Extracting.");
Binding.tf_output_redirect.WriteLine($"Extracting.");
var task = Task.Run(() =>
{
using (var inStream = File.OpenRead(gzArchiveName))
@@ -91,12 +91,12 @@ namespace Tensorflow.Keras.Utils
while (!task.IsCompleted)
{
Thread.Sleep(200);
Console.Write(".");
Binding.tf_output_redirect.Write(".");
}

File.Create(Path.Combine(destFolder, flag));
Console.WriteLine("");
Console.WriteLine("Extracting is completed.");
Binding.tf_output_redirect.WriteLine("");
Binding.tf_output_redirect.WriteLine("Extracting is completed.");
}
}
}

+ 5
- 5
src/TensorFlowNET.Keras/Utils/Web.cs View File

@@ -36,20 +36,20 @@ namespace Tensorflow.Keras.Utils

if (File.Exists(relativeFilePath))
{
Console.WriteLine($"{relativeFilePath} already exists.");
Binding.tf_output_redirect.WriteLine($"{relativeFilePath} already exists.");
return false;
}

var wc = new WebClient();
Console.WriteLine($"Downloading from {url}");
Binding.tf_output_redirect.WriteLine($"Downloading from {url}");
var download = Task.Run(() => wc.DownloadFile(url, relativeFilePath));
while (!download.IsCompleted)
{
Thread.Sleep(1000);
Console.Write(".");
Binding.tf_output_redirect.Write(".");
}
Console.WriteLine("");
Console.WriteLine($"Downloaded to {relativeFilePath}");
Binding.tf_output_redirect.WriteLine("");
Binding.tf_output_redirect.WriteLine($"Downloaded to {relativeFilePath}");

return true;
}


+ 55
- 0
test/TensorFlowNET.Keras.UnitTest/OutputTest.cs View File

@@ -0,0 +1,55 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using Tensorflow.Keras;

namespace Tensorflow.Keras.UnitTest
{
[TestClass]
public class OutputTest
{
[TestMethod]
public void OutputRedirectTest()
{
using var newOutput = new System.IO.StringWriter();
tf_output_redirect = newOutput;
var model = keras.Sequential();
model.add(keras.Input(shape: 16));
model.summary();
string output = newOutput.ToString();
Assert.IsTrue(output.StartsWith("Model: sequential"));
tf_output_redirect = null; // don't forget to change it to null !!!!
}

[TestMethod]
public void SwitchOutputsTest()
{
using var newOutput = new System.IO.StringWriter();
var model = keras.Sequential();
model.add(keras.Input(shape: 16));
model.summary(); // Console.Out

tf_output_redirect = newOutput; // change to the custom one
model.summary();
string firstOutput = newOutput.ToString();
Assert.IsTrue(firstOutput.StartsWith("Model: sequential"));

// if tf_output_reditect is StringWriter, calling "set" will make the writer clear.
tf_output_redirect = null; // null means Console.Out
model.summary();

tf_output_redirect = newOutput; // again, to test whether the newOutput is clear.
model.summary();
string secondOutput = newOutput.ToString();
Assert.IsTrue(secondOutput.StartsWith("Model: sequential"));

Assert.IsTrue(firstOutput == secondOutput);
tf_output_redirect = null; // don't forget to change it to null !!!!
}
}
}

Loading…
Cancel
Save