Browse Source

can training and cross validating, but train_saver.save failed. #248

tags/v0.9
Oceania2018 6 years ago
parent
commit
5751c2037d
8 changed files with 202 additions and 23 deletions
  1. +6
    -0
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  2. +3
    -0
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  3. +20
    -2
      src/TensorFlowNET.Core/Summaries/EventFileWriter.cs
  4. +52
    -0
      src/TensorFlowNET.Core/Summaries/EventLoggerThread.cs
  5. +5
    -2
      src/TensorFlowNET.Core/Summaries/EventsWriter.cs
  6. +7
    -1
      src/TensorFlowNET.Core/Summaries/SummaryToEventTransformer.cs
  7. +5
    -18
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  8. +104
    -0
      test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs

+ 6
- 0
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -90,6 +90,12 @@ namespace Tensorflow
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case int val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case long val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case long[] val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case int[] val:


+ 3
- 0
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -62,6 +62,9 @@ namespace Tensorflow
case "Single":
full_values.Add(float.NaN);
break;
case "String":
full_values.Add(float.NaN);
break;
default:
throw new NotImplementedException($"build_results tensor_values[0] {tensor_values[0].dtype.Name}");
}


+ 20
- 2
src/TensorFlowNET.Core/Summaries/EventFileWriter.cs View File

@@ -5,17 +5,35 @@ using System.Text;

namespace Tensorflow.Summaries
{
/// <summary>
/// Creates a `EventFileWriter` and an event file to write to.
/// </summary>
public class EventFileWriter
{
string _logdir;
Queue<int> _event_queue;
// Represents a first-in, first-out collection of objects.
Queue<Event> _event_queue;
EventsWriter _ev_writer;
int _flush_secs;
Event _sentinel_event;
bool _closed;
EventLoggerThread _worker;

public EventFileWriter(string logdir, int max_queue = 10, int flush_secs= 120,
string filename_suffix = null)
{
_logdir = logdir;
Directory.CreateDirectory(_logdir);
_event_queue = new Queue<int>(max_queue);
_event_queue = new Queue<Event>(max_queue);
_ev_writer = new EventsWriter(Path.Combine(_logdir, "events"));
_flush_secs = flush_secs;
_sentinel_event = new Event();
if (!string.IsNullOrEmpty(filename_suffix))
// self._ev_writer.InitWithSuffix(compat.as_bytes(filename_suffix)))
throw new NotImplementedException("EventFileWriter filename_suffix is not null");
_closed = false;
_worker = new EventLoggerThread(_event_queue, _ev_writer, _flush_secs, _sentinel_event);
_worker.start();
}
}
}

+ 52
- 0
src/TensorFlowNET.Core/Summaries/EventLoggerThread.cs View File

@@ -0,0 +1,52 @@
using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace Tensorflow.Summaries
{
/// <summary>
/// Thread that logs events.
/// </summary>
public class EventLoggerThread
{
Queue<Event> _queue;
bool daemon;
EventsWriter _ev_writer;
int _flush_secs;
Event _sentinel_event;

public EventLoggerThread(Queue<Event> queue, EventsWriter ev_writer, int flush_secs, Event sentinel_event)
{
daemon = true;
_queue = queue;
_ev_writer = ev_writer;
_flush_secs = flush_secs;
_sentinel_event = sentinel_event;
}

public void start() => run();

public void run()
{
Task.Run(delegate
{
while (true)
{
if(_queue.Count == 0)
{
Thread.Sleep(_flush_secs * 1000);
continue;
}

var @event = _queue.Dequeue();
_ev_writer._WriteSerializedEvent(@event.ToByteArray());
Thread.Sleep(1000);
}
});
}
}
}

+ 5
- 2
src/TensorFlowNET.Core/Summaries/EventsWriter.cs View File

@@ -1,19 +1,22 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;

namespace Tensorflow.Summaries
{
public class EventsWriter
{
string _file_prefix;

public EventsWriter(string file_prefix)
{
_file_prefix = file_prefix;
}

public void _WriteSerializedEvent(byte[] event_str)
{
File.WriteAllBytes(_file_prefix, event_str);
}
}
}

