Browse Source

fix variable path for transfer learning and word2vec.

tags/v0.10
Oceania2018 6 years ago
parent
commit
b28687b245
3 changed files with 150 additions and 141 deletions
  1. BIN
      graph/InceptionV3.meta
  2. +147
    -138
      test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs
  3. +3
    -3
      test/TensorFlowNET.Examples/TextProcessing/Word2Vec.cs

BIN
graph/InceptionV3.meta View File


+ 147
- 138
test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs View File

@@ -74,131 +74,23 @@ namespace TensorFlowNET.Examples.ImageProcess
Tensor bottleneck_input; Tensor bottleneck_input;
Tensor cross_entropy; Tensor cross_entropy;
Tensor ground_truth_input; Tensor ground_truth_input;
Tensor bottleneck_tensor;
bool wants_quantization;
float test_accuracy;
NDArray predictions;


public bool Run() public bool Run()
{ {
PrepareData(); PrepareData();


// Set up the pre-trained graph.
var (graph, bottleneck_tensor, resized_image_tensor, wants_quantization) =
create_module_graph();
var graph = IsImportingGraph ? ImportGraph() : BuildGraph();


// Add the new layer that we'll be training.
with(graph.as_default(), delegate
with(tf.Session(graph), sess =>
{ {
(train_step, cross_entropy, bottleneck_input,
ground_truth_input, final_tensor) = add_final_retrain_ops(
class_count, final_tensor_name, bottleneck_tensor,
wants_quantization, is_training: true);
Train(sess);
}); });


var sw = new Stopwatch();

return with(tf.Session(graph), sess =>
{
// Initialize all weights: for the module to their pretrained values,
// and for the newly added retraining layer to random initial values.
var init = tf.global_variables_initializer();
sess.run(init);

var (jpeg_data_tensor, decoded_image_tensor) = add_jpeg_decoding();

// We'll make sure we've calculated the 'bottleneck' image summaries and
// cached them on disk.
cache_bottlenecks(sess, image_lists, image_dir,
bottleneck_dir, jpeg_data_tensor,
decoded_image_tensor, resized_image_tensor,
bottleneck_tensor, tfhub_module);

// Create the operations we need to evaluate the accuracy of our new layer.
var (evaluation_step, _) = add_evaluation_step(final_tensor, ground_truth_input);

// Merge all the summaries and write them out to the summaries_dir
var merged = tf.summary.merge_all();
var train_writer = tf.summary.FileWriter(summaries_dir + "/train", sess.graph);
var validation_writer = tf.summary.FileWriter(summaries_dir + "/validation", sess.graph);

// Create a train saver that is used to restore values into an eval graph
// when exporting models.
var train_saver = tf.train.Saver();
train_saver.save(sess, CHECKPOINT_NAME);

sw.Restart();

for (int i = 0; i < how_many_training_steps; i++)
{
var (train_bottlenecks, train_ground_truth, _) = get_random_cached_bottlenecks(
sess, image_lists, train_batch_size, "training",
bottleneck_dir, image_dir, jpeg_data_tensor,
decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
tfhub_module);

// Feed the bottlenecks and ground truth into the graph, and run a training
// step. Capture training summaries for TensorBoard with the `merged` op.
var results = sess.run(
new ITensorOrOperation[] { merged, train_step },
new FeedItem(bottleneck_input, train_bottlenecks),
new FeedItem(ground_truth_input, train_ground_truth));
var train_summary = results[0];

// TODO
train_writer.add_summary(train_summary, i);

// Every so often, print out how well the graph is training.
bool is_last_step = (i + 1 == how_many_training_steps);
if ((i % eval_step_interval) == 0 || is_last_step)
{
results = sess.run(
new Tensor[] { evaluation_step, cross_entropy },
new FeedItem(bottleneck_input, train_bottlenecks),
new FeedItem(ground_truth_input, train_ground_truth));
(float train_accuracy, float cross_entropy_value) = (results[0], results[1]);
print($"{DateTime.Now}: Step {i + 1}: Train accuracy = {train_accuracy * 100}%, Cross entropy = {cross_entropy_value.ToString("G4")}");

var (validation_bottlenecks, validation_ground_truth, _) = get_random_cached_bottlenecks(
sess, image_lists, validation_batch_size, "validation",
bottleneck_dir, image_dir, jpeg_data_tensor,
decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
tfhub_module);

// Run a validation step and capture training summaries for TensorBoard
// with the `merged` op.
results = sess.run(new Tensor[] { merged, evaluation_step },
new FeedItem(bottleneck_input, validation_bottlenecks),
new FeedItem(ground_truth_input, validation_ground_truth));

(string validation_summary, float validation_accuracy) = (results[0], results[1]);

validation_writer.add_summary(validation_summary, i);
print($"{DateTime.Now}: Step {i + 1}: Validation accuracy = {validation_accuracy * 100}% (N={len(validation_bottlenecks)}) {sw.ElapsedMilliseconds}ms");
sw.Restart();
}

// Store intermediate results
int intermediate_frequency = intermediate_store_frequency;
if (intermediate_frequency > 0 && i % intermediate_frequency == 0 && i > 0)
{

}
}

// After training is complete, force one last save of the train checkpoint.
train_saver.save(sess, CHECKPOINT_NAME);

// We've completed all our training, so run a final test evaluation on
// some new images we haven't used before.
var (test_accuracy, predictions) = run_final_eval(sess, null, class_count, image_lists,
jpeg_data_tensor, decoded_image_tensor, resized_image_tensor,
bottleneck_tensor);

// Write out the trained graph and labels with the weights stored as
// constants.
print($"Save final result to : {output_graph}");
save_graph_to_file(output_graph, class_count);
File.WriteAllText(output_labels, string.Join("\n", image_lists.Keys));

return test_accuracy > 0.75f;
});
return test_accuracy > 0.75f;
} }


