diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 771b0d7d..b1770576 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -90,6 +90,12 @@ namespace Tensorflow feed_dict_tensor[subfeed_t] = (NDArray)val; break; case int val: + feed_dict_tensor[subfeed_t] = (NDArray)val; + break; + case long val: + feed_dict_tensor[subfeed_t] = (NDArray)val; + break; + case long[] val: feed_dict_tensor[subfeed_t] = (NDArray)val; break; case int[] val: diff --git a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs index 2c5b55f6..5989b364 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs @@ -62,6 +62,9 @@ namespace Tensorflow case "Single": full_values.Add(float.NaN); break; + case "String": + full_values.Add(float.NaN); + break; default: throw new NotImplementedException($"build_results tensor_values[0] {tensor_values[0].dtype.Name}"); } diff --git a/src/TensorFlowNET.Core/Summaries/EventFileWriter.cs b/src/TensorFlowNET.Core/Summaries/EventFileWriter.cs index d0faed84..30903137 100644 --- a/src/TensorFlowNET.Core/Summaries/EventFileWriter.cs +++ b/src/TensorFlowNET.Core/Summaries/EventFileWriter.cs @@ -5,17 +5,35 @@ using System.Text; namespace Tensorflow.Summaries { + /// + /// Creates a `EventFileWriter` and an event file to write to. + /// public class EventFileWriter { string _logdir; - Queue _event_queue; + // Represents a first-in, first-out collection of objects. + Queue _event_queue; + EventsWriter _ev_writer; + int _flush_secs; + Event _sentinel_event; + bool _closed; + EventLoggerThread _worker; public EventFileWriter(string logdir, int max_queue = 10, int flush_secs= 120, string filename_suffix = null) { _logdir = logdir; Directory.CreateDirectory(_logdir); - _event_queue = new Queue(max_queue); + _event_queue = new Queue(max_queue); + _ev_writer = new EventsWriter(Path.Combine(_logdir, "events")); + _flush_secs = flush_secs; + _sentinel_event = new Event(); + if (!string.IsNullOrEmpty(filename_suffix)) + // self._ev_writer.InitWithSuffix(compat.as_bytes(filename_suffix))) + throw new NotImplementedException("EventFileWriter filename_suffix is not null"); + _closed = false; + _worker = new EventLoggerThread(_event_queue, _ev_writer, _flush_secs, _sentinel_event); + _worker.start(); } } } diff --git a/src/TensorFlowNET.Core/Summaries/EventLoggerThread.cs b/src/TensorFlowNET.Core/Summaries/EventLoggerThread.cs new file mode 100644 index 00000000..1ebd6de5 --- /dev/null +++ b/src/TensorFlowNET.Core/Summaries/EventLoggerThread.cs @@ -0,0 +1,52 @@ +using Google.Protobuf; +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Tensorflow.Summaries +{ + /// + /// Thread that logs events. + /// + public class EventLoggerThread + { + Queue _queue; + bool daemon; + EventsWriter _ev_writer; + int _flush_secs; + Event _sentinel_event; + + public EventLoggerThread(Queue queue, EventsWriter ev_writer, int flush_secs, Event sentinel_event) + { + daemon = true; + _queue = queue; + _ev_writer = ev_writer; + _flush_secs = flush_secs; + _sentinel_event = sentinel_event; + } + + public void start() => run(); + + public void run() + { + Task.Run(delegate + { + while (true) + { + if(_queue.Count == 0) + { + Thread.Sleep(_flush_secs * 1000); + continue; + } + + var @event = _queue.Dequeue(); + _ev_writer._WriteSerializedEvent(@event.ToByteArray()); + Thread.Sleep(1000); + } + }); + } + } +} diff --git a/src/TensorFlowNET.Core/Summaries/EventsWriter.cs b/src/TensorFlowNET.Core/Summaries/EventsWriter.cs index b80a133e..2cae7ade 100644 --- a/src/TensorFlowNET.Core/Summaries/EventsWriter.cs +++ b/src/TensorFlowNET.Core/Summaries/EventsWriter.cs @@ -1,19 +1,22 @@ using System; using System.Collections.Generic; +using System.IO; using System.Text; namespace Tensorflow.Summaries { public class EventsWriter { + string _file_prefix; + public EventsWriter(string file_prefix) { - + _file_prefix = file_prefix; } public void _WriteSerializedEvent(byte[] event_str) { - + File.WriteAllBytes(_file_prefix, event_str); } } } diff --git a/src/TensorFlowNET.Core/Summaries/SummaryToEventTransformer.cs b/src/TensorFlowNET.Core/Summaries/SummaryToEventTransformer.cs index 1d29fe4f..7dc4addf 100644 --- a/src/TensorFlowNET.Core/Summaries/SummaryToEventTransformer.cs +++ b/src/TensorFlowNET.Core/Summaries/SummaryToEventTransformer.cs @@ -1,4 +1,5 @@ -using System; +using Google.Protobuf; +using System; using System.Collections.Generic; using System.Text; @@ -9,5 +10,10 @@ namespace Tensorflow.Summaries /// public abstract class SummaryToEventTransformer { + public void add_summary(string summary, int global_step = 0) + { + var bytes = UTF8Encoding.Unicode.GetBytes(summary); + // var summ = Tensorflow.Summary.Parser.ParseFrom(bytes); + } } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index eff138b5..0b64a212 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -71,6 +71,9 @@ namespace Tensorflow case "Int32": Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); break; + case "Int64": + Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); + break; case "Single": Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); break; @@ -80,24 +83,8 @@ namespace Tensorflow case "Byte": Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); break; - //case "String": - /*string ss = nd.Data()[0]; - var str = Marshal.StringToHGlobalAnsi(ss); - ulong dst_len = c_api.TF_StringEncodedSize((ulong)ss.Length); - var dataType1 = ToTFDataType(nd.dtype); - // shape - var dims1 = nd.shape.Select(x => (long)x).ToArray(); - - var tfHandle1 = c_api.TF_AllocateTensor(dataType1, - dims1, - nd.ndim, - dst_len + sizeof(Int64)); - - dotHandle = c_api.TF_TensorData(tfHandle1); - Marshal.WriteInt64(dotHandle, 0); - c_api.TF_StringEncode(str, (ulong)ss.Length, dotHandle + sizeof(Int64), dst_len, status); - return tfHandle1;*/ - break; + case "String": + return new Tensor(UTF8Encoding.UTF8.GetBytes(nd.Data(0))); default: throw new NotImplementedException($"Marshal.Copy failed for {nd.dtype.Name}."); } diff --git a/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs b/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs index 6f4c84d1..61ba33b3 100644 --- a/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs +++ b/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs @@ -31,11 +31,19 @@ namespace TensorFlowNET.Examples.ImageProcess string summaries_dir = Path.Join(data_dir, "retrain_logs"); string image_dir = Path.Join(data_dir, "flower_photos"); string bottleneck_dir = Path.Join(data_dir, "bottleneck"); + // The location where variable checkpoints will be stored. + string CHECKPOINT_NAME = Path.Join(data_dir, "_retrain_checkpoint"); string tfhub_module = "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/3"; float testing_percentage = 0.1f; float validation_percentage = 0.1f; Tensor resized_image_tensor; Dictionary> image_lists; + int how_many_training_steps = 200; + int eval_step_interval = 10; + int train_batch_size = 100; + int validation_batch_size = 100; + int intermediate_store_frequency = 0; + const int MAX_NUM_IMAGES_PER_CLASS = 134217727; public bool Run() { @@ -47,6 +55,9 @@ namespace TensorFlowNET.Examples.ImageProcess Tensor resized_image_tensor = graph.OperationByName("Placeholder"); Tensor final_tensor = graph.OperationByName("final_result"); Tensor ground_truth_input = graph.OperationByName("input/GroundTruthInput"); + Operation train_step = graph.OperationByName("train/GradientDescent"); + Tensor bottleneck_input = graph.OperationByName("input/BottleneckInputPlaceholder"); + Tensor cross_entropy = graph.OperationByName("cross_entropy/sparse_softmax_cross_entropy_loss/value"); var sw = new Stopwatch(); @@ -72,11 +83,104 @@ namespace TensorFlowNET.Examples.ImageProcess // Merge all the summaries and write them out to the summaries_dir var merged = tf.summary.merge_all(); var train_writer = tf.summary.FileWriter(summaries_dir + "/train", sess.graph); + var validation_writer = tf.summary.FileWriter(summaries_dir + "/validation", sess.graph); + + // Create a train saver that is used to restore values into an eval graph + // when exporting models. + var train_saver = tf.train.Saver(); + + for (int i = 0; i < how_many_training_steps; i++) + { + var (train_bottlenecks, train_ground_truth, _) = get_random_cached_bottlenecks( + sess, image_lists, train_batch_size, "training", + bottleneck_dir, image_dir, jpeg_data_tensor, + decoded_image_tensor, resized_image_tensor, bottleneck_tensor, + tfhub_module); + + // Feed the bottlenecks and ground truth into the graph, and run a training + // step. Capture training summaries for TensorBoard with the `merged` op. + var results = sess.run( + new ITensorOrOperation[] { merged, train_step }, + new FeedItem(bottleneck_input, train_bottlenecks), + new FeedItem(ground_truth_input, train_ground_truth)); + var train_summary = results[0]; + + // TODO + train_writer.add_summary(train_summary, i); + + // Every so often, print out how well the graph is training. + bool is_last_step = (i + 1 == how_many_training_steps); + if ((i % eval_step_interval) == 0 || is_last_step) + { + results = sess.run( + new Tensor[] { evaluation_step, cross_entropy }, + new FeedItem(bottleneck_input, train_bottlenecks), + new FeedItem(ground_truth_input, train_ground_truth)); + (float train_accuracy, float cross_entropy_value) = (results[0], results[1]); + print($"{DateTime.Now}: Step {i}: Train accuracy = {train_accuracy * 100}%"); + print($"{DateTime.Now}: Step {i}: Cross entropy = {cross_entropy_value}"); + + var (validation_bottlenecks, validation_ground_truth, _) = get_random_cached_bottlenecks( + sess, image_lists, validation_batch_size, "validation", + bottleneck_dir, image_dir, jpeg_data_tensor, + decoded_image_tensor, resized_image_tensor, bottleneck_tensor, + tfhub_module); + + // Run a validation step and capture training summaries for TensorBoard + // with the `merged` op. + results = sess.run(new Tensor[] { merged, evaluation_step }, + new FeedItem(bottleneck_input, validation_bottlenecks), + new FeedItem(ground_truth_input, validation_ground_truth)); + + (string validation_summary, float validation_accuracy) = (results[0], results[1]); + + validation_writer.add_summary(validation_summary, i); + print($"{DateTime.Now}: Step {i}: Validation accuracy = {validation_accuracy * 100}% (N={len(validation_bottlenecks)})"); + } + + // Store intermediate results + int intermediate_frequency = intermediate_store_frequency; + if (intermediate_frequency > 0 && i % intermediate_frequency == 0 && i > 0) + { + + } + } + + // After training is complete, force one last save of the train checkpoint. + train_saver.save(sess, CHECKPOINT_NAME); }); return false; } + private (NDArray, long[], string[]) get_random_cached_bottlenecks(Session sess, Dictionary> image_lists, + int how_many, string category, string bottleneck_dir, string image_dir, + Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor, + Tensor bottleneck_tensor, string module_name) + { + var bottlenecks = new List(); + var ground_truths = new List(); + var filenames = new List(); + int class_count = image_lists.Keys.Count; + foreach (var unused_i in range(how_many)) + { + int label_index = new Random().Next(class_count); + string label_name = image_lists.Keys.ToArray()[label_index]; + int image_index = new Random().Next(MAX_NUM_IMAGES_PER_CLASS); + string image_name = get_image_path(image_lists, label_name, image_index, + image_dir, category); + var bottleneck = get_or_create_bottleneck( + sess, image_lists, label_name, image_index, image_dir, category, + bottleneck_dir, jpeg_data_tensor, decoded_image_tensor, + resized_input_tensor, bottleneck_tensor, module_name); + bottlenecks.Add(bottleneck); + ground_truths.Add(label_index); + filenames.Add(image_name); + } + + return (bottlenecks.ToArray(), ground_truths.ToArray(), filenames.ToArray()); + } + /// /// Inserts the operations we need to evaluate the accuracy of our results. ///