+ 7
- 1
src/TensorFlowNET.Core/Summaries/SummaryToEventTransformer.cs View File

@@ -1,4 +1,5 @@
using System;
using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.Text;

@@ -9,5 +10,10 @@ namespace Tensorflow.Summaries
/// </summary>
public abstract class SummaryToEventTransformer
{
public void add_summary(string summary, int global_step = 0)
{
var bytes = UTF8Encoding.Unicode.GetBytes(summary);
// var summ = Tensorflow.Summary.Parser.ParseFrom(bytes);
}
}
}

+ 5
- 18
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -71,6 +71,9 @@ namespace Tensorflow
case "Int32":
Marshal.Copy(nd1.Data<int>(), 0, dotHandle, nd.size);
break;
case "Int64":
Marshal.Copy(nd1.Data<long>(), 0, dotHandle, nd.size);
break;
case "Single":
Marshal.Copy(nd1.Data<float>(), 0, dotHandle, nd.size);
break;
@@ -80,24 +83,8 @@ namespace Tensorflow
case "Byte":
Marshal.Copy(nd1.Data<byte>(), 0, dotHandle, nd.size);
break;
//case "String":
/*string ss = nd.Data<string>()[0];
var str = Marshal.StringToHGlobalAnsi(ss);
ulong dst_len = c_api.TF_StringEncodedSize((ulong)ss.Length);
var dataType1 = ToTFDataType(nd.dtype);
// shape
var dims1 = nd.shape.Select(x => (long)x).ToArray();

var tfHandle1 = c_api.TF_AllocateTensor(dataType1,
dims1,
nd.ndim,
dst_len + sizeof(Int64));

dotHandle = c_api.TF_TensorData(tfHandle1);
Marshal.WriteInt64(dotHandle, 0);
c_api.TF_StringEncode(str, (ulong)ss.Length, dotHandle + sizeof(Int64), dst_len, status);
return tfHandle1;*/
break;
case "String":
return new Tensor(UTF8Encoding.UTF8.GetBytes(nd.Data<string>(0)));
default:
throw new NotImplementedException($"Marshal.Copy failed for {nd.dtype.Name}.");
}


+ 104
- 0
test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs View File

@@ -31,11 +31,19 @@ namespace TensorFlowNET.Examples.ImageProcess
string summaries_dir = Path.Join(data_dir, "retrain_logs");
string image_dir = Path.Join(data_dir, "flower_photos");
string bottleneck_dir = Path.Join(data_dir, "bottleneck");
// The location where variable checkpoints will be stored.
string CHECKPOINT_NAME = Path.Join(data_dir, "_retrain_checkpoint");
string tfhub_module = "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/3";
float testing_percentage = 0.1f;
float validation_percentage = 0.1f;
Tensor resized_image_tensor;
Dictionary<string, Dictionary<string, string[]>> image_lists;
int how_many_training_steps = 200;
int eval_step_interval = 10;
int train_batch_size = 100;
int validation_batch_size = 100;
int intermediate_store_frequency = 0;
const int MAX_NUM_IMAGES_PER_CLASS = 134217727;

