diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs
index 771b0d7d..b1770576 100644
--- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs
+++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs
@@ -90,6 +90,12 @@ namespace Tensorflow
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case int val:
+ feed_dict_tensor[subfeed_t] = (NDArray)val;
+ break;
+ case long val:
+ feed_dict_tensor[subfeed_t] = (NDArray)val;
+ break;
+ case long[] val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case int[] val:
diff --git a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
index 2c5b55f6..5989b364 100644
--- a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
+++ b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
@@ -62,6 +62,9 @@ namespace Tensorflow
case "Single":
full_values.Add(float.NaN);
break;
+ case "String":
+ full_values.Add(float.NaN);
+ break;
default:
throw new NotImplementedException($"build_results tensor_values[0] {tensor_values[0].dtype.Name}");
}
diff --git a/src/TensorFlowNET.Core/Summaries/EventFileWriter.cs b/src/TensorFlowNET.Core/Summaries/EventFileWriter.cs
index d0faed84..30903137 100644
--- a/src/TensorFlowNET.Core/Summaries/EventFileWriter.cs
+++ b/src/TensorFlowNET.Core/Summaries/EventFileWriter.cs
@@ -5,17 +5,35 @@ using System.Text;
namespace Tensorflow.Summaries
{
+ ///
+ /// Creates a `EventFileWriter` and an event file to write to.
+ ///
public class EventFileWriter
{
string _logdir;
- Queue _event_queue;
+ // Represents a first-in, first-out collection of objects.
+ Queue _event_queue;
+ EventsWriter _ev_writer;
+ int _flush_secs;
+ Event _sentinel_event;
+ bool _closed;
+ EventLoggerThread _worker;
public EventFileWriter(string logdir, int max_queue = 10, int flush_secs= 120,
string filename_suffix = null)
{
_logdir = logdir;
Directory.CreateDirectory(_logdir);
- _event_queue = new Queue(max_queue);
+ _event_queue = new Queue(max_queue);
+ _ev_writer = new EventsWriter(Path.Combine(_logdir, "events"));
+ _flush_secs = flush_secs;
+ _sentinel_event = new Event();
+ if (!string.IsNullOrEmpty(filename_suffix))
+ // self._ev_writer.InitWithSuffix(compat.as_bytes(filename_suffix)))
+ throw new NotImplementedException("EventFileWriter filename_suffix is not null");
+ _closed = false;
+ _worker = new EventLoggerThread(_event_queue, _ev_writer, _flush_secs, _sentinel_event);
+ _worker.start();
}
}
}
diff --git a/src/TensorFlowNET.Core/Summaries/EventLoggerThread.cs b/src/TensorFlowNET.Core/Summaries/EventLoggerThread.cs
new file mode 100644
index 00000000..1ebd6de5
--- /dev/null
+++ b/src/TensorFlowNET.Core/Summaries/EventLoggerThread.cs
@@ -0,0 +1,52 @@
+using Google.Protobuf;
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Tensorflow.Summaries
+{
+ ///
+ /// Thread that logs events.
+ ///
+ public class EventLoggerThread
+ {
+ Queue _queue;
+ bool daemon;
+ EventsWriter _ev_writer;
+ int _flush_secs;
+ Event _sentinel_event;
+
+ public EventLoggerThread(Queue queue, EventsWriter ev_writer, int flush_secs, Event sentinel_event)
+ {
+ daemon = true;
+ _queue = queue;
+ _ev_writer = ev_writer;
+ _flush_secs = flush_secs;
+ _sentinel_event = sentinel_event;
+ }
+
+ public void start() => run();
+
+ public void run()
+ {
+ Task.Run(delegate
+ {
+ while (true)
+ {
+ if(_queue.Count == 0)
+ {
+ Thread.Sleep(_flush_secs * 1000);
+ continue;
+ }
+
+ var @event = _queue.Dequeue();
+ _ev_writer._WriteSerializedEvent(@event.ToByteArray());
+ Thread.Sleep(1000);
+ }
+ });
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Summaries/EventsWriter.cs b/src/TensorFlowNET.Core/Summaries/EventsWriter.cs
index b80a133e..2cae7ade 100644
--- a/src/TensorFlowNET.Core/Summaries/EventsWriter.cs
+++ b/src/TensorFlowNET.Core/Summaries/EventsWriter.cs
@@ -1,19 +1,22 @@
using System;
using System.Collections.Generic;
+using System.IO;
using System.Text;
namespace Tensorflow.Summaries
{
public class EventsWriter
{
+ string _file_prefix;
+
public EventsWriter(string file_prefix)
{
-
+ _file_prefix = file_prefix;
}
public void _WriteSerializedEvent(byte[] event_str)
{
-
+ File.WriteAllBytes(_file_prefix, event_str);
}
}
}
diff --git a/src/TensorFlowNET.Core/Summaries/SummaryToEventTransformer.cs b/src/TensorFlowNET.Core/Summaries/SummaryToEventTransformer.cs
index 1d29fe4f..7dc4addf 100644
--- a/src/TensorFlowNET.Core/Summaries/SummaryToEventTransformer.cs
+++ b/src/TensorFlowNET.Core/Summaries/SummaryToEventTransformer.cs
@@ -1,4 +1,5 @@
-using System;
+using Google.Protobuf;
+using System;
using System.Collections.Generic;
using System.Text;
@@ -9,5 +10,10 @@ namespace Tensorflow.Summaries
///
public abstract class SummaryToEventTransformer
{
+ public void add_summary(string summary, int global_step = 0)
+ {
+ var bytes = UTF8Encoding.Unicode.GetBytes(summary);
+ // var summ = Tensorflow.Summary.Parser.ParseFrom(bytes);
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
index eff138b5..0b64a212 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
@@ -71,6 +71,9 @@ namespace Tensorflow
case "Int32":
Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size);
break;
+ case "Int64":
+ Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size);
+ break;
case "Single":
Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size);
break;
@@ -80,24 +83,8 @@ namespace Tensorflow
case "Byte":
Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size);
break;
- //case "String":
- /*string ss = nd.Data()[0];
- var str = Marshal.StringToHGlobalAnsi(ss);
- ulong dst_len = c_api.TF_StringEncodedSize((ulong)ss.Length);
- var dataType1 = ToTFDataType(nd.dtype);
- // shape
- var dims1 = nd.shape.Select(x => (long)x).ToArray();
-
- var tfHandle1 = c_api.TF_AllocateTensor(dataType1,
- dims1,
- nd.ndim,
- dst_len + sizeof(Int64));
-
- dotHandle = c_api.TF_TensorData(tfHandle1);
- Marshal.WriteInt64(dotHandle, 0);
- c_api.TF_StringEncode(str, (ulong)ss.Length, dotHandle + sizeof(Int64), dst_len, status);
- return tfHandle1;*/
- break;
+ case "String":
+ return new Tensor(UTF8Encoding.UTF8.GetBytes(nd.Data(0)));
default:
throw new NotImplementedException($"Marshal.Copy failed for {nd.dtype.Name}.");
}
diff --git a/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs b/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs
index 6f4c84d1..61ba33b3 100644
--- a/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs
+++ b/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs
@@ -31,11 +31,19 @@ namespace TensorFlowNET.Examples.ImageProcess
string summaries_dir = Path.Join(data_dir, "retrain_logs");
string image_dir = Path.Join(data_dir, "flower_photos");
string bottleneck_dir = Path.Join(data_dir, "bottleneck");
+ // The location where variable checkpoints will be stored.
+ string CHECKPOINT_NAME = Path.Join(data_dir, "_retrain_checkpoint");
string tfhub_module = "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/3";
float testing_percentage = 0.1f;
float validation_percentage = 0.1f;
Tensor resized_image_tensor;
Dictionary> image_lists;
+ int how_many_training_steps = 200;
+ int eval_step_interval = 10;
+ int train_batch_size = 100;
+ int validation_batch_size = 100;
+ int intermediate_store_frequency = 0;
+ const int MAX_NUM_IMAGES_PER_CLASS = 134217727;
public bool Run()
{
@@ -47,6 +55,9 @@ namespace TensorFlowNET.Examples.ImageProcess
Tensor resized_image_tensor = graph.OperationByName("Placeholder");
Tensor final_tensor = graph.OperationByName("final_result");
Tensor ground_truth_input = graph.OperationByName("input/GroundTruthInput");
+ Operation train_step = graph.OperationByName("train/GradientDescent");
+ Tensor bottleneck_input = graph.OperationByName("input/BottleneckInputPlaceholder");
+ Tensor cross_entropy = graph.OperationByName("cross_entropy/sparse_softmax_cross_entropy_loss/value");
var sw = new Stopwatch();
@@ -72,11 +83,104 @@ namespace TensorFlowNET.Examples.ImageProcess
// Merge all the summaries and write them out to the summaries_dir
var merged = tf.summary.merge_all();
var train_writer = tf.summary.FileWriter(summaries_dir + "/train", sess.graph);
+ var validation_writer = tf.summary.FileWriter(summaries_dir + "/validation", sess.graph);
+
+ // Create a train saver that is used to restore values into an eval graph
+ // when exporting models.
+ var train_saver = tf.train.Saver();
+
+ for (int i = 0; i < how_many_training_steps; i++)
+ {
+ var (train_bottlenecks, train_ground_truth, _) = get_random_cached_bottlenecks(
+ sess, image_lists, train_batch_size, "training",
+ bottleneck_dir, image_dir, jpeg_data_tensor,
+ decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
+ tfhub_module);
+
+ // Feed the bottlenecks and ground truth into the graph, and run a training
+ // step. Capture training summaries for TensorBoard with the `merged` op.
+ var results = sess.run(
+ new ITensorOrOperation[] { merged, train_step },
+ new FeedItem(bottleneck_input, train_bottlenecks),
+ new FeedItem(ground_truth_input, train_ground_truth));
+ var train_summary = results[0];
+
+ // TODO
+ train_writer.add_summary(train_summary, i);
+
+ // Every so often, print out how well the graph is training.
+ bool is_last_step = (i + 1 == how_many_training_steps);
+ if ((i % eval_step_interval) == 0 || is_last_step)
+ {
+ results = sess.run(
+ new Tensor[] { evaluation_step, cross_entropy },
+ new FeedItem(bottleneck_input, train_bottlenecks),
+ new FeedItem(ground_truth_input, train_ground_truth));
+ (float train_accuracy, float cross_entropy_value) = (results[0], results[1]);
+ print($"{DateTime.Now}: Step {i}: Train accuracy = {train_accuracy * 100}%");
+ print($"{DateTime.Now}: Step {i}: Cross entropy = {cross_entropy_value}");
+
+ var (validation_bottlenecks, validation_ground_truth, _) = get_random_cached_bottlenecks(
+ sess, image_lists, validation_batch_size, "validation",
+ bottleneck_dir, image_dir, jpeg_data_tensor,
+ decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
+ tfhub_module);
+
+ // Run a validation step and capture training summaries for TensorBoard
+ // with the `merged` op.
+ results = sess.run(new Tensor[] { merged, evaluation_step },
+ new FeedItem(bottleneck_input, validation_bottlenecks),
+ new FeedItem(ground_truth_input, validation_ground_truth));
+
+ (string validation_summary, float validation_accuracy) = (results[0], results[1]);
+
+ validation_writer.add_summary(validation_summary, i);
+ print($"{DateTime.Now}: Step {i}: Validation accuracy = {validation_accuracy * 100}% (N={len(validation_bottlenecks)})");
+ }
+
+ // Store intermediate results
+ int intermediate_frequency = intermediate_store_frequency;
+ if (intermediate_frequency > 0 && i % intermediate_frequency == 0 && i > 0)
+ {
+
+ }
+ }
+
+ // After training is complete, force one last save of the train checkpoint.
+ train_saver.save(sess, CHECKPOINT_NAME);
});
return false;
}
+ private (NDArray, long[], string[]) get_random_cached_bottlenecks(Session sess, Dictionary> image_lists,
+ int how_many, string category, string bottleneck_dir, string image_dir,
+ Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor,
+ Tensor bottleneck_tensor, string module_name)
+ {
+ var bottlenecks = new List();
+ var ground_truths = new List();
+ var filenames = new List();
+ int class_count = image_lists.Keys.Count;
+ foreach (var unused_i in range(how_many))
+ {
+ int label_index = new Random().Next(class_count);
+ string label_name = image_lists.Keys.ToArray()[label_index];
+ int image_index = new Random().Next(MAX_NUM_IMAGES_PER_CLASS);
+ string image_name = get_image_path(image_lists, label_name, image_index,
+ image_dir, category);
+ var bottleneck = get_or_create_bottleneck(
+ sess, image_lists, label_name, image_index, image_dir, category,
+ bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
+ resized_input_tensor, bottleneck_tensor, module_name);
+ bottlenecks.Add(bottleneck);
+ ground_truths.Add(label_index);
+ filenames.Add(image_name);
+ }
+
+ return (bottlenecks.ToArray(), ground_truths.ToArray(), filenames.ToArray());
+ }
+
///
/// Inserts the operations we need to evaluate the accuracy of our results.
///