diff --git a/graph/InceptionV3.meta b/graph/InceptionV3.meta index c23a3a36..0ded6221 100644 Binary files a/graph/InceptionV3.meta and b/graph/InceptionV3.meta differ diff --git a/src/TensorFlowNET.Core/Framework/graph_util_impl.cs b/src/TensorFlowNET.Core/Framework/graph_util_impl.cs index 7c21d33e..0ea334ea 100644 --- a/src/TensorFlowNET.Core/Framework/graph_util_impl.cs +++ b/src/TensorFlowNET.Core/Framework/graph_util_impl.cs @@ -1,6 +1,9 @@ -using System; +using NumSharp; +using System; using System.Collections.Generic; +using System.Linq; using System.Text; +using static Tensorflow.Python; namespace Tensorflow { @@ -23,7 +26,181 @@ namespace Tensorflow { // This graph only includes the nodes needed to evaluate the output nodes, and // removes unneeded nodes like those involved in saving and assignment. - throw new NotImplementedException(""); + var inference_graph = extract_sub_graph(input_graph_def, output_node_names); + + // Identify the ops in the graph. + var map_name_to_node = new Dictionary(); + inference_graph.Node.Select(x => map_name_to_node[x.Name] = x).ToArray(); + + // Get list of variables. + var variable_names = new List(); + var variable_dict_names = new List(); + + foreach (var node in inference_graph.Node) + { + if(new string[] { "Variable", "VariableV2", "VarHandleOp" }.Contains(node.Op)) + { + var variable_name = node.Name; + + variable_dict_names.Add(variable_name); + if (node.Op == "VarHandleOp") + variable_names.Add(variable_name + "/Read/ReadVariableOp:0"); + else + variable_names.Add(variable_name + ":0"); + } + else if (new string[] { "ReadVariableOp", "ResourceGather" }.Contains(node.Op)) + { + // There can be one or more Identity ops in between the ReadVariableOp and + // VarHandleOp. Store the Identity ops with the associated dtypes. + var source_op_name = get_input_name(node); + while(map_name_to_node[source_op_name].Op == "Identity") + { + throw new NotImplementedException("map_name_to_node[source_op_name].Op"); + /*resource_identity_types[source_op_name] = node.attr["dtype"]; + source_op_name = get_input_name(map_name_to_node[source_op_name]);*/ + } + } + } + + // Gets map of variables and the associated data. + NDArray returned_variables = null; + if (variable_names != null) + returned_variables = sess.run(variable_names); + + var variables_data_map = new Dictionary(); + foreach(var (i, name) in enumerate(variable_dict_names)) + variables_data_map[name] = returned_variables[i]; + print($"Froze {len(returned_variables)} variables."); + + // Reconstruct the graph with constants in place of variables. + var output_graph_def = new GraphDef(); + int how_many_converted = 0; + foreach(var input_node in inference_graph.Node) + { + var output_node = new NodeDef(); + if (variables_data_map.ContainsKey(input_node.Name)) + { + var data = variables_data_map[input_node.Name]; + output_node = create_const_op(input_node.Name, input_node.Attr["dtype"], + data, data.shape); + how_many_converted += 1; + } + // else if (resource_identity_types.ContainsKey(input_node.Name)) + else if(input_node.Op == "ReadVariableOp") + { + output_node.Op = "Identity"; + output_node.Name = input_node.Name; + output_node.Input.AddRange(new[] { input_node.Input[0] }); + output_node.Attr["T"] = input_node.Attr["dtype"]; + } + else if (input_node.Op == "ResourceGather") + { + + } + else + { + output_node.MergeFrom(input_node); + } + + output_graph_def.Node.AddRange(new[] { output_node }); + } + + output_graph_def.Library = inference_graph.Library; + print($"Converted {how_many_converted} variables to const ops."); + return output_graph_def; + } + + private NodeDef create_const_op(string node_name, AttrValue dtype, NDArray data, int[] data_shape = null) + { + var output_node = new NodeDef + { + Op = "Const", + Name = node_name + }; + output_node.Attr["dtype"] = dtype; + output_node.Attr["value"] = new AttrValue() + { + Tensor = tensor_util.make_tensor_proto( + data, dtype: dtype.Type.as_tf_dtype(), shape: data_shape) + }; + + return output_node; + } + + /// + /// Gets the name of the first input. Errors if suffix is not :0. + /// + /// + /// + private string get_input_name(NodeDef node) + { + var details = node.Input[0].Split(':'); + if (details.Length == 1 || int.Parse(details[1]) == 0) + return details[0]; + // While it is valid for input tensors to have a suffix that is not :0, this + // method is used to find the associated ops, not tensors, and therefore it + // is not valid. + throw new ValueError($"Tensor name '{node.Input[0]}' is invalid."); + } + + + private GraphDef extract_sub_graph(GraphDef graph_def, string[] dest_nodes) + { + var (name_to_input_name, name_to_node, name_to_seq_num) = _extract_graph_summary( + graph_def); + + var nodes_to_keep = _bfs_for_reachable_nodes(dest_nodes, name_to_input_name); + var nodes_to_keep_list = nodes_to_keep.OrderBy(n => name_to_seq_num[n]).ToArray(); + // Now construct the output GraphDef + var output = new GraphDef(); + foreach (var n in nodes_to_keep_list) + output.Node.Add(name_to_node[n]); // need deep clone? + output.Library = graph_def.Library; + output.Versions = graph_def.Versions; + + return output; + } + + private string[] _bfs_for_reachable_nodes(string[] target_nodes, Dictionary name_to_input_name) + { + var nodes_to_keep = new List(); + var next_to_visit = target_nodes.Select(x => x).ToList(); + while(next_to_visit.Count > 0) + { + var node = next_to_visit[0]; + next_to_visit.RemoveAt(0); + if (nodes_to_keep.Contains(node)) + continue; + nodes_to_keep.Add(node); + if (name_to_input_name.Keys.Contains(node)) + next_to_visit.AddRange(name_to_input_name[node]); + } + + return nodes_to_keep.ToArray(); + } + + private (Dictionary, Dictionary, Dictionary) _extract_graph_summary(GraphDef graph_def) + { + var name_to_input_name = new Dictionary(); + var name_to_node = new Dictionary(); + var name_to_seq_num = new Dictionary(); + + int seq = 0; + foreach (var node in graph_def.Node) + { + var n = _node_name(node.Name); + name_to_node[n] = node; + name_to_input_name[n] = node.Input.Select(x => _node_name(x)).ToArray(); + name_to_seq_num[n] = seq; + seq++; + } + + return (name_to_input_name, name_to_node, name_to_seq_num); + } + + private string _node_name(string n) + { + return n.StartsWith("^") ? n.Substring(1) : n.Split(':')[0]; } private string get_input_name(string node) diff --git a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs index e18eebc0..9a233524 100644 --- a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs +++ b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs @@ -15,14 +15,13 @@ namespace Tensorflow public _ElementFetchMapper(object[] fetches, Func, object> contraction_fn) { var g = ops.get_default_graph(); - ITensorOrOperation el = null; foreach(var fetch in fetches) { - el = g.as_graph_element(fetch, allow_tensor: true, allow_operation: true); + var el = g.as_graph_element(fetch, allow_tensor: true, allow_operation: true); + _unique_fetches.Add(el); } - - _unique_fetches.Add(el); + _contraction_fn = contraction_fn; } diff --git a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs index 5989b364..19383183 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs @@ -33,6 +33,8 @@ namespace Tensorflow _fetches.Add(val); _ops.Add(false); break; + default: + throw new NotImplementedException("_FetchHandler fetch"); } } diff --git a/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs index 4b2691a8..b8188985 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchMapper.cs @@ -8,12 +8,14 @@ namespace Tensorflow { public class _FetchMapper { - protected List _unique_fetches = new List(); - + protected List _unique_fetches = new List(); + protected List _value_indices = new List(); public static _FetchMapper for_fetch(object fetch) { var fetches = fetch.GetType().IsArray ? (object[])fetch : new object[] { fetch }; + if(fetch is List fetches1) + return new _ListFetchMapper(fetches1.ToArray()); if (fetch.GetType().IsArray) return new _ListFetchMapper(fetches); else @@ -28,7 +30,7 @@ namespace Tensorflow return nd; } - public virtual List unique_fetches() + public virtual List unique_fetches() { return _unique_fetches; } diff --git a/src/TensorFlowNET.Core/Sessions/_ListFetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_ListFetchMapper.cs index f94a19da..c80e3f9e 100644 --- a/src/TensorFlowNET.Core/Sessions/_ListFetchMapper.cs +++ b/src/TensorFlowNET.Core/Sessions/_ListFetchMapper.cs @@ -12,7 +12,46 @@ namespace Tensorflow public _ListFetchMapper(object[] fetches) { _mappers = fetches.Select(fetch => _FetchMapper.for_fetch(fetch)).ToArray(); - _unique_fetches.AddRange(fetches); + (_unique_fetches, _value_indices) = _uniquify_fetches(_mappers); + } + + private (List, List) _uniquify_fetches(_FetchMapper[] fetch_mappers) + { + var unique_fetches = new List(); + var value_indices = new List(); + var seen_fetches = new Dictionary(); + + foreach (var m in fetch_mappers) + { + var m_value_indices = new List(); + foreach (var uf in m.unique_fetches()) + { + switch (uf) + { + case Tensor f: + if (!seen_fetches.ContainsKey(f)) + { + seen_fetches[f] = seen_fetches.Count; + unique_fetches.Add(f); + } + m_value_indices.Add(seen_fetches.Count - 1); + break; + case Operation f: + if (!seen_fetches.ContainsKey(f)) + { + seen_fetches[f] = seen_fetches.Count; + unique_fetches.Add(f); + } + m_value_indices.Add(seen_fetches.Count - 1); + break; + default: + throw new NotImplementedException("_uniquify_fetches"); + } + } + value_indices.Add(m_value_indices.ToArray()); + } + + return (unique_fetches, value_indices); } } } diff --git a/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs b/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs index 373d7355..6794931b 100644 --- a/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs +++ b/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs @@ -3,12 +3,14 @@ using NumSharp; using System; using System.Collections.Generic; using System.Diagnostics; +using System.Drawing; using System.IO; using System.Linq; using System.Text; using Tensorflow; using TensorFlowNET.Examples.Utility; using static Tensorflow.Python; +using Console = Colorful.Console; namespace TensorFlowNET.Examples.ImageProcess { @@ -84,7 +86,7 @@ namespace TensorFlowNET.Examples.ImageProcess var sw = new Stopwatch(); - with(tf.Session(graph), sess => + return with(tf.Session(graph), sess => { // Initialize all weights: for the module to their pretrained values, // and for the newly added retraining layer to random initial values. @@ -111,6 +113,7 @@ namespace TensorFlowNET.Examples.ImageProcess // Create a train saver that is used to restore values into an eval graph // when exporting models. var train_saver = tf.train.Saver(); + sw.Restart(); for (int i = 0; i < how_many_training_steps; i++) { @@ -140,8 +143,7 @@ namespace TensorFlowNET.Examples.ImageProcess new FeedItem(bottleneck_input, train_bottlenecks), new FeedItem(ground_truth_input, train_ground_truth)); (float train_accuracy, float cross_entropy_value) = (results[0], results[1]); - print($"{DateTime.Now}: Step {i}: Train accuracy = {train_accuracy * 100}%"); - print($"{DateTime.Now}: Step {i}: Cross entropy = {cross_entropy_value}"); + print($"{DateTime.Now}: Step {i + 1}: Train accuracy = {train_accuracy * 100}%, Cross entropy = {cross_entropy_value.ToString("G4")}"); var (validation_bottlenecks, validation_ground_truth, _) = get_random_cached_bottlenecks( sess, image_lists, validation_batch_size, "validation", @@ -158,7 +160,8 @@ namespace TensorFlowNET.Examples.ImageProcess (string validation_summary, float validation_accuracy) = (results[0], results[1]); validation_writer.add_summary(validation_summary, i); - print($"{DateTime.Now}: Step {i + 1}: Validation accuracy = {validation_accuracy * 100}% (N={len(validation_bottlenecks)})"); + print($"{DateTime.Now}: Step {i + 1}: Validation accuracy = {validation_accuracy * 100}% (N={len(validation_bottlenecks)}) {sw.ElapsedMilliseconds}ms"); + sw.Restart(); } // Store intermediate results @@ -180,12 +183,11 @@ namespace TensorFlowNET.Examples.ImageProcess // Write out the trained graph and labels with the weights stored as // constants. - print($"final test accuracy: {test_accuracy}"); print($"Save final result to : {output_graph}"); save_graph_to_file(output_graph, class_count); + File.WriteAllText(output_labels, string.Join("\n", image_lists.Keys)); + return test_accuracy > 0.75f; }); - - return false; } /// @@ -215,6 +217,8 @@ namespace TensorFlowNET.Examples.ImageProcess new FeedItem(bottleneck_input, test_bottlenecks), new FeedItem(ground_truth_input, test_ground_truth)); + print($"final test accuracy: {((float)results[0] * 100).ToString("G4")}% (N={len(test_bottlenecks)})"); + return (results[0], results[1]); }