diff --git a/README.md b/README.md index ee89da4c..457a1533 100644 --- a/README.md +++ b/README.md @@ -142,13 +142,13 @@ Example runner will download all the required files like training data and model * [Logistic Regression](test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs) * [Nearest Neighbor](test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs) * [Naive Bayes Classification](test/TensorFlowNET.Examples/BasicModels/NaiveBayesClassifier.cs) +* [Full Connected Neural Network](test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionNN.cs) * [Image Recognition](test/TensorFlowNET.Examples/ImageProcess) * [K-means Clustering](test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs) * [NN XOR](test/TensorFlowNET.Examples/BasicModels/NeuralNetXor.cs) * [Object Detection](test/TensorFlowNET.Examples/ImageProcess/ObjectDetection.cs) * [Text Classification](test/TensorFlowNET.Examples/TextProcess/BinaryTextClassification.cs) * [CNN Text Classification](test/TensorFlowNET.Examples/TextProcess/cnn_models/VdCnn.cs) - * [Named Entity Recognition](test/TensorFlowNET.Examples/TextProcess/NER) * [Transfer Learning for Image Classification in InceptionV3](test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs) diff --git a/src/KerasNET.Core/Layers/Dense.cs b/src/KerasNET.Core/Layers/Dense.cs index c6c086a4..0970b421 100644 --- a/src/KerasNET.Core/Layers/Dense.cs +++ b/src/KerasNET.Core/Layers/Dense.cs @@ -40,7 +40,7 @@ namespace Keras.Layers var dot = tf.matmul(x, W); if (this.activation != null) dot = activation.Activate(dot); - Console.WriteLine("Calling Layer \"" + name + "(" + np.array(dot.GetShape().Dimensions).ToString() + ")\" ..."); + Console.WriteLine("Calling Layer \"" + name + "(" + np.array(dot.TensorShape.Dimensions).ToString() + ")\" ..."); return dot; } public TensorShape __shape__() diff --git a/src/KerasNET.Core/Model.cs b/src/KerasNET.Core/Model.cs index cb960169..c7606106 100644 --- a/src/KerasNET.Core/Model.cs +++ b/src/KerasNET.Core/Model.cs @@ -65,7 +65,7 @@ namespace Keras #endregion #region Model Graph Form Layer Stack - var flow_shape = features.GetShape(); + var flow_shape = features.TensorShape; Flow = features; for (int i = 0; i < layer_stack.Count; i++) { diff --git a/src/TensorFlowNET.Core/Framework/common_shapes.py.cs b/src/TensorFlowNET.Core/Framework/common_shapes.py.cs index e0a08184..770756dc 100644 --- a/src/TensorFlowNET.Core/Framework/common_shapes.py.cs +++ b/src/TensorFlowNET.Core/Framework/common_shapes.py.cs @@ -37,7 +37,7 @@ namespace Tensorflow.Framework public static bool has_fully_defined_shape(Tensor tensor) { - return tensor.GetShape().is_fully_defined(); + return tensor.TensorShape.is_fully_defined(); } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs index 6c61b12a..ebb2ed12 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/Layer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/Layer.cs @@ -161,7 +161,7 @@ namespace Tensorflow.Keras.Layers if (_dtype == TF_DataType.DtInvalid) _dtype = input.dtype; - var input_shapes = input.GetShape(); + var input_shapes = input.TensorShape; build(input_shapes); built = true; } diff --git a/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs b/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs index c07712a1..85157fac 100644 --- a/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs +++ b/src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs @@ -118,8 +118,8 @@ namespace Tensorflow if(weights > 0) { var weights_tensor = ops.convert_to_tensor(weights); - var labels_rank = labels.GetShape().NDim; - var weights_shape = weights_tensor.GetShape(); + var labels_rank = labels.TensorShape.NDim; + var weights_shape = weights_tensor.TensorShape; var weights_rank = weights_shape.NDim; if (labels_rank > -1 && weights_rank > -1) diff --git a/src/TensorFlowNET.Core/Operations/NnOps/_WithSpaceToBatch.cs b/src/TensorFlowNET.Core/Operations/NnOps/_WithSpaceToBatch.cs index 8f77e0ea..71de87df 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/_WithSpaceToBatch.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/_WithSpaceToBatch.cs @@ -18,7 +18,7 @@ namespace Tensorflow.Operations string data_format = null) { var dilation_rate_tensor = ops.convert_to_tensor(dilation_rate, TF_DataType.TF_INT32, name: "dilation_rate"); - var rate_shape = dilation_rate_tensor.GetShape(); + var rate_shape = dilation_rate_tensor.TensorShape; var num_spatial_dims = rate_shape.Dimensions[0]; int starting_spatial_dim = -1; if (!string.IsNullOrEmpty(data_format) && data_format.StartsWith("NC")) diff --git a/src/TensorFlowNET.Core/Operations/confusion_matrix.py.cs b/src/TensorFlowNET.Core/Operations/confusion_matrix.py.cs index 0cb54647..59e8ebdb 100644 --- a/src/TensorFlowNET.Core/Operations/confusion_matrix.py.cs +++ b/src/TensorFlowNET.Core/Operations/confusion_matrix.py.cs @@ -24,9 +24,9 @@ namespace Tensorflow { predictions = ops.convert_to_tensor(predictions); labels = ops.convert_to_tensor(labels); - var predictions_shape = predictions.GetShape(); + var predictions_shape = predictions.TensorShape; var predictions_rank = predictions_shape.NDim; - var labels_shape = labels.GetShape(); + var labels_shape = labels.TensorShape; var labels_rank = labels_shape.NDim; if(labels_rank > -1 && predictions_rank > -1) { diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs index 69d5f166..f3c39f9b 100644 --- a/src/TensorFlowNET.Core/Operations/nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs @@ -83,7 +83,7 @@ namespace Tensorflow // float to be selected, hence we use a >= comparison. var keep_mask = random_tensor >= rate; var ret = x * scale * math_ops.cast(keep_mask, x.dtype); - ret.SetShape(x.GetShape()); + ret.SetShape(x.TensorShape); return ret; }); } @@ -131,14 +131,14 @@ namespace Tensorflow var precise_logits = logits.dtype == TF_DataType.TF_HALF ? math_ops.cast(logits, dtypes.float32) : logits; // Store label shape for result later. - var labels_static_shape = labels.GetShape(); + var labels_static_shape = labels.TensorShape; var labels_shape = array_ops.shape(labels); /*bool static_shapes_fully_defined = ( labels_static_shape.is_fully_defined() && logits.get_shape()[:-1].is_fully_defined());*/ // Check if no reshapes are required. - if(logits.GetShape().NDim == 2) + if(logits.TensorShape.NDim == 2) { var (cost, _) = gen_nn_ops.sparse_softmax_cross_entropy_with_logits( precise_logits, labels, name: name); @@ -163,7 +163,7 @@ namespace Tensorflow { var precise_logits = logits; var input_rank = array_ops.rank(precise_logits); - var shape = logits.GetShape(); + var shape = logits.TensorShape; if (axis != -1) throw new NotImplementedException("softmax_cross_entropy_with_logits_v2_helper axis != -1"); diff --git a/src/TensorFlowNET.Core/Operations/weights_broadcast_ops.cs b/src/TensorFlowNET.Core/Operations/weights_broadcast_ops.cs index f0afa1fe..5e3a2cd6 100644 --- a/src/TensorFlowNET.Core/Operations/weights_broadcast_ops.cs +++ b/src/TensorFlowNET.Core/Operations/weights_broadcast_ops.cs @@ -16,8 +16,8 @@ namespace Tensorflow weights, dtype: values.dtype.as_base_dtype(), name: "weights"); // Try static check for exact match. - var weights_shape = weights.GetShape(); - var values_shape = values.GetShape(); + var weights_shape = weights.TensorShape; + var values_shape = values.TensorShape; if (weights_shape.is_fully_defined() && values_shape.is_fully_defined()) return weights; diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs index c69419ec..103161fd 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs @@ -1,4 +1,5 @@ -using System; +using NumSharp; +using System; using System.Collections.Generic; using System.Text; diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 6e1a63ba..8918c03c 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -110,10 +110,7 @@ namespace Tensorflow return shape.Select(x => (int)x).ToArray(); } - public TensorShape GetShape() - { - return tensor_util.to_shape(shape); - } + public TensorShape TensorShape => tensor_util.to_shape(shape); public void SetShape(Shape shape) { diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index 9498cead..d144bfb6 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -37,5 +37,7 @@ namespace Tensorflow { throw new NotImplementedException("TensorShape is_compatible_with"); } + + public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); } } diff --git a/src/TensorFlowNET.Core/Tensors/tf.constant.cs b/src/TensorFlowNET.Core/Tensors/tf.constant.cs index c4c5ffaf..2783e51c 100644 --- a/src/TensorFlowNET.Core/Tensors/tf.constant.cs +++ b/src/TensorFlowNET.Core/Tensors/tf.constant.cs @@ -20,6 +20,15 @@ namespace Tensorflow verify_shape: verify_shape, allow_broadcast: false); + public static Tensor constant(float value, + int shape, + string name = "Const") => constant_op._constant_impl(value, + tf.float32, + new int[] { shape }, + name, + verify_shape: false, + allow_broadcast: false); + public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) => array_ops.zeros(shape, dtype, name); public static Tensor size(Tensor input, diff --git a/src/TensorFlowNET.Core/Train/tf.optimizers.cs b/src/TensorFlowNET.Core/Train/tf.optimizers.cs index 9e3d66a6..95ab868a 100644 --- a/src/TensorFlowNET.Core/Train/tf.optimizers.cs +++ b/src/TensorFlowNET.Core/Train/tf.optimizers.cs @@ -10,9 +10,11 @@ namespace Tensorflow { public static class train { - public static Optimizer GradientDescentOptimizer(float learning_rate) => new GradientDescentOptimizer(learning_rate); + public static Optimizer GradientDescentOptimizer(float learning_rate) + => new GradientDescentOptimizer(learning_rate); - public static Optimizer AdamOptimizer(float learning_rate) => new AdamOptimizer(learning_rate); + public static Optimizer AdamOptimizer(float learning_rate, string name = null) + => new AdamOptimizer(learning_rate, name: name); public static Saver Saver(VariableV1[] var_list = null) => new Saver(var_list: var_list); diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index 7b1de63c..c1912455 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -153,7 +153,7 @@ namespace Tensorflow // Manually overrides the variable's shape with the initial value's. if (validate_shape) { - var initial_value_shape = _initial_value.GetShape(); + var initial_value_shape = _initial_value.TensorShape; if (!initial_value_shape.is_fully_defined()) throw new ValueError($"initial_value must have a shape specified: {_initial_value}"); } diff --git a/test/KerasNET.Test/BaseTests.cs b/test/KerasNET.Test/BaseTests.cs index 6ab72276..92c777a0 100644 --- a/test/KerasNET.Test/BaseTests.cs +++ b/test/KerasNET.Test/BaseTests.cs @@ -15,8 +15,8 @@ namespace Keras.Test { var dense_1 = new Dense(1, name: "dense_1", activation: tf.nn.relu()); var input = new Tensor(np.array(new int[] { 3 })); - dense_1.__build__(input.GetShape()); - var outputShape = dense_1.output_shape(input.GetShape()); + dense_1.__build__(input.TensorShape); + var outputShape = dense_1.output_shape(input.TensorShape); var a = (int[])(outputShape.Dimensions); var b = (int[])(new int[] { 1 }); var _a = np.array(a); diff --git a/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionNN.cs b/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionNN.cs index c46453dd..9e30e421 100644 --- a/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionNN.cs +++ b/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionNN.cs @@ -1,14 +1,17 @@ -using System; +using NumSharp; +using System; using System.Collections.Generic; using System.Text; using Tensorflow; using TensorFlowNET.Examples.Utility; +using static Tensorflow.Python; namespace TensorFlowNET.Examples.ImageProcess { /// /// Neural Network classifier for Hand Written Digits - /// Sample Neural Network architecture with two layers implemented for classifying MNIST digits + /// Sample Neural Network architecture with two layers implemented for classifying MNIST digits. + /// Use Stochastic Gradient Descent (SGD) optimizer. /// http://www.easy-tensorflow.com/tf-tutorials/neural-networks /// public class DigitRecognitionNN : IExample @@ -22,24 +25,74 @@ namespace TensorFlowNET.Examples.ImageProcess const int img_w = 28; int img_size_flat = img_h * img_w; // 784, the total number of pixels int n_classes = 10; // Number of classes, one class per digit - int training_epochs = 10; - int? train_size = null; - int validation_size = 5000; - int? test_size = null; + // Hyper-parameters + int epochs = 10; int batch_size = 100; + float learning_rate = 0.001f; + int h1 = 200; // number of nodes in the 1st hidden layer Datasets mnist; + Tensor x, y; + Tensor loss, accuracy; + Operation optimizer; + + int display_freq = 100; + public bool Run() { PrepareData(); + BuildGraph(); + Train(); + return true; } public Graph BuildGraph() { - throw new NotImplementedException(); + var g = tf.Graph(); + + // Placeholders for inputs (x) and outputs(y) + x = tf.placeholder(tf.float32, shape: (-1, img_size_flat), name: "X"); + y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: "Y"); + + // Create a fully-connected layer with h1 nodes as hidden layer + var fc1 = fc_layer(x, h1, "FC1", use_relu: true); + // Create a fully-connected layer with n_classes nodes as output layer + var output_logits = fc_layer(fc1, n_classes, "OUT", use_relu: false); + // Define the loss function, optimizer, and accuracy + loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels: y, logits: output_logits), name: "loss"); + optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: "Adam-op").minimize(loss); + var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: "correct_pred"); + accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: "accuracy"); + + // Network predictions + var cls_prediction = tf.argmax(output_logits, axis: 1, name: "predictions"); + + return g; } + private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true) + { + var in_dim = x.shape[1]; + + var initer = tf.truncated_normal_initializer(stddev: 0.01f); + var W = tf.get_variable("W_" + name, + dtype: tf.float32, + shape: (in_dim, num_units), + initializer: initer); + + var initial = tf.constant(0f, num_units); + var b = tf.get_variable("b_" + name, + dtype: tf.float32, + initializer: initial); + + var layer = tf.matmul(x, W) + b; + if (use_relu) + layer = tf.nn.relu(layer); + + return layer; + } + public Graph ImportGraph() { throw new NotImplementedException(); @@ -52,12 +105,82 @@ namespace TensorFlowNET.Examples.ImageProcess public void PrepareData() { - mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: train_size, validation_size: validation_size, test_size: test_size); + mnist = MnistDataSet.read_data_sets("mnist", one_hot: true); } public bool Train() { - throw new NotImplementedException(); + // Number of training iterations in each epoch + var num_tr_iter = mnist.train.labels.len / batch_size; + + return with(tf.Session(), sess => + { + var init = tf.global_variables_initializer(); + sess.run(init); + + float loss_val = 100.0f; + float accuracy_val = 0f; + + foreach (var epoch in range(epochs)) + { + print($"Training epoch: {epoch + 1}"); + // Randomly shuffle the training data at the beginning of each epoch + var (x_train, y_train) = randomize(mnist.train.images, mnist.train.labels); + + foreach (var iteration in range(num_tr_iter)) + { + var start = iteration * batch_size; + var end = (iteration + 1) * batch_size; + var (x_batch, y_batch) = get_next_batch(x_train, y_train, start, end); + + // Run optimization op (backprop) + sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch)); + + if (iteration % display_freq == 0) + { + // Calculate and display the batch loss and accuracy + var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_batch), new FeedItem(y, y_batch)); + loss_val = result[0]; + accuracy_val = result[1]; + print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}"); + } + } + + // Run validation after every epoch + var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.validation.images), new FeedItem(y, mnist.validation.labels)); + loss_val = results1[0]; + accuracy_val = results1[1]; + print("---------------------------------------------------------"); + print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}"); + print("---------------------------------------------------------"); + + } + + return accuracy_val > 0.9; + }); + } + + private (NDArray, NDArray) randomize(NDArray x, NDArray y) + { + var perm = np.random.permutation(y.shape[0]); + + np.random.shuffle(perm); + return (mnist.train.images[perm], mnist.train.labels[perm]); + } + + /// + /// selects a few number of images determined by the batch_size variable (if you don't know why, read about Stochastic Gradient Method) + /// + /// + /// + /// + /// + /// + private (NDArray, NDArray) get_next_batch(NDArray x, NDArray y, int start, int end) + { + var x_batch = x[$"{start}:{end}"]; + var y_batch = y[$"{start}:{end}"]; + return (x_batch, y_batch); } } } diff --git a/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs b/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs index 3793027a..53cea791 100644 --- a/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs +++ b/test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs @@ -264,12 +264,12 @@ namespace TensorFlowNET.Examples.ImageProcess private (Operation, Tensor, Tensor, Tensor, Tensor) add_final_retrain_ops(int class_count, string final_tensor_name, Tensor bottleneck_tensor, bool quantize_layer, bool is_training) { - var (batch_size, bottleneck_tensor_size) = (bottleneck_tensor.GetShape().Dimensions[0], bottleneck_tensor.GetShape().Dimensions[1]); + var (batch_size, bottleneck_tensor_size) = (bottleneck_tensor.TensorShape.Dimensions[0], bottleneck_tensor.TensorShape.Dimensions[1]); with(tf.name_scope("input"), scope => { bottleneck_input = tf.placeholder_with_default( bottleneck_tensor, - shape: bottleneck_tensor.GetShape().Dimensions, + shape: bottleneck_tensor.TensorShape.Dimensions, name: "BottleneckInputPlaceholder"); ground_truth_input = tf.placeholder(tf.int64, new TensorShape(batch_size), name: "GroundTruthInput");