Browse Source

added FeedDict

tags/v0.9
Meinrad Recheis 6 years ago
parent
commit
6479706535
3 changed files with 35 additions and 22 deletions
  1. +8
    -8
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  2. +11
    -0
      src/TensorFlowNET.Core/Sessions/FeedDict.cs
  3. +16
    -14
      test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs

+ 8
- 8
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -19,7 +19,7 @@ namespace Tensorflow

public BaseSession(string target = "", Graph graph = null)
{
if(graph is null)
if (graph is null)
{
_graph = ops.get_default_graph();
}
@@ -41,9 +41,9 @@ namespace Tensorflow
return _run(fetches, feed_dict);
}

public virtual NDArray run(ITensorOrOperation[] fetches, Hashtable feed_dict = null)
public virtual NDArray run(object fetches, Hashtable feed_dict = null)
{
var feed_items = feed_dict == null ? new FeedItem[0] :
var feed_items = feed_dict == null ? new FeedItem[0] :
feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray();
return _run(fetches, feed_items);
}
@@ -98,7 +98,7 @@ namespace Tensorflow
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case bool val:
feed_dict_tensor[subfeed_t] = (NDArray) val;
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case bool[] val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
@@ -106,8 +106,8 @@ namespace Tensorflow
default:
Console.WriteLine($"can't handle data type of subfeed_val");
throw new NotImplementedException("_run subfeed");
}
}
feed_map[subfeed_t.name] = (subfeed_t, subfeed_val);
}
}
@@ -146,9 +146,9 @@ namespace Tensorflow
/// </returns>
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict)
{
var feeds = feed_dict.Select(x =>
var feeds = feed_dict.Select(x =>
{
if(x.Key is Tensor tensor)
if (x.Key is Tensor tensor)
{
switch (x.Value)
{


+ 11
- 0
src/TensorFlowNET.Core/Sessions/FeedDict.cs View File

@@ -0,0 +1,11 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Text;
namespace Tensorflow.Sessions
{
public class FeedDict : Hashtable
{
}
}

+ 16
- 14
test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs View File

@@ -8,6 +8,7 @@ using System.Text;
using NumSharp;
using Tensorflow;
using Tensorflow.Keras.Engine;
using Tensorflow.Sessions;
using TensorFlowNET.Examples.Text.cnn_models;
using TensorFlowNET.Examples.TextClassification;
using TensorFlowNET.Examples.Utility;
@@ -91,7 +92,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification
foreach (var (x_batch, y_batch, total) in train_batches)
{
i++;
var train_feed_dict = new Hashtable
var train_feed_dict = new FeedDict
{
[model_x] = x_batch,
[model_y] = y_batch,
@@ -113,25 +114,26 @@ namespace TensorFlowNET.Examples.CnnTextClassification

if (step % 100 == 0)
{
continue;
// # Test accuracy with validation data for each epoch.
var valid_batches = batch_iter(valid_x, valid_y, BATCH_SIZE, 1);
var (sum_accuracy, cnt) = (0, 0);
var (sum_accuracy, cnt) = (0.0f, 0);
foreach (var (valid_x_batch, valid_y_batch, total_validation_batches) in valid_batches)
{
// valid_feed_dict = {
// model.x: valid_x_batch,
// model.y: valid_y_batch,
// model.is_training: False
// }

// accuracy = sess.run(model.accuracy, feed_dict = valid_feed_dict)
// sum_accuracy += accuracy
// cnt += 1
var valid_feed_dict = new FeedDict
{
[model_x] = valid_x_batch,
[model_y] = valid_y_batch,
[is_training] = false
};
var result1 = sess.run(accuracy, valid_feed_dict);
float accuracy_value = result1;
sum_accuracy += accuracy_value;
cnt += 1;
}
// valid_accuracy = sum_accuracy / cnt

// print("\nValidation Accuracy = {1}\n".format(step // num_batches_per_epoch, sum_accuracy / cnt))
var valid_accuracy = sum_accuracy / cnt;

print($"\nValidation Accuracy = {valid_accuracy}\n");

// # Save model
// if valid_accuracy > max_accuracy:


Loading…
Cancel
Save