diff --git a/src/TensorFlowNET.Core/APIs/tf.graph.cs b/src/TensorFlowNET.Core/APIs/tf.graph.cs
new file mode 100644
index 00000000..c343050a
--- /dev/null
+++ b/src/TensorFlowNET.Core/APIs/tf.graph.cs
@@ -0,0 +1,18 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public static partial class tf
+ {
+ public static graph_util_impl graph_util => new graph_util_impl();
+ public static Graph get_default_graph()
+ {
+ return ops.get_default_graph();
+ }
+
+ public static Graph Graph() => new Graph();
+
+ }
+}
diff --git a/src/TensorFlowNET.Core/Framework/tf.ops.cs b/src/TensorFlowNET.Core/APIs/tf.ops.cs
similarity index 100%
rename from src/TensorFlowNET.Core/Framework/tf.ops.cs
rename to src/TensorFlowNET.Core/APIs/tf.ops.cs
diff --git a/src/TensorFlowNET.Core/Variables/tf.variable.cs b/src/TensorFlowNET.Core/APIs/tf.variable.cs
similarity index 100%
rename from src/TensorFlowNET.Core/Variables/tf.variable.cs
rename to src/TensorFlowNET.Core/APIs/tf.variable.cs
diff --git a/src/TensorFlowNET.Core/Framework/graph_util_impl.cs b/src/TensorFlowNET.Core/Framework/graph_util_impl.cs
new file mode 100644
index 00000000..7c21d33e
--- /dev/null
+++ b/src/TensorFlowNET.Core/Framework/graph_util_impl.cs
@@ -0,0 +1,34 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public class graph_util_impl
+ {
+ ///
+ /// Replaces all the variables in a graph with constants of the same values.
+ ///
+ /// Active TensorFlow session containing the variables.
+ /// GraphDef object holding the network.
+ /// List of name strings for the result nodes of the graph.
+ ///
+ ///
+ /// GraphDef containing a simplified version of the original.
+ public GraphDef convert_variables_to_constants(Session sess,
+ GraphDef input_graph_def,
+ string[] output_node_names,
+ string[] variable_names_whitelist = null,
+ string[] variable_names_blacklist = null)
+ {
+ // This graph only includes the nodes needed to evaluate the output nodes, and
+ // removes unneeded nodes like those involved in saving and assignment.
+ throw new NotImplementedException("");
+ }
+
+ private string get_input_name(string node)
+ {
+ throw new NotImplementedException("");
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs
index ea52a6de..f0db80ce 100644
--- a/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs
+++ b/src/TensorFlowNET.Core/Train/Saving/BaseSaverBuilder.cs
@@ -111,6 +111,12 @@ namespace Tensorflow
var cols = graph.get_collection(collection_type);
switch (cols)
{
+ case List values:
+ foreach (var element in values) ;
+ break;
+ case List values:
+ foreach (var element in values) ;
+ break;
case List values:
foreach (var element in values) ;
break;
@@ -129,6 +135,9 @@ namespace Tensorflow
case List values:
foreach (var element in values) ;
break;
+ case List values:
+ foreach (var element in values) ;
+ break;
default:
throw new NotImplementedException("_build_internal.check_collection_list");
}
@@ -172,7 +181,7 @@ namespace Tensorflow
string name = "restore_all")
{
var all_tensors = bulk_restore(filename_tensor, saveables, preferred_shard, restore_sequentially);
- var assign_ops = new List();
+ var assign_ops = new List();
int idx = 0;
// Load and optionally reshape on the CPU, as string tensors are not
@@ -190,7 +199,7 @@ namespace Tensorflow
var saveable_tensors = all_tensors.Skip(idx).Take(saveable.specs.Length);
idx += saveable.specs.Length;
var restored = saveable.restore(saveable_tensors.ToArray(), shapes == null ? null : shapes.ToArray());
- assign_ops.Add(restored as Operation);
+ assign_ops.Add(restored as ITensorOrOperation);
}
return control_flow_ops.group(assign_ops.ToArray(), name: name);
diff --git a/src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs b/src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs
index 4bd224da..1ae4c585 100644
--- a/src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs
+++ b/src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs
@@ -73,23 +73,28 @@ namespace Tensorflow
string model_checkpoint_path,
List all_model_checkpoint_paths = null,
List all_model_checkpoint_timestamps = null,
- float? last_preserved_timestamp = null)
+ double? last_preserved_timestamp = null)
{
if (all_model_checkpoint_paths == null)
all_model_checkpoint_paths = new List();
+ if (!all_model_checkpoint_paths.Contains(model_checkpoint_path))
+ all_model_checkpoint_paths.Add(model_checkpoint_path);
+
// Relative paths need to be rewritten to be relative to the "save_dir"
// if model_checkpoint_path already contains "save_dir".
- all_model_checkpoint_paths.Add(model_checkpoint_path);
var coord_checkpoint_proto = new CheckpointState()
{
- ModelCheckpointPath = model_checkpoint_path,
- LastPreservedTimestamp = last_preserved_timestamp.Value
+ ModelCheckpointPath = model_checkpoint_path
};
+ if (last_preserved_timestamp.HasValue)
+ coord_checkpoint_proto.LastPreservedTimestamp = last_preserved_timestamp.Value;
+
coord_checkpoint_proto.AllModelCheckpointPaths.AddRange(all_model_checkpoint_paths);
- coord_checkpoint_proto.AllModelCheckpointTimestamps.AddRange(all_model_checkpoint_timestamps.Select(x => (double)x));
+ if (all_model_checkpoint_timestamps != null)
+ coord_checkpoint_proto.AllModelCheckpointTimestamps.AddRange(all_model_checkpoint_timestamps.Select(x => (double)x));
return coord_checkpoint_proto;
}
diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs
index 32b045e6..33acdfbc 100644
--- a/src/TensorFlowNET.Core/tf.cs
+++ b/src/TensorFlowNET.Core/tf.cs
@@ -48,13 +48,6 @@ namespace Tensorflow
public static string VERSION => c_api.StringPiece(c_api.TF_Version());
- public static Graph get_default_graph()
- {
- return ops.get_default_graph();
- }
-
- public static Graph Graph() => new Graph();
-
public static Session Session()
{
defaultSession = new Session();
diff --git a/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs b/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs
index 8f4a6ecf..373d7355 100644
--- a/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs
+++ b/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs
@@ -1,4 +1,5 @@
-using NumSharp;
+using Google.Protobuf;
+using NumSharp;
using System;
using System.Collections.Generic;
using System.Diagnostics;
@@ -31,17 +32,21 @@ namespace TensorFlowNET.Examples.ImageProcess
string summaries_dir = Path.Join(data_dir, "retrain_logs");
string image_dir = Path.Join(data_dir, "flower_photos");
string bottleneck_dir = Path.Join(data_dir, "bottleneck");
+ string output_graph = Path.Join(data_dir, "output_graph.pb");
+ string output_labels = Path.Join(data_dir, "output_labels.txt");
// The location where variable checkpoints will be stored.
string CHECKPOINT_NAME = Path.Join(data_dir, "_retrain_checkpoint");
string tfhub_module = "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/3";
+ string final_tensor_name = "final_result";
float testing_percentage = 0.1f;
float validation_percentage = 0.1f;
float learning_rate = 0.01f;
Tensor resized_image_tensor;
Dictionary> image_lists;
- int how_many_training_steps = 200;
+ int how_many_training_steps = 100;
int eval_step_interval = 10;
int train_batch_size = 100;
+ int test_batch_size = -1;
int validation_batch_size = 100;
int intermediate_store_frequency = 0;
int class_count = 0;
@@ -65,13 +70,13 @@ namespace TensorFlowNET.Examples.ImageProcess
{
(train_step, cross_entropy, bottleneck_input,
ground_truth_input, final_tensor) = add_final_retrain_ops(
- class_count, "final_result", bottleneck_tensor,
+ class_count, final_tensor_name, bottleneck_tensor,
wants_quantization, is_training: true);
});
/*Tensor bottleneck_tensor = graph.OperationByName("module_apply_default/hub_output/feature_vector/SpatialSqueeze");
Tensor resized_image_tensor = graph.OperationByName("Placeholder");
- Tensor final_tensor = graph.OperationByName("final_result");
+ Tensor final_tensor = graph.OperationByName(final_tensor_name);
Tensor ground_truth_input = graph.OperationByName("input/GroundTruthInput");
train_step = graph.OperationByName("train/GradientDescent");
Tensor bottleneck_input = graph.OperationByName("input/BottleneckInputPlaceholder");
@@ -105,7 +110,7 @@ namespace TensorFlowNET.Examples.ImageProcess
// Create a train saver that is used to restore values into an eval graph
// when exporting models.
- // var train_saver = tf.train.Saver();
+ var train_saver = tf.train.Saver();
for (int i = 0; i < how_many_training_steps; i++)
{
@@ -153,7 +158,7 @@ namespace TensorFlowNET.Examples.ImageProcess
(string validation_summary, float validation_accuracy) = (results[0], results[1]);
validation_writer.add_summary(validation_summary, i);
- print($"{DateTime.Now}: Step {i}: Validation accuracy = {validation_accuracy * 100}% (N={len(validation_bottlenecks)})");
+ print($"{DateTime.Now}: Step {i + 1}: Validation accuracy = {validation_accuracy * 100}% (N={len(validation_bottlenecks)})");
}
// Store intermediate results
@@ -165,13 +170,19 @@ namespace TensorFlowNET.Examples.ImageProcess
}
// After training is complete, force one last save of the train checkpoint.
- // train_saver.save(sess, CHECKPOINT_NAME);
+ train_saver.save(sess, CHECKPOINT_NAME);
// We've completed all our training, so run a final test evaluation on
// some new images we haven't used before.
- run_final_eval(sess, null, class_count, image_lists,
+ var (test_accuracy, predictions) = run_final_eval(sess, null, class_count, image_lists,
jpeg_data_tensor, decoded_image_tensor, resized_image_tensor,
bottleneck_tensor);
+
+ // Write out the trained graph and labels with the weights stored as
+ // constants.
+ print($"final test accuracy: {test_accuracy}");
+ print($"Save final result to : {output_graph}");
+ save_graph_to_file(output_graph, class_count);
});
return false;
@@ -188,26 +199,51 @@ namespace TensorFlowNET.Examples.ImageProcess
///
///
///
- private void run_final_eval(Session train_session, object module_spec, int class_count,
+ private (float, NDArray) run_final_eval(Session train_session, object module_spec, int class_count,
Dictionary> image_lists,
Tensor jpeg_data_tensor, Tensor decoded_image_tensor,
Tensor resized_image_tensor, Tensor bottleneck_tensor)
{
- /*var (eval_session, _, bottleneck_input, ground_truth_input, evaluation_step,
- prediction) = build_eval_session(module_spec, class_count);*/
+ var (test_bottlenecks, test_ground_truth, test_filenames) = get_random_cached_bottlenecks(train_session, image_lists,
+ test_batch_size, "testing", bottleneck_dir, image_dir, jpeg_data_tensor,
+ decoded_image_tensor, resized_image_tensor, bottleneck_tensor, tfhub_module);
+
+ var (eval_session, _, bottleneck_input, ground_truth_input, evaluation_step,
+ prediction) = build_eval_session(class_count);
+
+ var results = eval_session.run(new Tensor[] { evaluation_step, prediction },
+ new FeedItem(bottleneck_input, test_bottlenecks),
+ new FeedItem(ground_truth_input, test_ground_truth));
+
+ return (results[0], results[1]);
}
- private void build_eval_session(int class_count)
+ private (Session, Tensor, Tensor, Tensor, Tensor, Tensor)
+ build_eval_session(int class_count)
{
// If quantized, we need to create the correct eval graph for exporting.
var (eval_graph, bottleneck_tensor, resized_input_tensor, wants_quantization) = create_module_graph();
var eval_sess = tf.Session(graph: eval_graph);
+ Tensor evaluation_step = null;
+ Tensor prediction = null;
with(eval_graph.as_default(), graph =>
{
+ // Add the new layer for exporting.
+ var (_, _, bottleneck_input, ground_truth_input, final_tensor) =
+ add_final_retrain_ops(class_count, final_tensor_name, bottleneck_tensor,
+ wants_quantization, is_training: false);
+ // Now we need to restore the values from the training graph to the eval
+ // graph.
+ tf.train.Saver().restore(eval_sess, CHECKPOINT_NAME);
+ (evaluation_step, prediction) = add_evaluation_step(final_tensor,
+ ground_truth_input);
});
+
+ return (eval_sess, resized_input_tensor, bottleneck_input, ground_truth_input,
+ evaluation_step, prediction);
}
///
@@ -348,20 +384,42 @@ namespace TensorFlowNET.Examples.ImageProcess
var ground_truths = new List();
var filenames = new List();
class_count = image_lists.Keys.Count;
- foreach (var unused_i in range(how_many))
+ if (how_many >= 0)
{
- int label_index = new Random().Next(class_count);
- string label_name = image_lists.Keys.ToArray()[label_index];
- int image_index = new Random().Next(MAX_NUM_IMAGES_PER_CLASS);
- string image_name = get_image_path(image_lists, label_name, image_index,
- image_dir, category);
- var bottleneck = get_or_create_bottleneck(
- sess, image_lists, label_name, image_index, image_dir, category,
- bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
- resized_input_tensor, bottleneck_tensor, module_name);
- bottlenecks.Add(bottleneck);
- ground_truths.Add(label_index);
- filenames.Add(image_name);
+ // Retrieve a random sample of bottlenecks.
+ foreach (var unused_i in range(how_many))
+ {
+ int label_index = new Random().Next(class_count);
+ string label_name = image_lists.Keys.ToArray()[label_index];
+ int image_index = new Random().Next(MAX_NUM_IMAGES_PER_CLASS);
+ string image_name = get_image_path(image_lists, label_name, image_index,
+ image_dir, category);
+ var bottleneck = get_or_create_bottleneck(
+ sess, image_lists, label_name, image_index, image_dir, category,
+ bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
+ resized_input_tensor, bottleneck_tensor, module_name);
+ bottlenecks.Add(bottleneck);
+ ground_truths.Add(label_index);
+ filenames.Add(image_name);
+ }
+ }
+ else
+ {
+ // Retrieve all bottlenecks.
+ foreach (var (label_index, label_name) in enumerate(image_lists.Keys.ToArray()))
+ {
+ foreach(var (image_index, image_name) in enumerate(image_lists[label_name][category]))
+ {
+ var bottleneck = get_or_create_bottleneck(
+ sess, image_lists, label_name, image_index, image_dir, category,
+ bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
+ resized_input_tensor, bottleneck_tensor, module_name);
+
+ bottlenecks.Add(bottleneck);
+ ground_truths.Add(label_index);
+ filenames.Add(image_name);
+ }
+ }
}
return (bottlenecks.ToArray(), ground_truths.ToArray(), filenames.ToArray());
@@ -521,6 +579,20 @@ namespace TensorFlowNET.Examples.ImageProcess
return full_path;
}
+ ///
+ /// Saves an graph to file, creating a valid quantized one if necessary.
+ ///
+ ///
+ ///
+ private void save_graph_to_file(string graph_file_name, int class_count)
+ {
+ var (sess, _, _, _, _, _) = build_eval_session(class_count);
+ var graph = sess.graph;
+ var output_graph_def = tf.graph_util.convert_variables_to_constants(
+ sess, graph.as_graph_def(), new string[] { final_tensor_name });
+ File.WriteAllBytes(graph_file_name, output_graph_def.ToByteArray());
+ }
+
public void PrepareData()
{
// get a set of images to teach the network about the new classes