Browse Source

Froze trained model is completed. #248.

tags/v0.9
Oceania2018 6 years ago
parent
commit
a9330d277f
7 changed files with 240 additions and 17 deletions
  1. BIN
      graph/InceptionV3.meta
  2. +179
    -2
      src/TensorFlowNET.Core/Framework/graph_util_impl.cs
  3. +3
    -4
      src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs
  4. +2
    -0
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  5. +5
    -3
      src/TensorFlowNET.Core/Sessions/_FetchMapper.cs
  6. +40
    -1
      src/TensorFlowNET.Core/Sessions/_ListFetchMapper.cs
  7. +11
    -7
      test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs

BIN
graph/InceptionV3.meta View File


+ 179
- 2
src/TensorFlowNET.Core/Framework/graph_util_impl.cs View File

@@ -1,6 +1,9 @@
using System;
using NumSharp;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using static Tensorflow.Python;

namespace Tensorflow
{
@@ -23,7 +26,181 @@ namespace Tensorflow
{
// This graph only includes the nodes needed to evaluate the output nodes, and
// removes unneeded nodes like those involved in saving and assignment.
throw new NotImplementedException("");
var inference_graph = extract_sub_graph(input_graph_def, output_node_names);

// Identify the ops in the graph.
var map_name_to_node = new Dictionary<string, NodeDef>();
inference_graph.Node.Select(x => map_name_to_node[x.Name] = x).ToArray();

// Get list of variables.
var variable_names = new List<string>();
var variable_dict_names = new List<string>();

foreach (var node in inference_graph.Node)
{
if(new string[] { "Variable", "VariableV2", "VarHandleOp" }.Contains(node.Op))
{
var variable_name = node.Name;

variable_dict_names.Add(variable_name);
if (node.Op == "VarHandleOp")
variable_names.Add(variable_name + "/Read/ReadVariableOp:0");
else
variable_names.Add(variable_name + ":0");
}
else if (new string[] { "ReadVariableOp", "ResourceGather" }.Contains(node.Op))
{
// There can be one or more Identity ops in between the ReadVariableOp and
// VarHandleOp. Store the Identity ops with the associated dtypes.
var source_op_name = get_input_name(node);
while(map_name_to_node[source_op_name].Op == "Identity")
{
throw new NotImplementedException("map_name_to_node[source_op_name].Op");
/*resource_identity_types[source_op_name] = node.attr["dtype"];
source_op_name = get_input_name(map_name_to_node[source_op_name]);*/
}
}
}

// Gets map of variables and the associated data.
NDArray returned_variables = null;
if (variable_names != null)
returned_variables = sess.run(variable_names);

var variables_data_map = new Dictionary<string, NDArray>();
foreach(var (i, name) in enumerate(variable_dict_names))
variables_data_map[name] = returned_variables[i];
print($"Froze {len(returned_variables)} variables.");

// Reconstruct the graph with constants in place of variables.
var output_graph_def = new GraphDef();
int how_many_converted = 0;
foreach(var input_node in inference_graph.Node)
{
var output_node = new NodeDef();
if (variables_data_map.ContainsKey(input_node.Name))
{
var data = variables_data_map[input_node.Name];
output_node = create_const_op(input_node.Name, input_node.Attr["dtype"],
data, data.shape);
how_many_converted += 1;
}
// else if (resource_identity_types.ContainsKey(input_node.Name))
else if(input_node.Op == "ReadVariableOp")
{
output_node.Op = "Identity";
output_node.Name = input_node.Name;
output_node.Input.AddRange(new[] { input_node.Input[0] });
output_node.Attr["T"] = input_node.Attr["dtype"];
}
else if (input_node.Op == "ResourceGather")
{

}
else
{
output_node.MergeFrom(input_node);
}

output_graph_def.Node.AddRange(new[] { output_node });
}

output_graph_def.Library = inference_graph.Library;
print($"Converted {how_many_converted} variables to const ops.");
return output_graph_def;
}

private NodeDef create_const_op(string node_name, AttrValue dtype, NDArray data, int[] data_shape = null)
{
var output_node = new NodeDef
{
Op = "Const",
Name = node_name
};
output_node.Attr["dtype"] = dtype;
output_node.Attr["value"] = new AttrValue()
{
Tensor = tensor_util.make_tensor_proto(
data, dtype: dtype.Type.as_tf_dtype(), shape: data_shape)
};

return output_node;
}

/// <summary>
/// Gets the name of the first input. Errors if suffix is not :0.
/// </summary>
/// <param name="node"></param>
/// <returns></returns>
private string get_input_name(NodeDef node)
{
var details = node.Input[0].Split(':');
if (details.Length == 1 || int.Parse(details[1]) == 0)
return details[0];
// While it is valid for input tensors to have a suffix that is not :0, this
// method is used to find the associated ops, not tensors, and therefore it
// is not valid.
throw new ValueError($"Tensor name '{node.Input[0]}' is invalid.");
}


private GraphDef extract_sub_graph(GraphDef graph_def, string[] dest_nodes)
{
var (name_to_input_name, name_to_node, name_to_seq_num) = _extract_graph_summary(
graph_def);

var nodes_to_keep = _bfs_for_reachable_nodes(dest_nodes, name_to_input_name);
var nodes_to_keep_list = nodes_to_keep.OrderBy(n => name_to_seq_num[n]).ToArray();
// Now construct the output GraphDef
var output = new GraphDef();
foreach (var n in nodes_to_keep_list)
output.Node.Add(name_to_node[n]); // need deep clone?
output.Library = graph_def.Library;
output.Versions = graph_def.Versions;

return output;
}

private string[] _bfs_for_reachable_nodes(string[] target_nodes, Dictionary<string, string[]> name_to_input_name)
{
var nodes_to_keep = new List<string>();
var next_to_visit = target_nodes.Select(x => x).ToList();
while(next_to_visit.Count > 0)
{
var node = next_to_visit[0];
next_to_visit.RemoveAt(0);
if (nodes_to_keep.Contains(node))
continue;
nodes_to_keep.Add(node);
if (name_to_input_name.Keys.Contains(node))
next_to_visit.AddRange(name_to_input_name[node]);
}

return nodes_to_keep.ToArray();
}

private (Dictionary<string, string[]>, Dictionary<string, NodeDef>, Dictionary<string, int>) _extract_graph_summary(GraphDef graph_def)
{
var name_to_input_name = new Dictionary<string, string[]>();
var name_to_node = new Dictionary<string, NodeDef>();
var name_to_seq_num = new Dictionary<string, int>();

int seq = 0;
foreach (var node in graph_def.Node)
{
var n = _node_name(node.Name);
name_to_node[n] = node;
name_to_input_name[n] = node.Input.Select(x => _node_name(x)).ToArray();
name_to_seq_num[n] = seq;
seq++;
}

return (name_to_input_name, name_to_node, name_to_seq_num);
}

private string _node_name(string n)
{
return n.StartsWith("^") ? n.Substring(1) : n.Split(':')[0];
}

private string get_input_name(string node)


+ 3
- 4
src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs View File

@@ -15,14 +15,13 @@ namespace Tensorflow
public _ElementFetchMapper(object[] fetches, Func<List<NDArray>, object> contraction_fn)
{
var g = ops.get_default_graph();
ITensorOrOperation el = null;

foreach(var fetch in fetches)
{
el = g.as_graph_element(fetch, allow_tensor: true, allow_operation: true);
var el = g.as_graph_element(fetch, allow_tensor: true, allow_operation: true);
_unique_fetches.Add(el);
}

_unique_fetches.Add(el);
_contraction_fn = contraction_fn;
}



+ 2
- 0
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -33,6 +33,8 @@ namespace Tensorflow
_fetches.Add(val);
_ops.Add(false);
break;
default:
throw new NotImplementedException("_FetchHandler fetch");
}

}


