diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs index a97a8dda..dd385b4c 100644 --- a/src/TensorFlowNET.Core/APIs/tf.array.cs +++ b/src/TensorFlowNET.Core/APIs/tf.array.cs @@ -20,6 +20,13 @@ namespace Tensorflow public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1) => array_ops.expand_dims(input, axis, name, dim); + /// + /// Return the elements, either from `x` or `y`, depending on the `condition`. + /// + /// + public static Tensor where(Tensor condition, Tx x, Ty y, string name = null) + => array_ops.where(condition, x, y, name); + /// /// Transposes `a`. Permutes the dimensions according to `perm`. /// diff --git a/src/TensorFlowNET.Core/APIs/tf.random.cs b/src/TensorFlowNET.Core/APIs/tf.random.cs index b309fc06..b56df144 100644 --- a/src/TensorFlowNET.Core/APIs/tf.random.cs +++ b/src/TensorFlowNET.Core/APIs/tf.random.cs @@ -25,7 +25,7 @@ namespace Tensorflow public static Tensor random_uniform(int[] shape, float minval = 0, - float? maxval = null, + float maxval = 1, TF_DataType dtype = TF_DataType.TF_FLOAT, int? seed = null, string name = null) => random_ops.random_uniform(shape, minval, maxval, dtype, seed, name); diff --git a/src/TensorFlowNET.Core/Layers/Layer.cs b/src/TensorFlowNET.Core/Layers/Layer.cs index aad687e6..a4aafc94 100644 --- a/src/TensorFlowNET.Core/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Layers/Layer.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Linq; using System.Text; +using static Tensorflow.Python; namespace Tensorflow.Layers { @@ -50,7 +51,7 @@ namespace Tensorflow.Layers auxiliary_name_scope: false); } - Python.with(scope_context_manager, scope2 => _current_scope = scope2); + with(scope_context_manager, scope2 => _current_scope = scope2); // Actually call layer var outputs = base.__call__(new Tensor[] { inputs }, training: training); @@ -60,6 +61,13 @@ namespace Tensorflow.Layers return outputs; } + protected override void _init_set_name(string name, bool zero_based = true) + { + // Determine layer name (non-unique). + base._init_set_name(name, zero_based: zero_based); + _base_name = this.name; + } + protected virtual void _add_elements_to_collection(Operation[] elements, string[] collection_list) { foreach(var name in collection_list) @@ -140,10 +148,18 @@ namespace Tensorflow.Layers { if (_scope == null) { - Python.with(tf.variable_scope(scope, default_name: _base_name), captured_scope => + if(_reuse.HasValue && _reuse.Value) { - _scope = captured_scope; - }); + throw new NotImplementedException("_set_scope _reuse.HasValue"); + /*with(tf.variable_scope(scope == null ? _base_name : scope), + captured_scope => _scope = captured_scope);*/ + } + else + { + with(tf.variable_scope(scope, default_name: _base_name), + captured_scope => _scope = captured_scope); + } + } } } diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index 34adba6d..0e043198 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -233,7 +233,7 @@ namespace Tensorflow }); } - public static Tensor where(Tensor condition, Tensor x = null, Tensor y = null, string name = null) + public static Tensor where(Tensor condition, object x = null, object y = null, string name = null) { if( x == null && y == null) { diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index a83c636e..0a49233c 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -234,7 +234,7 @@ namespace Tensorflow return _op.outputs[0]; } - public static Tensor select(Tensor condition, Tensor t, Tensor e, string name = null) + public static Tensor select(Tensor condition, Tx t, Ty e, string name = null) { var _op = _op_def_lib._apply_op_helper("Select", name, new { condition, t, e }); return _op.outputs[0]; diff --git a/src/TensorFlowNET.Core/Operations/random_ops.py.cs b/src/TensorFlowNET.Core/Operations/random_ops.py.cs index 54e9cd28..25ef6da0 100644 --- a/src/TensorFlowNET.Core/Operations/random_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/random_ops.py.cs @@ -49,7 +49,7 @@ namespace Tensorflow /// A tensor of the specified shape filled with random uniform values. public static Tensor random_uniform(int[] shape, float minval = 0, - float? maxval = null, + float maxval = 1, TF_DataType dtype = TF_DataType.TF_FLOAT, int? seed = null, string name = null) diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs index 33acdfbc..52c4f988 100644 --- a/src/TensorFlowNET.Core/tf.cs +++ b/src/TensorFlowNET.Core/tf.cs @@ -15,7 +15,7 @@ namespace Tensorflow public static TF_DataType float16 = TF_DataType.TF_HALF; public static TF_DataType float32 = TF_DataType.TF_FLOAT; public static TF_DataType float64 = TF_DataType.TF_DOUBLE; - public static TF_DataType boolean = TF_DataType.TF_BOOL; + public static TF_DataType @bool = TF_DataType.TF_BOOL; public static TF_DataType chars = TF_DataType.TF_STRING; public static Context context = new Context(new ContextOptions(), new Status()); diff --git a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs index 8bc2f9ef..289ed4d0 100644 --- a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs +++ b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs @@ -40,43 +40,157 @@ namespace TensorFlowNET.Examples protected float loss_value = 0; int vocabulary_size = 50000; + NDArray train_x, valid_x, train_y, valid_y; public bool Run() { PrepareData(); - var graph = tf.Graph().as_default(); - return with(tf.Session(graph), sess => + Train(); + + return true; + } + + // TODO: this originally is an SKLearn utility function. it randomizes train and test which we don't do here + private (NDArray, NDArray, NDArray, NDArray) train_test_split(NDArray x, NDArray y, float test_size = 0.3f) + { + Console.WriteLine("Splitting in Training and Testing data..."); + int len = x.shape[0]; + //int classes = y.Data().Distinct().Count(); + //int samples = len / classes; + int train_size = (int)Math.Round(len * (1 - test_size)); + var train_x = x[new Slice(stop: train_size), new Slice()]; + var valid_x = x[new Slice(start: train_size), new Slice()]; + var train_y = y[new Slice(stop: train_size)]; + var valid_y = y[new Slice(start: train_size)]; + Console.WriteLine("\tDONE"); + return (train_x, valid_x, train_y, valid_y); + } + + private static void FillWithShuffledLabels(int[][] x, int[] y, int[][] shuffled_x, int[] shuffled_y, Random random, Dictionary> labels) + { + int i = 0; + var label_keys = labels.Keys.ToArray(); + while (i < shuffled_x.Length) { - if (IsImportingGraph) - return RunWithImportedGraph(sess, graph); - else - return RunWithBuiltGraph(sess, graph); - }); + var key = label_keys[random.Next(label_keys.Length)]; + var set = labels[key]; + var index = set.First(); + if (set.Count == 0) + { + labels.Remove(key); // remove the set as it is empty + label_keys = labels.Keys.ToArray(); + } + shuffled_x[i] = x[index]; + shuffled_y[i] = y[index]; + i++; + } } - protected virtual bool RunWithImportedGraph(Session sess, Graph graph) + private IEnumerable<(NDArray, NDArray, int)> batch_iter(NDArray inputs, NDArray outputs, int batch_size, int num_epochs) { - var stopwatch = Stopwatch.StartNew(); + var num_batches_per_epoch = (len(inputs) - 1) / batch_size + 1; + var total_batches = num_batches_per_epoch * num_epochs; + foreach (var epoch in range(num_epochs)) + { + foreach (var batch_num in range(num_batches_per_epoch)) + { + var start_index = batch_num * batch_size; + var end_index = Math.Min((batch_num + 1) * batch_size, len(inputs)); + if (end_index <= start_index) + break; + yield return (inputs[new Slice(start_index, end_index)], outputs[new Slice(start_index, end_index)], total_batches); + } + } + } + + public void PrepareData() + { + // full dataset https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz + var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/dbpedia_subset.zip"; + Web.Download(url, dataDir, "dbpedia_subset.zip"); + Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv")); + Console.WriteLine("Building dataset..."); - int[][] x = null; - int[] y = null; + int alphabet_size = 0; var word_dict = DataHelpers.build_word_dict(TRAIN_PATH); - // vocabulary_size = len(word_dict); - (x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN); + vocabulary_size = len(word_dict); + var (x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN); Console.WriteLine("\tDONE "); var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); Console.WriteLine("Training set size: " + train_x.len); Console.WriteLine("Test set size: " + valid_x.len); + } - Console.WriteLine("Import graph..."); + public Graph ImportGraph() + { + var graph = tf.Graph().as_default(); + + // download graph meta data var meta_file = "word_cnn.meta"; + var meta_path = Path.Combine("graph", meta_file); + if (File.GetLastWriteTime(meta_path) < new DateTime(2019, 05, 11)) + { + // delete old cached file which contains errors + Console.WriteLine("Discarding cached file: " + meta_path); + File.Delete(meta_path); + } + var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file; + Web.Download(url, "graph", meta_file); + + Console.WriteLine("Import graph..."); tf.train.import_meta_graph(Path.Join("graph", meta_file)); - Console.WriteLine("\tDONE " + stopwatch.Elapsed); + Console.WriteLine("\tDONE "); + + return graph; + } + + public Graph BuildGraph() + { + var graph = tf.Graph().as_default(); + + var embedding_size = 128; + var learning_rate = 0.001f; + var filter_sizes = new int[3, 4, 5]; + var num_filters = 100; + var document_max_len = 100; + + var x = tf.placeholder(tf.int32, new TensorShape(-1, document_max_len), name: "x"); + var y = tf.placeholder(tf.int32, new TensorShape(-1), name: "y"); + var is_training = tf.placeholder(tf.@bool, new TensorShape(), name: "is_training"); + var global_step = tf.Variable(0, trainable: false); + var keep_prob = tf.where(is_training, 0.5, 1.0); + Tensor x_emb = null; + + with(tf.name_scope("embedding"), scope => + { + var init_embeddings = tf.random_uniform(new int[] { vocabulary_size, embedding_size }); + var embeddings = tf.get_variable("embeddings", initializer: init_embeddings); + x_emb = tf.nn.embedding_lookup(embeddings, x); + x_emb = tf.expand_dims(x_emb, -1); + }); + + foreach(var filter_size in filter_sizes) + { + var conv = tf.layers.conv2d( + x_emb, + filters: num_filters, + kernel_size: new int[] { filter_size, embedding_size }, + strides: new int[] { 1, 1 }, + padding: "VALID", + activation: tf.nn.relu()); + } + + return graph; + } + + private bool RunWithImportedGraph(Session sess, Graph graph) + { + var stopwatch = Stopwatch.StartNew(); sess.run(tf.global_variables_initializer()); var saver = tf.train.Saver(tf.global_variables()); @@ -149,107 +263,12 @@ namespace TensorFlowNET.Examples return false; } - protected virtual bool RunWithBuiltGraph(Session session, Graph graph) - { - Console.WriteLine("Building dataset..."); - var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "word_cnn", CHAR_MAX_LEN, DataLimit); - - var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); - - ITextClassificationModel model = null; - // todo train the model - return false; - } - - // TODO: this originally is an SKLearn utility function. it randomizes train and test which we don't do here - private (NDArray, NDArray, NDArray, NDArray) train_test_split(NDArray x, NDArray y, float test_size = 0.3f) - { - Console.WriteLine("Splitting in Training and Testing data..."); - int len = x.shape[0]; - //int classes = y.Data().Distinct().Count(); - //int samples = len / classes; - int train_size = (int)Math.Round(len * (1 - test_size)); - var train_x = x[new Slice(stop: train_size), new Slice()]; - var valid_x = x[new Slice(start: train_size), new Slice()]; - var train_y = y[new Slice(stop: train_size)]; - var valid_y = y[new Slice(start: train_size)]; - Console.WriteLine("\tDONE"); - return (train_x, valid_x, train_y, valid_y); - } - - private static void FillWithShuffledLabels(int[][] x, int[] y, int[][] shuffled_x, int[] shuffled_y, Random random, Dictionary> labels) - { - int i = 0; - var label_keys = labels.Keys.ToArray(); - while (i < shuffled_x.Length) - { - var key = label_keys[random.Next(label_keys.Length)]; - var set = labels[key]; - var index = set.First(); - if (set.Count == 0) - { - labels.Remove(key); // remove the set as it is empty - label_keys = labels.Keys.ToArray(); - } - shuffled_x[i] = x[index]; - shuffled_y[i] = y[index]; - i++; - } - } - - private IEnumerable<(NDArray, NDArray, int)> batch_iter(NDArray inputs, NDArray outputs, int batch_size, int num_epochs) - { - var num_batches_per_epoch = (len(inputs) - 1) / batch_size + 1; - var total_batches = num_batches_per_epoch * num_epochs; - foreach (var epoch in range(num_epochs)) - { - foreach (var batch_num in range(num_batches_per_epoch)) - { - var start_index = batch_num * batch_size; - var end_index = Math.Min((batch_num + 1) * batch_size, len(inputs)); - if (end_index <= start_index) - break; - yield return (inputs[new Slice(start_index, end_index)], outputs[new Slice(start_index, end_index)], total_batches); - } - } - } - - public void PrepareData() - { - // full dataset https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz - var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/dbpedia_subset.zip"; - Web.Download(url, dataDir, "dbpedia_subset.zip"); - Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv")); - - if (IsImportingGraph) - { - // download graph meta data - var meta_file = "word_cnn.meta"; - var meta_path = Path.Combine("graph", meta_file); - if (File.GetLastWriteTime(meta_path) < new DateTime(2019, 05, 11)) - { - // delete old cached file which contains errors - Console.WriteLine("Discarding cached file: " + meta_path); - File.Delete(meta_path); - } - url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file; - Web.Download(url, "graph", meta_file); - } - } - - public Graph ImportGraph() - { - throw new NotImplementedException(); - } - - public Graph BuildGraph() - { - throw new NotImplementedException(); - } - public bool Train() { - throw new NotImplementedException(); + var graph = IsImportingGraph ? ImportGraph() : BuildGraph(); + + return with(tf.Session(graph), sess + => RunWithImportedGraph(sess, graph)); } public bool Predict() diff --git a/test/TensorFlowNET.Examples/TextProcess/cnn_models/VdCnn.cs b/test/TensorFlowNET.Examples/TextProcess/cnn_models/VdCnn.cs index fdba49ec..f4f430a5 100644 --- a/test/TensorFlowNET.Examples/TextProcess/cnn_models/VdCnn.cs +++ b/test/TensorFlowNET.Examples/TextProcess/cnn_models/VdCnn.cs @@ -40,7 +40,7 @@ namespace TensorFlowNET.Examples.TextClassification x = tf.placeholder(tf.int32, new TensorShape(-1, document_max_len), name: "x"); y = tf.placeholder(tf.int32, new TensorShape(-1), name: "y"); - is_training = tf.placeholder(tf.boolean, new TensorShape(), name: "is_training"); + is_training = tf.placeholder(tf.@bool, new TensorShape(), name: "is_training"); global_step = tf.Variable(0, trainable: false); // Embedding Layer