public bool Run()
{
@@ -47,6 +55,9 @@ namespace TensorFlowNET.Examples.ImageProcess
Tensor resized_image_tensor = graph.OperationByName("Placeholder");
Tensor final_tensor = graph.OperationByName("final_result");
Tensor ground_truth_input = graph.OperationByName("input/GroundTruthInput");
Operation train_step = graph.OperationByName("train/GradientDescent");
Tensor bottleneck_input = graph.OperationByName("input/BottleneckInputPlaceholder");
Tensor cross_entropy = graph.OperationByName("cross_entropy/sparse_softmax_cross_entropy_loss/value");

var sw = new Stopwatch();

@@ -72,11 +83,104 @@ namespace TensorFlowNET.Examples.ImageProcess
// Merge all the summaries and write them out to the summaries_dir
var merged = tf.summary.merge_all();
var train_writer = tf.summary.FileWriter(summaries_dir + "/train", sess.graph);
var validation_writer = tf.summary.FileWriter(summaries_dir + "/validation", sess.graph);

// Create a train saver that is used to restore values into an eval graph
// when exporting models.
var train_saver = tf.train.Saver();

for (int i = 0; i < how_many_training_steps; i++)
{
var (train_bottlenecks, train_ground_truth, _) = get_random_cached_bottlenecks(
sess, image_lists, train_batch_size, "training",
bottleneck_dir, image_dir, jpeg_data_tensor,
decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
tfhub_module);

// Feed the bottlenecks and ground truth into the graph, and run a training
// step. Capture training summaries for TensorBoard with the `merged` op.
var results = sess.run(
new ITensorOrOperation[] { merged, train_step },
new FeedItem(bottleneck_input, train_bottlenecks),
new FeedItem(ground_truth_input, train_ground_truth));
var train_summary = results[0];

// TODO
train_writer.add_summary(train_summary, i);

// Every so often, print out how well the graph is training.
bool is_last_step = (i + 1 == how_many_training_steps);
if ((i % eval_step_interval) == 0 || is_last_step)
{
results = sess.run(
new Tensor[] { evaluation_step, cross_entropy },
new FeedItem(bottleneck_input, train_bottlenecks),
new FeedItem(ground_truth_input, train_ground_truth));
(float train_accuracy, float cross_entropy_value) = (results[0], results[1]);
print($"{DateTime.Now}: Step {i}: Train accuracy = {train_accuracy * 100}%");
print($"{DateTime.Now}: Step {i}: Cross entropy = {cross_entropy_value}");

var (validation_bottlenecks, validation_ground_truth, _) = get_random_cached_bottlenecks(
sess, image_lists, validation_batch_size, "validation",
bottleneck_dir, image_dir, jpeg_data_tensor,
decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
tfhub_module);

// Run a validation step and capture training summaries for TensorBoard
// with the `merged` op.
results = sess.run(new Tensor[] { merged, evaluation_step },
new FeedItem(bottleneck_input, validation_bottlenecks),
new FeedItem(ground_truth_input, validation_ground_truth));

(string validation_summary, float validation_accuracy) = (results[0], results[1]);

validation_writer.add_summary(validation_summary, i);
print($"{DateTime.Now}: Step {i}: Validation accuracy = {validation_accuracy * 100}% (N={len(validation_bottlenecks)})");
}

// Store intermediate results
int intermediate_frequency = intermediate_store_frequency;
if (intermediate_frequency > 0 && i % intermediate_frequency == 0 && i > 0)
{

}
}

// After training is complete, force one last save of the train checkpoint.
train_saver.save(sess, CHECKPOINT_NAME);
});

return false;
}

private (NDArray, long[], string[]) get_random_cached_bottlenecks(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists,
int how_many, string category, string bottleneck_dir, string image_dir,
Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor,
Tensor bottleneck_tensor, string module_name)
{
var bottlenecks = new List<float[]>();
var ground_truths = new List<long>();
var filenames = new List<string>();
int class_count = image_lists.Keys.Count;
foreach (var unused_i in range(how_many))
{
int label_index = new Random().Next(class_count);
string label_name = image_lists.Keys.ToArray()[label_index];
int image_index = new Random().Next(MAX_NUM_IMAGES_PER_CLASS);
string image_name = get_image_path(image_lists, label_name, image_index,
image_dir, category);
var bottleneck = get_or_create_bottleneck(
sess, image_lists, label_name, image_index, image_dir, category,
bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
resized_input_tensor, bottleneck_tensor, module_name);
bottlenecks.Add(bottleneck);
ground_truths.Add(label_index);
filenames.Add(image_name);
}

return (bottlenecks.ToArray(), ground_truths.ToArray(), filenames.ToArray());
}

/// <summary>
/// Inserts the operations we need to evaluate the accuracy of our results.
/// </summary>


Loading…
Cancel
Save