+ 5
- 3
src/TensorFlowNET.Core/Sessions/_FetchMapper.cs View File

@@ -8,12 +8,14 @@ namespace Tensorflow
{
public class _FetchMapper
{
protected List<object> _unique_fetches = new List<object>();
protected List<ITensorOrOperation> _unique_fetches = new List<ITensorOrOperation>();
protected List<int[]> _value_indices = new List<int[]>();
public static _FetchMapper for_fetch(object fetch)
{
var fetches = fetch.GetType().IsArray ? (object[])fetch : new object[] { fetch };

if(fetch is List<string> fetches1)
return new _ListFetchMapper(fetches1.ToArray());
if (fetch.GetType().IsArray)
return new _ListFetchMapper(fetches);
else
@@ -28,7 +30,7 @@ namespace Tensorflow
return nd;
}

public virtual List<object> unique_fetches()
public virtual List<ITensorOrOperation> unique_fetches()
{
return _unique_fetches;
}


+ 40
- 1
src/TensorFlowNET.Core/Sessions/_ListFetchMapper.cs View File

@@ -12,7 +12,46 @@ namespace Tensorflow
public _ListFetchMapper(object[] fetches)
{
_mappers = fetches.Select(fetch => _FetchMapper.for_fetch(fetch)).ToArray();
_unique_fetches.AddRange(fetches);
(_unique_fetches, _value_indices) = _uniquify_fetches(_mappers);
}

private (List<ITensorOrOperation>, List<int[]>) _uniquify_fetches(_FetchMapper[] fetch_mappers)
{
var unique_fetches = new List<ITensorOrOperation>();
var value_indices = new List<int[]>();
var seen_fetches = new Dictionary<ITensorOrOperation, int>();

foreach (var m in fetch_mappers)
{
var m_value_indices = new List<int>();
foreach (var uf in m.unique_fetches())
{
switch (uf)
{
case Tensor f:
if (!seen_fetches.ContainsKey(f))
{
seen_fetches[f] = seen_fetches.Count;
unique_fetches.Add(f);
}
m_value_indices.Add(seen_fetches.Count - 1);
break;
case Operation f:
if (!seen_fetches.ContainsKey(f))
{
seen_fetches[f] = seen_fetches.Count;
unique_fetches.Add(f);
}
m_value_indices.Add(seen_fetches.Count - 1);
break;
default:
throw new NotImplementedException("_uniquify_fetches");
}
}
value_indices.Add(m_value_indices.ToArray());
}

return (unique_fetches, value_indices);
}
}
}

