@@ -19,7 +19,7 @@ namespace Tensorflow | |||
public BaseSession(string target = "", Graph graph = null) | |||
{ | |||
if(graph is null) | |||
if (graph is null) | |||
{ | |||
_graph = ops.get_default_graph(); | |||
} | |||
@@ -41,9 +41,9 @@ namespace Tensorflow | |||
return _run(fetches, feed_dict); | |||
} | |||
public virtual NDArray run(ITensorOrOperation[] fetches, Hashtable feed_dict = null) | |||
public virtual NDArray run(object fetches, Hashtable feed_dict = null) | |||
{ | |||
var feed_items = feed_dict == null ? new FeedItem[0] : | |||
var feed_items = feed_dict == null ? new FeedItem[0] : | |||
feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | |||
return _run(fetches, feed_items); | |||
} | |||
@@ -98,7 +98,7 @@ namespace Tensorflow | |||
feed_dict_tensor[subfeed_t] = (NDArray)val; | |||
break; | |||
case bool val: | |||
feed_dict_tensor[subfeed_t] = (NDArray) val; | |||
feed_dict_tensor[subfeed_t] = (NDArray)val; | |||
break; | |||
case bool[] val: | |||
feed_dict_tensor[subfeed_t] = (NDArray)val; | |||
@@ -106,8 +106,8 @@ namespace Tensorflow | |||
default: | |||
Console.WriteLine($"can't handle data type of subfeed_val"); | |||
throw new NotImplementedException("_run subfeed"); | |||
} | |||
} | |||
feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); | |||
} | |||
} | |||
@@ -146,9 +146,9 @@ namespace Tensorflow | |||
/// </returns> | |||
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | |||
{ | |||
var feeds = feed_dict.Select(x => | |||
var feeds = feed_dict.Select(x => | |||
{ | |||
if(x.Key is Tensor tensor) | |||
if (x.Key is Tensor tensor) | |||
{ | |||
switch (x.Value) | |||
{ | |||
@@ -0,0 +1,11 @@ | |||
using System; | |||
using System.Collections; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Sessions | |||
{ | |||
public class FeedDict : Hashtable | |||
{ | |||
} | |||
} |
@@ -8,6 +8,7 @@ using System.Text; | |||
using NumSharp; | |||
using Tensorflow; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Sessions; | |||
using TensorFlowNET.Examples.Text.cnn_models; | |||
using TensorFlowNET.Examples.TextClassification; | |||
using TensorFlowNET.Examples.Utility; | |||
@@ -91,7 +92,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification | |||
foreach (var (x_batch, y_batch, total) in train_batches) | |||
{ | |||
i++; | |||
var train_feed_dict = new Hashtable | |||
var train_feed_dict = new FeedDict | |||
{ | |||
[model_x] = x_batch, | |||
[model_y] = y_batch, | |||
@@ -113,25 +114,26 @@ namespace TensorFlowNET.Examples.CnnTextClassification | |||
if (step % 100 == 0) | |||
{ | |||
continue; | |||
// # Test accuracy with validation data for each epoch. | |||
var valid_batches = batch_iter(valid_x, valid_y, BATCH_SIZE, 1); | |||
var (sum_accuracy, cnt) = (0, 0); | |||
var (sum_accuracy, cnt) = (0.0f, 0); | |||
foreach (var (valid_x_batch, valid_y_batch, total_validation_batches) in valid_batches) | |||
{ | |||
// valid_feed_dict = { | |||
// model.x: valid_x_batch, | |||
// model.y: valid_y_batch, | |||
// model.is_training: False | |||
// } | |||
// accuracy = sess.run(model.accuracy, feed_dict = valid_feed_dict) | |||
// sum_accuracy += accuracy | |||
// cnt += 1 | |||
var valid_feed_dict = new FeedDict | |||
{ | |||
[model_x] = valid_x_batch, | |||
[model_y] = valid_y_batch, | |||
[is_training] = false | |||
}; | |||
var result1 = sess.run(accuracy, valid_feed_dict); | |||
float accuracy_value = result1; | |||
sum_accuracy += accuracy_value; | |||
cnt += 1; | |||
} | |||
// valid_accuracy = sum_accuracy / cnt | |||
// print("\nValidation Accuracy = {1}\n".format(step // num_batches_per_epoch, sum_accuracy / cnt)) | |||
var valid_accuracy = sum_accuracy / cnt; | |||
print($"\nValidation Accuracy = {valid_accuracy}\n"); | |||
// # Save model | |||
// if valid_accuracy > max_accuracy: | |||