Browse Source

optimize image dataset loading

tags/v0.9
Oceania2018 6 years ago
parent
commit
0484aab0f9
4 changed files with 14 additions and 5 deletions
  1. +10
    -1
      test/TensorFlowNET.Examples/LogisticRegression.cs
  2. +1
    -1
      test/TensorFlowNET.Examples/NeuralNetXor.cs
  3. +2
    -2
      test/TensorFlowNET.Examples/Utility/DataSet.cs
  4. +1
    -1
      test/TensorFlowNET.Examples/Utility/MnistDataSet.cs

+ 10
- 1
test/TensorFlowNET.Examples/LogisticRegression.cs View File

@@ -1,6 +1,7 @@
using NumSharp; using NumSharp;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
@@ -55,6 +56,8 @@ namespace TensorFlowNET.Examples
// Initialize the variables (i.e. assign their default value) // Initialize the variables (i.e. assign their default value)
var init = tf.global_variables_initializer(); var init = tf.global_variables_initializer();


var sw = new Stopwatch();

return with(tf.Session(), sess => return with(tf.Session(), sess =>
{ {
// Run the initializer // Run the initializer
@@ -63,6 +66,8 @@ namespace TensorFlowNET.Examples
// Training cycle // Training cycle
foreach (var epoch in range(training_epochs)) foreach (var epoch in range(training_epochs))
{ {
sw.Start();

var avg_cost = 0.0f; var avg_cost = 0.0f;
var total_batch = mnist.train.num_examples / batch_size; var total_batch = mnist.train.num_examples / batch_size;
// Loop over all batches // Loop over all batches
@@ -79,9 +84,13 @@ namespace TensorFlowNET.Examples
avg_cost += c / total_batch; avg_cost += c / total_batch;
} }


sw.Stop();

// Display logs per epoch step // Display logs per epoch step
if ((epoch + 1) % display_step == 0) if ((epoch + 1) % display_step == 0)
print($"Epoch: {(epoch + 1).ToString("D4")} cost= {avg_cost.ToString("G9")}");
print($"Epoch: {(epoch + 1).ToString("D4")} cost= {avg_cost.ToString("G9")} elapse= {sw.ElapsedMilliseconds}ms");

sw.Reset();
} }


print("Optimization Finished!"); print("Optimization Finished!");


+ 1
- 1
test/TensorFlowNET.Examples/NeuralNetXor.cs View File

@@ -12,7 +12,7 @@ namespace TensorFlowNET.Examples
/// </summary> /// </summary>
public class NeuralNetXor : Python, IExample public class NeuralNetXor : Python, IExample
{ {
public int Priority => 2;
public int Priority => 10;
public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public string Name => "NN XOR"; public string Name => "NN XOR";


+ 2
- 2
test/TensorFlowNET.Examples/Utility/DataSet.cs View File

@@ -54,8 +54,8 @@ namespace TensorFlowNET.Examples.Utility


// Get the rest examples in this epoch // Get the rest examples in this epoch
var rest_num_examples = _num_examples - start; var rest_num_examples = _num_examples - start;
var images_rest_part = _images[np.arange(start, _num_examples)];
var labels_rest_part = _labels[np.arange(start, _num_examples)];
//var images_rest_part = _images[np.arange(start, _num_examples)];
//var labels_rest_part = _labels[np.arange(start, _num_examples)];
// Shuffle the data // Shuffle the data
if (shuffle) if (shuffle)
{ {


+ 1
- 1
test/TensorFlowNET.Examples/Utility/MnistDataSet.cs View File

@@ -102,7 +102,7 @@ namespace TensorFlowNET.Examples.Utility
for(int row = 0; row < num_labels; row++) for(int row = 0; row < num_labels; row++)
{ {
var col = labels_dense.Data<byte>(row); var col = labels_dense.Data<byte>(row);
labels_one_hot.SetData(1, row, col);
labels_one_hot.SetData(1.0, row, col);
} }


return labels_one_hot; return labels_one_hot;


Loading…
Cancel
Save