Browse Source

optimize image dataset loading

tags/v0.9
Oceania2018 6 years ago
parent
commit
0484aab0f9
4 changed files with 14 additions and 5 deletions
  1. +10
    -1
      test/TensorFlowNET.Examples/LogisticRegression.cs
  2. +1
    -1
      test/TensorFlowNET.Examples/NeuralNetXor.cs
  3. +2
    -2
      test/TensorFlowNET.Examples/Utility/DataSet.cs
  4. +1
    -1
      test/TensorFlowNET.Examples/Utility/MnistDataSet.cs

+ 10
- 1
test/TensorFlowNET.Examples/LogisticRegression.cs View File

@@ -1,6 +1,7 @@
using NumSharp;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Text;
@@ -55,6 +56,8 @@ namespace TensorFlowNET.Examples
// Initialize the variables (i.e. assign their default value)
var init = tf.global_variables_initializer();

var sw = new Stopwatch();

return with(tf.Session(), sess =>
{
// Run the initializer
@@ -63,6 +66,8 @@ namespace TensorFlowNET.Examples
// Training cycle
foreach (var epoch in range(training_epochs))
{
sw.Start();

var avg_cost = 0.0f;
var total_batch = mnist.train.num_examples / batch_size;
// Loop over all batches
@@ -79,9 +84,13 @@ namespace TensorFlowNET.Examples
avg_cost += c / total_batch;
}

sw.Stop();

// Display logs per epoch step
if ((epoch + 1) % display_step == 0)
print($"Epoch: {(epoch + 1).ToString("D4")} cost= {avg_cost.ToString("G9")}");
print($"Epoch: {(epoch + 1).ToString("D4")} cost= {avg_cost.ToString("G9")} elapse= {sw.ElapsedMilliseconds}ms");

sw.Reset();
}

print("Optimization Finished!");


+ 1
- 1
test/TensorFlowNET.Examples/NeuralNetXor.cs View File

@@ -12,7 +12,7 @@ namespace TensorFlowNET.Examples
/// </summary>
public class NeuralNetXor : Python, IExample
{
public int Priority => 2;
public int Priority => 10;
public bool Enabled { get; set; } = true;
public string Name => "NN XOR";


+ 2
- 2
test/TensorFlowNET.Examples/Utility/DataSet.cs View File

@@ -54,8 +54,8 @@ namespace TensorFlowNET.Examples.Utility

// Get the rest examples in this epoch
var rest_num_examples = _num_examples - start;
var images_rest_part = _images[np.arange(start, _num_examples)];
var labels_rest_part = _labels[np.arange(start, _num_examples)];
//var images_rest_part = _images[np.arange(start, _num_examples)];
//var labels_rest_part = _labels[np.arange(start, _num_examples)];
// Shuffle the data
if (shuffle)
{


+ 1
- 1
test/TensorFlowNET.Examples/Utility/MnistDataSet.cs View File

@@ -102,7 +102,7 @@ namespace TensorFlowNET.Examples.Utility
for(int row = 0; row < num_labels; row++)
{
var col = labels_dense.Data<byte>(row);
labels_one_hot.SetData(1, row, col);
labels_one_hot.SetData(1.0, row, col);
}

return labels_one_hot;


Loading…
Cancel
Save