/// <summary> /// <summary>
@@ -212,7 +104,7 @@ namespace TensorFlowNET.Examples.ImageProcess
/// <param name="decoded_image_tensor"></param> /// <param name="decoded_image_tensor"></param>
/// <param name="resized_image_tensor"></param> /// <param name="resized_image_tensor"></param>
/// <param name="bottleneck_tensor"></param> /// <param name="bottleneck_tensor"></param>
private (float, NDArray) run_final_eval(Session train_session, object module_spec, int class_count,
private (float, NDArray) run_final_eval(Session train_session, object module_spec, int class_count,
Dictionary<string, Dictionary<string, string[]>> image_lists, Dictionary<string, Dictionary<string, string[]>> image_lists,
Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor jpeg_data_tensor, Tensor decoded_image_tensor,
Tensor resized_image_tensor, Tensor bottleneck_tensor) Tensor resized_image_tensor, Tensor bottleneck_tensor)
@@ -233,7 +125,7 @@ namespace TensorFlowNET.Examples.ImageProcess
return (results[0], results[1]); return (results[0], results[1]);
} }


private (Session, Tensor, Tensor, Tensor, Tensor, Tensor)
private (Session, Tensor, Tensor, Tensor, Tensor, Tensor)
build_eval_session(int class_count) build_eval_session(int class_count)
{ {
// If quantized, we need to create the correct eval graph for exporting. // If quantized, we need to create the correct eval graph for exporting.
@@ -245,7 +137,7 @@ namespace TensorFlowNET.Examples.ImageProcess
with(eval_graph.as_default(), graph => with(eval_graph.as_default(), graph =>
{ {
// Add the new layer for exporting. // Add the new layer for exporting.
var (_, _, bottleneck_input, ground_truth_input, final_tensor) =
var (_, _, bottleneck_input, ground_truth_input, final_tensor) =
add_final_retrain_ops(class_count, final_tensor_name, bottleneck_tensor, add_final_retrain_ops(class_count, final_tensor_name, bottleneck_tensor,
wants_quantization, is_training: false); wants_quantization, is_training: false);


@@ -277,7 +169,7 @@ namespace TensorFlowNET.Examples.ImageProcess
/// <param name="quantize_layer"></param> /// <param name="quantize_layer"></param>
/// <param name="is_training"></param> /// <param name="is_training"></param>
/// <returns></returns> /// <returns></returns>
private (Operation, Tensor, Tensor, Tensor, Tensor) add_final_retrain_ops(int class_count, string final_tensor_name,
private (Operation, Tensor, Tensor, Tensor, Tensor) add_final_retrain_ops(int class_count, string final_tensor_name,
Tensor bottleneck_tensor, bool quantize_layer, bool is_training) Tensor bottleneck_tensor, bool quantize_layer, bool is_training)
{ {
var (batch_size, bottleneck_tensor_size) = (bottleneck_tensor.TensorShape.Dimensions[0], bottleneck_tensor.TensorShape.Dimensions[1]); var (batch_size, bottleneck_tensor_size) = (bottleneck_tensor.TensorShape.Dimensions[0], bottleneck_tensor.TensorShape.Dimensions[1]);
@@ -365,7 +257,8 @@ namespace TensorFlowNET.Examples.ImageProcess
var mean = tf.reduce_mean(var); var mean = tf.reduce_mean(var);
tf.summary.scalar("mean", mean); tf.summary.scalar("mean", mean);
Tensor stddev = null; Tensor stddev = null;
with(tf.name_scope("stddev"), delegate {
with(tf.name_scope("stddev"), delegate
{
stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))); stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)));
}); });
tf.summary.scalar("stddev", stddev); tf.summary.scalar("stddev", stddev);
@@ -378,7 +271,7 @@ namespace TensorFlowNET.Examples.ImageProcess
private (Graph, Tensor, Tensor, bool) create_module_graph() private (Graph, Tensor, Tensor, bool) create_module_graph()
{ {
var (height, width) = (299, 299); var (height, width) = (299, 299);
return with(tf.Graph().as_default(), graph => return with(tf.Graph().as_default(), graph =>
{ {
tf.train.import_meta_graph("graph/InceptionV3.meta"); tf.train.import_meta_graph("graph/InceptionV3.meta");
@@ -390,8 +283,8 @@ namespace TensorFlowNET.Examples.ImageProcess
}); });
} }


private (NDArray, long[], string[]) get_random_cached_bottlenecks(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists,
int how_many, string category, string bottleneck_dir, string image_dir,
private (NDArray, long[], string[]) get_random_cached_bottlenecks(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists,
int how_many, string category, string bottleneck_dir, string image_dir,
Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor, Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor,
Tensor bottleneck_tensor, string module_name) Tensor bottleneck_tensor, string module_name)
{ {
@@ -423,7 +316,7 @@ namespace TensorFlowNET.Examples.ImageProcess
// Retrieve all bottlenecks. // Retrieve all bottlenecks.
foreach (var (label_index, label_name) in enumerate(image_lists.Keys.ToArray())) foreach (var (label_index, label_name) in enumerate(image_lists.Keys.ToArray()))
{ {
foreach(var (image_index, image_name) in enumerate(image_lists[label_name][category]))
foreach (var (image_index, image_name) in enumerate(image_lists[label_name][category]))
{ {
var bottleneck = get_or_create_bottleneck( var bottleneck = get_or_create_bottleneck(
sess, image_lists, label_name, image_index, image_dir, category, sess, image_lists, label_name, image_index, image_dir, category,
@@ -480,17 +373,17 @@ namespace TensorFlowNET.Examples.ImageProcess
/// <param name="resized_image_tensor"></param> /// <param name="resized_image_tensor"></param>
/// <param name="bottleneck_tensor"></param> /// <param name="bottleneck_tensor"></param>
/// <param name="tfhub_module"></param> /// <param name="tfhub_module"></param>
private void cache_bottlenecks(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists,
private void cache_bottlenecks(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists,
string image_dir, string bottleneck_dir, Tensor jpeg_data_tensor, Tensor decoded_image_tensor, string image_dir, string bottleneck_dir, Tensor jpeg_data_tensor, Tensor decoded_image_tensor,
Tensor resized_input_tensor, Tensor bottleneck_tensor, string module_name) Tensor resized_input_tensor, Tensor bottleneck_tensor, string module_name)
{ {
int how_many_bottlenecks = 0; int how_many_bottlenecks = 0;
foreach(var (label_name, label_lists) in image_lists)
foreach (var (label_name, label_lists) in image_lists)
{ {
foreach(var category in new string[] { "training", "testing", "validation" })
foreach (var category in new string[] { "training", "testing", "validation" })
{ {
var category_list = label_lists[category]; var category_list = label_lists[category];
foreach(var (index, unused_base_name) in enumerate(category_list))
foreach (var (index, unused_base_name) in enumerate(category_list))
{ {
get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir, category, get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir, category,
bottleneck_dir, jpeg_data_tensor, decoded_image_tensor, bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
@@ -503,8 +396,8 @@ namespace TensorFlowNET.Examples.ImageProcess
} }
} }


private float[] get_or_create_bottleneck(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists,
string label_name, int index, string image_dir, string category, string bottleneck_dir,
private float[] get_or_create_bottleneck(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists,
string label_name, int index, string image_dir, string category, string bottleneck_dir,
Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor, Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor,
Tensor bottleneck_tensor, string module_name) Tensor bottleneck_tensor, string module_name)
{ {
@@ -524,8 +417,8 @@ namespace TensorFlowNET.Examples.ImageProcess
return bottleneck_values; return bottleneck_values;
} }


private void create_bottleneck_file(string bottleneck_path, Dictionary<string, Dictionary<string, string[]>> image_lists,
string label_name, int index, string image_dir, string category, Session sess,
private void create_bottleneck_file(string bottleneck_path, Dictionary<string, Dictionary<string, string[]>> image_lists,
string label_name, int index, string image_dir, string category, Session sess,
Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor, Tensor bottleneck_tensor) Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor, Tensor bottleneck_tensor)
{ {
// Create a single bottleneck file. // Create a single bottleneck file.
@@ -557,14 +450,14 @@ namespace TensorFlowNET.Examples.ImageProcess
Tensor decoded_image_tensor, Tensor resized_input_tensor, Tensor bottleneck_tensor) Tensor decoded_image_tensor, Tensor resized_input_tensor, Tensor bottleneck_tensor)
{ {
// First decode the JPEG image, resize it, and rescale the pixel values. // First decode the JPEG image, resize it, and rescale the pixel values.
var resized_input_values = sess.run(decoded_image_tensor, new FeedItem(image_data_tensor, new Tensor( image_data, TF_DataType.TF_STRING)));
var resized_input_values = sess.run(decoded_image_tensor, new FeedItem(image_data_tensor, new Tensor(image_data, TF_DataType.TF_STRING)));
// Then run it through the recognition network. // Then run it through the recognition network.
var bottleneck_values = sess.run(bottleneck_tensor, new FeedItem(resized_input_tensor, resized_input_values)); var bottleneck_values = sess.run(bottleneck_tensor, new FeedItem(resized_input_tensor, resized_input_values));
bottleneck_values = np.squeeze(bottleneck_values); bottleneck_values = np.squeeze(bottleneck_values);
return bottleneck_values; return bottleneck_values;
} }


private string get_bottleneck_path(Dictionary<string, Dictionary<string, string[]>> image_lists, string label_name, int index,
private string get_bottleneck_path(Dictionary<string, Dictionary<string, string[]>> image_lists, string label_name, int index,
string bottleneck_dir, string category, string module_name) string bottleneck_dir, string category, string module_name)
{ {
module_name = (module_name.Replace("://", "~") // URL scheme. module_name = (module_name.Replace("://", "~") // URL scheme.
@@ -667,7 +560,7 @@ namespace TensorFlowNET.Examples.ImageProcess


var result = new Dictionary<string, Dictionary<string, string[]>>(); var result = new Dictionary<string, Dictionary<string, string[]>>();


foreach(var sub_dir in sub_dirs)
foreach (var sub_dir in sub_dirs)
{ {
var dir_name = sub_dir.Split(Path.DirectorySeparatorChar).Last(); var dir_name = sub_dir.Split(Path.DirectorySeparatorChar).Last();
print($"Looking for images in '{dir_name}'"); print($"Looking for images in '{dir_name}'");
@@ -689,7 +582,22 @@ namespace TensorFlowNET.Examples.ImageProcess


public Graph ImportGraph() public Graph ImportGraph()
{ {
throw new NotImplementedException();
Graph graph;

// Set up the pre-trained graph.
(graph, bottleneck_tensor, resized_image_tensor, wants_quantization) =
create_module_graph();

// Add the new layer that we'll be training.
with(graph.as_default(), delegate
{
(train_step, cross_entropy, bottleneck_input,
ground_truth_input, final_tensor) = add_final_retrain_ops(
class_count, final_tensor_name, bottleneck_tensor,
wants_quantization, is_training: true);
});

return graph;
} }


public Graph BuildGraph() public Graph BuildGraph()
@@ -699,7 +607,108 @@ namespace TensorFlowNET.Examples.ImageProcess


public void Train(Session sess) public void Train(Session sess)
{ {
throw new NotImplementedException();
var sw = new Stopwatch();

// Initialize all weights: for the module to their pretrained values,
// and for the newly added retraining layer to random initial values.
var init = tf.global_variables_initializer();
sess.run(init);

var (jpeg_data_tensor, decoded_image_tensor) = add_jpeg_decoding();

// We'll make sure we've calculated the 'bottleneck' image summaries and
// cached them on disk.
cache_bottlenecks(sess, image_lists, image_dir,
bottleneck_dir, jpeg_data_tensor,
decoded_image_tensor, resized_image_tensor,
bottleneck_tensor, tfhub_module);

// Create the operations we need to evaluate the accuracy of our new layer.
var (evaluation_step, _) = add_evaluation_step(final_tensor, ground_truth_input);

// Merge all the summaries and write them out to the summaries_dir
var merged = tf.summary.merge_all();
var train_writer = tf.summary.FileWriter(summaries_dir + "/train", sess.graph);
var validation_writer = tf.summary.FileWriter(summaries_dir + "/validation", sess.graph);

// Create a train saver that is used to restore values into an eval graph
// when exporting models.
var train_saver = tf.train.Saver();
train_saver.save(sess, CHECKPOINT_NAME);

sw.Restart();

for (int i = 0; i < how_many_training_steps; i++)
{
var (train_bottlenecks, train_ground_truth, _) = get_random_cached_bottlenecks(
sess, image_lists, train_batch_size, "training",
bottleneck_dir, image_dir, jpeg_data_tensor,
decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
tfhub_module);

// Feed the bottlenecks and ground truth into the graph, and run a training
// step. Capture training summaries for TensorBoard with the `merged` op.
var results = sess.run(
new ITensorOrOperation[] { merged, train_step },
new FeedItem(bottleneck_input, train_bottlenecks),
new FeedItem(ground_truth_input, train_ground_truth));
var train_summary = results[0];

// TODO
train_writer.add_summary(train_summary, i);

// Every so often, print out how well the graph is training.
bool is_last_step = (i + 1 == how_many_training_steps);
if ((i % eval_step_interval) == 0 || is_last_step)
{
results = sess.run(
new Tensor[] { evaluation_step, cross_entropy },
new FeedItem(bottleneck_input, train_bottlenecks),
new FeedItem(ground_truth_input, train_ground_truth));
(float train_accuracy, float cross_entropy_value) = (results[0], results[1]);
print($"{DateTime.Now}: Step {i + 1}: Train accuracy = {train_accuracy * 100}%, Cross entropy = {cross_entropy_value.ToString("G4")}");

var (validation_bottlenecks, validation_ground_truth, _) = get_random_cached_bottlenecks(
sess, image_lists, validation_batch_size, "validation",
bottleneck_dir, image_dir, jpeg_data_tensor,
decoded_image_tensor, resized_image_tensor, bottleneck_tensor,
tfhub_module);

// Run a validation step and capture training summaries for TensorBoard
// with the `merged` op.
results = sess.run(new Tensor[] { merged, evaluation_step },
new FeedItem(bottleneck_input, validation_bottlenecks),
new FeedItem(ground_truth_input, validation_ground_truth));

(string validation_summary, float validation_accuracy) = (results[0], results[1]);

validation_writer.add_summary(validation_summary, i);
print($"{DateTime.Now}: Step {i + 1}: Validation accuracy = {validation_accuracy * 100}% (N={len(validation_bottlenecks)}) {sw.ElapsedMilliseconds}ms");
sw.Restart();
}

// Store intermediate results
int intermediate_frequency = intermediate_store_frequency;
if (intermediate_frequency > 0 && i % intermediate_frequency == 0 && i > 0)
{

}
}

// After training is complete, force one last save of the train checkpoint.
train_saver.save(sess, CHECKPOINT_NAME);

// We've completed all our training, so run a final test evaluation on
// some new images we haven't used before.
(test_accuracy, predictions) = run_final_eval(sess, null, class_count, image_lists,
jpeg_data_tensor, decoded_image_tensor, resized_image_tensor,
bottleneck_tensor);

// Write out the trained graph and labels with the weights stored as
// constants.
print($"Save final result to : {output_graph}");
save_graph_to_file(output_graph, class_count);
File.WriteAllText(output_labels, string.Join("\n", image_lists.Keys));
} }


public void Predict(Session sess) public void Predict(Session sess)


+ 3
- 3
test/TensorFlowNET.Examples/TextProcessing/Word2Vec.cs View File

@@ -51,7 +51,7 @@ namespace TensorFlowNET.Examples


var graph = tf.Graph().as_default(); var graph = tf.Graph().as_default();


tf.train.import_meta_graph("graph/word2vec.meta");
tf.train.import_meta_graph($"graph{Path.DirectorySeparatorChar}word2vec.meta");


// Input data // Input data
Tensor X = graph.OperationByName("Placeholder"); Tensor X = graph.OperationByName("Placeholder");
@@ -169,10 +169,10 @@ namespace TensorFlowNET.Examples
url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/text8.zip"; url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/text8.zip";
Web.Download(url, "word2vec", "text8.zip"); Web.Download(url, "word2vec", "text8.zip");
// Unzip the dataset file. Text has already been processed // Unzip the dataset file. Text has already been processed
Compress.UnZip(@"word2vec\text8.zip", "word2vec");
Compress.UnZip($"word2vec{Path.DirectorySeparatorChar}text8.zip", "word2vec");


int wordId = 0; int wordId = 0;
text_words = File.ReadAllText(@"word2vec\text8").Trim().ToLower().Split();
text_words = File.ReadAllText($"word2vec{Path.DirectorySeparatorChar}text8").Trim().ToLower().Split();
// Build the dictionary and replace rare words with UNK token // Build the dictionary and replace rare words with UNK token
word2id = text_words.GroupBy(x => x) word2id = text_words.GroupBy(x => x)
.Select(x => new WordId .Select(x => new WordId


Loading…
Cancel
Save