+ 11
- 7
test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs View File

@@ -3,12 +3,14 @@ using NumSharp;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Drawing;
using System.IO;
using System.Linq;
using System.Text;
using Tensorflow;
using TensorFlowNET.Examples.Utility;
using static Tensorflow.Python;
using Console = Colorful.Console;

namespace TensorFlowNET.Examples.ImageProcess
{
@@ -84,7 +86,7 @@ namespace TensorFlowNET.Examples.ImageProcess

var sw = new Stopwatch();

with(tf.Session(graph), sess =>
return with(tf.Session(graph), sess =>
{
// Initialize all weights: for the module to their pretrained values,
// and for the newly added retraining layer to random initial values.
@@ -111,6 +113,7 @@ namespace TensorFlowNET.Examples.ImageProcess
// Create a train saver that is used to restore values into an eval graph
// when exporting models.
var train_saver = tf.train.Saver();
sw.Restart();

for (int i = 0; i < how_many_training_steps; i++)
{
@@ -140,8 +143,7 @@ namespace TensorFlowNET.Examples.ImageProcess
new FeedItem(bottleneck_input, train_bottlenecks),
new FeedItem(ground_truth_input, train_ground_truth));
(float train_accuracy, float cross_entropy_value) = (results[0], results[1]);
print($"{DateTime.Now}: Step {i}: Train accuracy = {train_accuracy * 100}%");
print($"{DateTime.Now}: Step {i}: Cross entropy = {cross_entropy_value}");
print($"{DateTime.Now}: Step {i + 1}: Train accuracy = {train_accuracy * 100}%, Cross entropy = {cross_entropy_value.ToString("G4")}");

var (validation_bottlenecks, validation_ground_truth, _) = get_random_cached_bottlenecks(
sess, image_lists, validation_batch_size, "validation",
@@ -158,7 +160,8 @@ namespace TensorFlowNET.Examples.ImageProcess
(string validation_summary, float validation_accuracy) = (results[0], results[1]);

validation_writer.add_summary(validation_summary, i);
print($"{DateTime.Now}: Step {i + 1}: Validation accuracy = {validation_accuracy * 100}% (N={len(validation_bottlenecks)})");
print($"{DateTime.Now}: Step {i + 1}: Validation accuracy = {validation_accuracy * 100}% (N={len(validation_bottlenecks)}) {sw.ElapsedMilliseconds}ms");
sw.Restart();
}

// Store intermediate results
@@ -180,12 +183,11 @@ namespace TensorFlowNET.Examples.ImageProcess

// Write out the trained graph and labels with the weights stored as
// constants.
print($"final test accuracy: {test_accuracy}");
print($"Save final result to : {output_graph}");
save_graph_to_file(output_graph, class_count);
File.WriteAllText(output_labels, string.Join("\n", image_lists.Keys));
return test_accuracy > 0.75f;
});

return false;
}

/// <summary>
@@ -215,6 +217,8 @@ namespace TensorFlowNET.Examples.ImageProcess
new FeedItem(bottleneck_input, test_bottlenecks),
new FeedItem(ground_truth_input, test_ground_truth));

print($"final test accuracy: {((float)results[0] * 100).ToString("G4")}% (N={len(test_bottlenecks)})");

return (results[0], results[1]);
}



Loading…
Cancel
Save