diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs
index a97a8dda..dd385b4c 100644
--- a/src/TensorFlowNET.Core/APIs/tf.array.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.array.cs
@@ -20,6 +20,13 @@ namespace Tensorflow
public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1)
=> array_ops.expand_dims(input, axis, name, dim);
+ ///
+ /// Return the elements, either from `x` or `y`, depending on the `condition`.
+ ///
+ ///
+ public static Tensor where(Tensor condition, Tx x, Ty y, string name = null)
+ => array_ops.where(condition, x, y, name);
+
///
/// Transposes `a`. Permutes the dimensions according to `perm`.
///
diff --git a/src/TensorFlowNET.Core/APIs/tf.random.cs b/src/TensorFlowNET.Core/APIs/tf.random.cs
index b309fc06..b56df144 100644
--- a/src/TensorFlowNET.Core/APIs/tf.random.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.random.cs
@@ -25,7 +25,7 @@ namespace Tensorflow
public static Tensor random_uniform(int[] shape,
float minval = 0,
- float? maxval = null,
+ float maxval = 1,
TF_DataType dtype = TF_DataType.TF_FLOAT,
int? seed = null,
string name = null) => random_ops.random_uniform(shape, minval, maxval, dtype, seed, name);
diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs
index aad687e6..a4aafc94 100644
--- a/src/TensorFlowNET.Core/Layers/Layer.cs
+++ b/src/TensorFlowNET.Core/Layers/Layer.cs
@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Text;
+using static Tensorflow.Python;
namespace Tensorflow.Layers
{
@@ -50,7 +51,7 @@ namespace Tensorflow.Layers
auxiliary_name_scope: false);
}
- Python.with(scope_context_manager, scope2 => _current_scope = scope2);
+ with(scope_context_manager, scope2 => _current_scope = scope2);
// Actually call layer
var outputs = base.__call__(new Tensor[] { inputs }, training: training);
@@ -60,6 +61,13 @@ namespace Tensorflow.Layers
return outputs;
}
+ protected override void _init_set_name(string name, bool zero_based = true)
+ {
+ // Determine layer name (non-unique).
+ base._init_set_name(name, zero_based: zero_based);
+ _base_name = this.name;
+ }
+
protected virtual void _add_elements_to_collection(Operation[] elements, string[] collection_list)
{
foreach(var name in collection_list)
@@ -140,10 +148,18 @@ namespace Tensorflow.Layers
{
if (_scope == null)
{
- Python.with(tf.variable_scope(scope, default_name: _base_name), captured_scope =>
+ if(_reuse.HasValue && _reuse.Value)
{
- _scope = captured_scope;
- });
+ throw new NotImplementedException("_set_scope _reuse.HasValue");
+ /*with(tf.variable_scope(scope == null ? _base_name : scope),
+ captured_scope => _scope = captured_scope);*/
+ }
+ else
+ {
+ with(tf.variable_scope(scope, default_name: _base_name),
+ captured_scope => _scope = captured_scope);
+ }
+
}
}
}
diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs
index 34adba6d..0e043198 100644
--- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs
@@ -233,7 +233,7 @@ namespace Tensorflow
});
}
- public static Tensor where(Tensor condition, Tensor x = null, Tensor y = null, string name = null)
+ public static Tensor where(Tensor condition, object x = null, object y = null, string name = null)
{
if( x == null && y == null)
{
diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
index a83c636e..0a49233c 100644
--- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
@@ -234,7 +234,7 @@ namespace Tensorflow
return _op.outputs[0];
}
- public static Tensor select(Tensor condition, Tensor t, Tensor e, string name = null)
+ public static Tensor select(Tensor condition, Tx t, Ty e, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Select", name, new { condition, t, e });
return _op.outputs[0];
diff --git a/src/TensorFlowNET.Core/Operations/random_ops.py.cs b/src/TensorFlowNET.Core/Operations/random_ops.py.cs
index 54e9cd28..25ef6da0 100644
--- a/src/TensorFlowNET.Core/Operations/random_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/random_ops.py.cs
@@ -49,7 +49,7 @@ namespace Tensorflow
/// A tensor of the specified shape filled with random uniform values.
public static Tensor random_uniform(int[] shape,
float minval = 0,
- float? maxval = null,
+ float maxval = 1,
TF_DataType dtype = TF_DataType.TF_FLOAT,
int? seed = null,
string name = null)
diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs
index 33acdfbc..52c4f988 100644
--- a/src/TensorFlowNET.Core/tf.cs
+++ b/src/TensorFlowNET.Core/tf.cs
@@ -15,7 +15,7 @@ namespace Tensorflow
public static TF_DataType float16 = TF_DataType.TF_HALF;
public static TF_DataType float32 = TF_DataType.TF_FLOAT;
public static TF_DataType float64 = TF_DataType.TF_DOUBLE;
- public static TF_DataType boolean = TF_DataType.TF_BOOL;
+ public static TF_DataType @bool = TF_DataType.TF_BOOL;
public static TF_DataType chars = TF_DataType.TF_STRING;
public static Context context = new Context(new ContextOptions(), new Status());
diff --git a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs
index 8bc2f9ef..289ed4d0 100644
--- a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs
+++ b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs
@@ -40,43 +40,157 @@ namespace TensorFlowNET.Examples
protected float loss_value = 0;
int vocabulary_size = 50000;
+ NDArray train_x, valid_x, train_y, valid_y;
public bool Run()
{
PrepareData();
- var graph = tf.Graph().as_default();
- return with(tf.Session(graph), sess =>
+ Train();
+
+ return true;
+ }
+
+ // TODO: this originally is an SKLearn utility function. it randomizes train and test which we don't do here
+ private (NDArray, NDArray, NDArray, NDArray) train_test_split(NDArray x, NDArray y, float test_size = 0.3f)
+ {
+ Console.WriteLine("Splitting in Training and Testing data...");
+ int len = x.shape[0];
+ //int classes = y.Data().Distinct().Count();
+ //int samples = len / classes;
+ int train_size = (int)Math.Round(len * (1 - test_size));
+ var train_x = x[new Slice(stop: train_size), new Slice()];
+ var valid_x = x[new Slice(start: train_size), new Slice()];
+ var train_y = y[new Slice(stop: train_size)];
+ var valid_y = y[new Slice(start: train_size)];
+ Console.WriteLine("\tDONE");
+ return (train_x, valid_x, train_y, valid_y);
+ }
+
+ private static void FillWithShuffledLabels(int[][] x, int[] y, int[][] shuffled_x, int[] shuffled_y, Random random, Dictionary> labels)
+ {
+ int i = 0;
+ var label_keys = labels.Keys.ToArray();
+ while (i < shuffled_x.Length)
{
- if (IsImportingGraph)
- return RunWithImportedGraph(sess, graph);
- else
- return RunWithBuiltGraph(sess, graph);
- });
+ var key = label_keys[random.Next(label_keys.Length)];
+ var set = labels[key];
+ var index = set.First();
+ if (set.Count == 0)
+ {
+ labels.Remove(key); // remove the set as it is empty
+ label_keys = labels.Keys.ToArray();
+ }
+ shuffled_x[i] = x[index];
+ shuffled_y[i] = y[index];
+ i++;
+ }
}
- protected virtual bool RunWithImportedGraph(Session sess, Graph graph)
+ private IEnumerable<(NDArray, NDArray, int)> batch_iter(NDArray inputs, NDArray outputs, int batch_size, int num_epochs)
{
- var stopwatch = Stopwatch.StartNew();
+ var num_batches_per_epoch = (len(inputs) - 1) / batch_size + 1;
+ var total_batches = num_batches_per_epoch * num_epochs;
+ foreach (var epoch in range(num_epochs))
+ {
+ foreach (var batch_num in range(num_batches_per_epoch))
+ {
+ var start_index = batch_num * batch_size;
+ var end_index = Math.Min((batch_num + 1) * batch_size, len(inputs));
+ if (end_index <= start_index)
+ break;
+ yield return (inputs[new Slice(start_index, end_index)], outputs[new Slice(start_index, end_index)], total_batches);
+ }
+ }
+ }
+
+ public void PrepareData()
+ {
+ // full dataset https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz
+ var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/dbpedia_subset.zip";
+ Web.Download(url, dataDir, "dbpedia_subset.zip");
+ Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv"));
+
Console.WriteLine("Building dataset...");
- int[][] x = null;
- int[] y = null;
+
int alphabet_size = 0;
var word_dict = DataHelpers.build_word_dict(TRAIN_PATH);
- // vocabulary_size = len(word_dict);
- (x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN);
+ vocabulary_size = len(word_dict);
+ var (x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN);
Console.WriteLine("\tDONE ");
var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);
Console.WriteLine("Training set size: " + train_x.len);
Console.WriteLine("Test set size: " + valid_x.len);
+ }
- Console.WriteLine("Import graph...");
+ public Graph ImportGraph()
+ {
+ var graph = tf.Graph().as_default();
+
+ // download graph meta data
var meta_file = "word_cnn.meta";
+ var meta_path = Path.Combine("graph", meta_file);
+ if (File.GetLastWriteTime(meta_path) < new DateTime(2019, 05, 11))
+ {
+ // delete old cached file which contains errors
+ Console.WriteLine("Discarding cached file: " + meta_path);
+ File.Delete(meta_path);
+ }
+ var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file;
+ Web.Download(url, "graph", meta_file);
+
+ Console.WriteLine("Import graph...");
tf.train.import_meta_graph(Path.Join("graph", meta_file));
- Console.WriteLine("\tDONE " + stopwatch.Elapsed);
+ Console.WriteLine("\tDONE ");
+
+ return graph;
+ }
+
+ public Graph BuildGraph()
+ {
+ var graph = tf.Graph().as_default();
+
+ var embedding_size = 128;
+ var learning_rate = 0.001f;
+ var filter_sizes = new int[3, 4, 5];
+ var num_filters = 100;
+ var document_max_len = 100;
+
+ var x = tf.placeholder(tf.int32, new TensorShape(-1, document_max_len), name: "x");
+ var y = tf.placeholder(tf.int32, new TensorShape(-1), name: "y");
+ var is_training = tf.placeholder(tf.@bool, new TensorShape(), name: "is_training");
+ var global_step = tf.Variable(0, trainable: false);
+ var keep_prob = tf.where(is_training, 0.5, 1.0);
+ Tensor x_emb = null;
+
+ with(tf.name_scope("embedding"), scope =>
+ {
+ var init_embeddings = tf.random_uniform(new int[] { vocabulary_size, embedding_size });
+ var embeddings = tf.get_variable("embeddings", initializer: init_embeddings);
+ x_emb = tf.nn.embedding_lookup(embeddings, x);
+ x_emb = tf.expand_dims(x_emb, -1);
+ });
+
+ foreach(var filter_size in filter_sizes)
+ {
+ var conv = tf.layers.conv2d(
+ x_emb,
+ filters: num_filters,
+ kernel_size: new int[] { filter_size, embedding_size },
+ strides: new int[] { 1, 1 },
+ padding: "VALID",
+ activation: tf.nn.relu());
+ }
+
+ return graph;
+ }
+
+ private bool RunWithImportedGraph(Session sess, Graph graph)
+ {
+ var stopwatch = Stopwatch.StartNew();
sess.run(tf.global_variables_initializer());
var saver = tf.train.Saver(tf.global_variables());
@@ -149,107 +263,12 @@ namespace TensorFlowNET.Examples
return false;
}
- protected virtual bool RunWithBuiltGraph(Session session, Graph graph)
- {
- Console.WriteLine("Building dataset...");
- var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "word_cnn", CHAR_MAX_LEN, DataLimit);
-
- var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);
-
- ITextClassificationModel model = null;
- // todo train the model
- return false;
- }
-
- // TODO: this originally is an SKLearn utility function. it randomizes train and test which we don't do here
- private (NDArray, NDArray, NDArray, NDArray) train_test_split(NDArray x, NDArray y, float test_size = 0.3f)
- {
- Console.WriteLine("Splitting in Training and Testing data...");
- int len = x.shape[0];
- //int classes = y.Data().Distinct().Count();
- //int samples = len / classes;
- int train_size = (int)Math.Round(len * (1 - test_size));
- var train_x = x[new Slice(stop: train_size), new Slice()];
- var valid_x = x[new Slice(start: train_size), new Slice()];
- var train_y = y[new Slice(stop: train_size)];
- var valid_y = y[new Slice(start: train_size)];
- Console.WriteLine("\tDONE");
- return (train_x, valid_x, train_y, valid_y);
- }
-
- private static void FillWithShuffledLabels(int[][] x, int[] y, int[][] shuffled_x, int[] shuffled_y, Random random, Dictionary> labels)
- {
- int i = 0;
- var label_keys = labels.Keys.ToArray();
- while (i < shuffled_x.Length)
- {
- var key = label_keys[random.Next(label_keys.Length)];
- var set = labels[key];
- var index = set.First();
- if (set.Count == 0)
- {
- labels.Remove(key); // remove the set as it is empty
- label_keys = labels.Keys.ToArray();
- }
- shuffled_x[i] = x[index];
- shuffled_y[i] = y[index];
- i++;
- }
- }
-
- private IEnumerable<(NDArray, NDArray, int)> batch_iter(NDArray inputs, NDArray outputs, int batch_size, int num_epochs)
- {
- var num_batches_per_epoch = (len(inputs) - 1) / batch_size + 1;
- var total_batches = num_batches_per_epoch * num_epochs;
- foreach (var epoch in range(num_epochs))
- {
- foreach (var batch_num in range(num_batches_per_epoch))
- {
- var start_index = batch_num * batch_size;
- var end_index = Math.Min((batch_num + 1) * batch_size, len(inputs));
- if (end_index <= start_index)
- break;
- yield return (inputs[new Slice(start_index, end_index)], outputs[new Slice(start_index, end_index)], total_batches);
- }
- }
- }
-
- public void PrepareData()
- {
- // full dataset https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz
- var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/dbpedia_subset.zip";
- Web.Download(url, dataDir, "dbpedia_subset.zip");
- Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv"));
-
- if (IsImportingGraph)
- {
- // download graph meta data
- var meta_file = "word_cnn.meta";
- var meta_path = Path.Combine("graph", meta_file);
- if (File.GetLastWriteTime(meta_path) < new DateTime(2019, 05, 11))
- {
- // delete old cached file which contains errors
- Console.WriteLine("Discarding cached file: " + meta_path);
- File.Delete(meta_path);
- }
- url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file;
- Web.Download(url, "graph", meta_file);
- }
- }
-
- public Graph ImportGraph()
- {
- throw new NotImplementedException();
- }
-
- public Graph BuildGraph()
- {
- throw new NotImplementedException();
- }
-
public bool Train()
{
- throw new NotImplementedException();
+ var graph = IsImportingGraph ? ImportGraph() : BuildGraph();
+
+ return with(tf.Session(graph), sess
+ => RunWithImportedGraph(sess, graph));
}
public bool Predict()
diff --git a/test/TensorFlowNET.Examples/TextProcess/cnn_models/VdCnn.cs b/test/TensorFlowNET.Examples/TextProcess/cnn_models/VdCnn.cs
index fdba49ec..f4f430a5 100644
--- a/test/TensorFlowNET.Examples/TextProcess/cnn_models/VdCnn.cs
+++ b/test/TensorFlowNET.Examples/TextProcess/cnn_models/VdCnn.cs
@@ -40,7 +40,7 @@ namespace TensorFlowNET.Examples.TextClassification
x = tf.placeholder(tf.int32, new TensorShape(-1, document_max_len), name: "x");
y = tf.placeholder(tf.int32, new TensorShape(-1), name: "y");
- is_training = tf.placeholder(tf.boolean, new TensorShape(), name: "is_training");
+ is_training = tf.placeholder(tf.@bool, new TensorShape(), name: "is_training");
global_step = tf.Variable(0, trainable: false);
// Embedding Layer