diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index bfa71f31..8683809f 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.Text; using Tensorflow.Operations; using Tensorflow.Operations.Activation; +using static Tensorflow.Python; namespace Tensorflow { @@ -101,6 +102,25 @@ namespace Tensorflow Tensor logits = null, string name = null) => nn_ops.sparse_softmax_cross_entropy_with_logits(labels: labels, logits: logits, name: name); + /// + /// Computes softmax cross entropy between `logits` and `labels`. + /// + /// + /// + /// + /// + /// + public static Tensor softmax_cross_entropy_with_logits(Tensor labels, Tensor logits, int dim = -1, string name = null) + { + with(ops.name_scope(name, "softmax_cross_entropy_with_logits_sg", new { logits, labels }), scope => + { + name = scope; + labels = array_ops.stop_gradient(labels, name: "labels_stop_gradient"); + }); + + return softmax_cross_entropy_with_logits_v2(labels, logits, axis: dim, name: name); + } + public static Tensor softmax_cross_entropy_with_logits_v2(Tensor labels, Tensor logits, int axis = -1, string name = null) => nn_ops.softmax_cross_entropy_with_logits_v2_helper(labels, logits, axis: axis, name: name); } diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index 00be4db8..84941e8a 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -94,14 +94,28 @@ namespace Tensorflow if (attrs.ContainsKey(input_arg.TypeAttr)) dtype = (DataType)attrs[input_arg.TypeAttr]; else - if (values is Tensor[] values1) - dtype = values1[0].dtype.as_datatype_enum(); + switch (values) + { + case Tensor[] values1: + dtype = values1[0].dtype.as_datatype_enum(); + break; + case object[] values1: + foreach(var t in values1) + if(t is Tensor tensor) + { + dtype = tensor.dtype.as_datatype_enum(); + break; + } + break; + default: + throw new NotImplementedException($"can't infer the dtype for {values.GetType()}"); + } if (dtype == DataType.DtInvalid && default_type_attr_map.ContainsKey(input_arg.TypeAttr)) default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr]; } - if(input_arg.IsRef && dtype != DataType.DtInvalid) + if(!input_arg.IsRef && dtype != DataType.DtInvalid) dtype = dtype.as_base_dtype(); values = ops.internal_convert_n_to_tensor(values, diff --git a/src/TensorFlowNET.Core/Operations/Operation.Output.cs b/src/TensorFlowNET.Core/Operations/Operation.Output.cs index 52cfa2ad..635009eb 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Output.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Output.cs @@ -17,9 +17,7 @@ namespace Tensorflow private Tensor[] _outputs; public Tensor[] outputs => _outputs; -#if GRAPH_SERIALIZE - [JsonIgnore] -#endif + public Tensor output => _outputs.FirstOrDefault(); public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index c5bd77f6..3c1c632e 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -1,7 +1,4 @@ using Google.Protobuf.Collections; -#if GRAPH_SERIALIZE -using Newtonsoft.Json; -#endif using System; using System.Collections.Generic; using System.Linq; @@ -37,21 +34,11 @@ namespace Tensorflow private Graph _graph; public string type => OpType; -#if GRAPH_SERIALIZE - [JsonIgnore] public Graph graph => _graph; - [JsonIgnore] public int _id => _id_value; - [JsonIgnore] public int _id_value; - [JsonIgnore] public Operation op => this; -#else - public Graph graph => _graph; - public int _id => _id_value; - public int _id_value; - public Operation op => this; -#endif + public TF_DataType dtype => TF_DataType.DtInvalid; private Status status = new Status(); @@ -60,9 +47,6 @@ namespace Tensorflow public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle)); private NodeDef _node_def; -#if GRAPH_SERIALIZE - [JsonIgnore] -#endif public NodeDef node_def { get diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index 262752b2..e2ce9446 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -492,13 +492,18 @@ namespace Tensorflow { return with(ops.name_scope(name), scope => { var t = ops.convert_to_tensor(axis, name: "concat_dim", dtype: TF_DataType.TF_INT32); - return identity(values[0], name = scope); + return identity(values[0], name: scope); }); } return gen_array_ops.concat_v2(values, axis, name: name); } + public static Tensor concat(object[] values, int axis, string name = "concat") + { + return gen_array_ops.concat_v2(values, axis, name: name); + } + public static Tensor gather(Tensor @params, Tensor indices, string name = null, int axis = 0) => gen_array_ops.gather_v2(@params, indices, axis, name: name); diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 8308d48d..70e710a6 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -19,7 +19,7 @@ namespace Tensorflow /// /// /// - public static Tensor concat_v2(Tensor[] values, int axis, string name = null) + public static Tensor concat_v2(T[] values, int axis, string name = null) { var _op = _op_def_lib._apply_op_helper("ConcatV2", name: name, args: new { values, axis }); diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs index f3c39f9b..0954b673 100644 --- a/src/TensorFlowNET.Core/Operations/nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; using Tensorflow.Operations; using static Tensorflow.Python; @@ -159,8 +160,9 @@ namespace Tensorflow int axis = -1, string name = null) { - return Python.with(ops.name_scope(name, "softmax_cross_entropy_with_logits", new { }), scope => + return with(ops.name_scope(name, "softmax_cross_entropy_with_logits", new { logits, labels }), scope => { + name = scope; var precise_logits = logits; var input_rank = array_ops.rank(precise_logits); var shape = logits.TensorShape; @@ -170,6 +172,10 @@ namespace Tensorflow var input_shape = array_ops.shape(precise_logits); + // Make precise_logits and labels into matrices. + precise_logits = _flatten_outer_dims(precise_logits); + labels = _flatten_outer_dims(labels); + // Do the actual op computation. // The second output tensor contains the gradients. We use it in // _CrossEntropyGrad() in nn_grad but not here. @@ -186,5 +192,50 @@ namespace Tensorflow return cost; }); } + + /// + /// Flattens logits' outer dimensions and keep its last dimension. + /// + /// + /// + private static Tensor _flatten_outer_dims(Tensor logits) + { + var rank = array_ops.rank(logits); + var last_dim_size = array_ops.slice(array_ops.shape(logits), + new[] { math_ops.subtract(rank, 1) }, + new[] { 1 }); + + var ops = array_ops.concat(new[] { new[] { -1 }, (object)last_dim_size }, 0); + var output = array_ops.reshape(logits, ops); + + // Set output shape if known. + // if not context.executing_eagerly(): + var shape = logits.TensorShape; + if(shape != null && shape.NDim > 0) + { + var product = 1; + var product_valid = true; + foreach(var d in shape.Dimensions.Take(shape.NDim - 1)) + { + if(d == -1) + { + product_valid = false; + break; + } + else + { + product *= d; + } + } + + if (product_valid) + { + var output_shape = new[] { product }; + throw new NotImplementedException("_flatten_outer_dims product_valid"); + } + } + + return output; + } } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 8918c03c..3ca6ed9b 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -22,21 +22,11 @@ namespace Tensorflow private int _id; private Operation _op; -#if GRAPH_SERIALIZE - [JsonIgnore] - public int Id => _id; - [JsonIgnore] - public Graph graph => op?.graph; - [JsonIgnore] - public Operation op => _op; - [JsonIgnore] - public Tensor[] outputs => op.outputs; -#else + public int Id => _id; public Graph graph => op?.graph; public Operation op => _op; public Tensor[] outputs => op.outputs; -#endif /// /// The string name of this tensor. @@ -50,18 +40,12 @@ namespace Tensorflow private TF_DataType _dtype = TF_DataType.DtInvalid; public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle); -#if GRAPH_SERIALIZE - [JsonIgnore] -#endif + public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); -#if GRAPH_SERIALIZE - [JsonIgnore] -#endif + public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; -#if GRAPH_SERIALIZE - [JsonIgnore] -#endif + public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); @@ -70,9 +54,6 @@ namespace Tensorflow /// /// used for keep other pointer when do implicit operating /// -#if GRAPH_SERIALIZE - [JsonIgnore] -#endif public object Tag { get; set; } public int[] shape @@ -140,9 +121,7 @@ namespace Tensorflow } } } -#if GRAPH_SERIALIZE - [JsonIgnore] -#endif + public int NDims => rank; public string Device => op.Device; diff --git a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs index 15557679..8364e86c 100644 --- a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs +++ b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs @@ -110,7 +110,7 @@ namespace Tensorflow.Train var update_beta2 = beta2_power.assign(beta2_power * _beta2_t, use_locking: _use_locking); operations.Add(update_beta1); - operations.Add(update_beta1); + operations.Add(update_beta2); }); return control_flow_ops.group(operations.ToArray(), name: name_scope); diff --git a/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionNN.cs b/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionNN.cs index 9e30e421..d57cda11 100644 --- a/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionNN.cs +++ b/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionNN.cs @@ -49,8 +49,6 @@ namespace TensorFlowNET.Examples.ImageProcess public Graph BuildGraph() { - var g = tf.Graph(); - // Placeholders for inputs (x) and outputs(y) x = tf.placeholder(tf.float32, shape: (-1, img_size_flat), name: "X"); y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: "Y"); @@ -60,7 +58,8 @@ namespace TensorFlowNET.Examples.ImageProcess // Create a fully-connected layer with n_classes nodes as output layer var output_logits = fc_layer(fc1, n_classes, "OUT", use_relu: false); // Define the loss function, optimizer, and accuracy - loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels: y, logits: output_logits), name: "loss"); + var logits = tf.nn.softmax_cross_entropy_with_logits(labels: y, logits: output_logits); + loss = tf.reduce_mean(logits, name: "loss"); optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: "Adam-op").minimize(loss); var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: "correct_pred"); accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: "accuracy"); @@ -68,7 +67,7 @@ namespace TensorFlowNET.Examples.ImageProcess // Network predictions var cls_prediction = tf.argmax(output_logits, axis: 1, name: "predictions"); - return g; + return tf.get_default_graph(); } private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true) @@ -93,16 +92,10 @@ namespace TensorFlowNET.Examples.ImageProcess return layer; } - public Graph ImportGraph() - { - throw new NotImplementedException(); - } - - public bool Predict() - { - throw new NotImplementedException(); - } + public Graph ImportGraph() => throw new NotImplementedException(); + public bool Predict() => throw new NotImplementedException(); + public void PrepareData() { mnist = MnistDataSet.read_data_sets("mnist", one_hot: true); @@ -112,7 +105,6 @@ namespace TensorFlowNET.Examples.ImageProcess { // Number of training iterations in each epoch var num_tr_iter = mnist.train.labels.len / batch_size; - return with(tf.Session(), sess => { var init = tf.global_variables_initializer(); @@ -153,10 +145,9 @@ namespace TensorFlowNET.Examples.ImageProcess print("---------------------------------------------------------"); print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}"); print("---------------------------------------------------------"); - } - return accuracy_val > 0.9; + return accuracy_val > 0.95; }); } diff --git a/test/TensorFlowNET.Examples/python/neural_network.py b/test/TensorFlowNET.Examples/python/neural_network.py new file mode 100644 index 00000000..ac9597ca --- /dev/null +++ b/test/TensorFlowNET.Examples/python/neural_network.py @@ -0,0 +1,164 @@ + +# imports +import tensorflow as tf +import numpy as np +import matplotlib.pyplot as plt + +img_h = img_w = 28 # MNIST images are 28x28 +img_size_flat = img_h * img_w # 28x28=784, the total number of pixels +n_classes = 10 # Number of classes, one class per digit + +def load_data(mode='train'): + """ + Function to (download and) load the MNIST data + :param mode: train or test + :return: images and the corresponding labels + """ + from tensorflow.examples.tutorials.mnist import input_data + mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) + if mode == 'train': + x_train, y_train, x_valid, y_valid = mnist.train.images, mnist.train.labels, \ + mnist.validation.images, mnist.validation.labels + return x_train, y_train, x_valid, y_valid + elif mode == 'test': + x_test, y_test = mnist.test.images, mnist.test.labels + return x_test, y_test + +def randomize(x, y): + """ Randomizes the order of data samples and their corresponding labels""" + permutation = np.random.permutation(y.shape[0]) + shuffled_x = x[permutation, :] + shuffled_y = y[permutation] + return shuffled_x, shuffled_y + +def get_next_batch(x, y, start, end): + x_batch = x[start:end] + y_batch = y[start:end] + return x_batch, y_batch + +# Load MNIST data +x_train, y_train, x_valid, y_valid = load_data(mode='train') +print("Size of:") +print("- Training-set:\t\t{}".format(len(y_train))) +print("- Validation-set:\t{}".format(len(y_valid))) + +print('x_train:\t{}'.format(x_train.shape)) +print('y_train:\t{}'.format(y_train.shape)) +print('x_train:\t{}'.format(x_valid.shape)) +print('y_valid:\t{}'.format(y_valid.shape)) + +print(y_valid[:5, :]) + +# Hyper-parameters +epochs = 10 # Total number of training epochs +batch_size = 100 # Training batch size +display_freq = 100 # Frequency of displaying the training results +learning_rate = 0.001 # The optimization initial learning rate + +h1 = 200 # number of nodes in the 1st hidden layer + +# weight and bais wrappers +def weight_variable(name, shape): + """ + Create a weight variable with appropriate initialization + :param name: weight name + :param shape: weight shape + :return: initialized weight variable + """ + initer = tf.truncated_normal_initializer(stddev=0.01) + return tf.get_variable('W_' + name, + dtype=tf.float32, + shape=shape, + initializer=initer) + + +def bias_variable(name, shape): + """ + Create a bias variable with appropriate initialization + :param name: bias variable name + :param shape: bias variable shape + :return: initialized bias variable + """ + initial = tf.constant(0., shape=shape, dtype=tf.float32) + return tf.get_variable('b_' + name, + dtype=tf.float32, + initializer=initial) + +def fc_layer(x, num_units, name, use_relu=True): + """ + Create a fully-connected layer + :param x: input from previous layer + :param num_units: number of hidden units in the fully-connected layer + :param name: layer name + :param use_relu: boolean to add ReLU non-linearity (or not) + :return: The output array + """ + in_dim = x.get_shape()[1] + W = weight_variable(name, shape=[in_dim, num_units]) + b = bias_variable(name, [num_units]) + layer = tf.matmul(x, W) + layer += b + if use_relu: + layer = tf.nn.relu(layer) + return layer + +# Create the graph for the linear model +# Placeholders for inputs (x) and outputs(y) +x = tf.placeholder(tf.float32, shape=[None, img_size_flat], name='X') +y = tf.placeholder(tf.float32, shape=[None, n_classes], name='Y') + +# Create a fully-connected layer with h1 nodes as hidden layer +fc1 = fc_layer(x, h1, 'FC1', use_relu=True) +# Create a fully-connected layer with n_classes nodes as output layer +output_logits = fc_layer(fc1, n_classes, 'OUT', use_relu=False) + +# Define the loss function, optimizer, and accuracy +logits = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=output_logits) +loss = tf.reduce_mean(logits, name='loss') +optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, name='Adam-op').minimize(loss) +correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name='correct_pred') +accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy') + +# Network predictions +cls_prediction = tf.argmax(output_logits, axis=1, name='predictions') + +# export graph +#tf.train.export_meta_graph(filename='neural_network.meta', graph=tf.get_default_graph(), clear_extraneous_savers= True, as_text = True) + +# Create the op for initializing all variables +init = tf.global_variables_initializer() + +# Create an interactive session (to keep the session in the other cells) +sess = tf.InteractiveSession() +# Initialize all variables +sess.run(init) +# Number of training iterations in each epoch +num_tr_iter = int(len(y_train) / batch_size) +for epoch in range(epochs): + print('Training epoch: {}'.format(epoch + 1)) + # Randomly shuffle the training data at the beginning of each epoch + x_train, y_train = randomize(x_train, y_train) + for iteration in range(num_tr_iter): + start = iteration * batch_size + end = (iteration + 1) * batch_size + x_batch, y_batch = get_next_batch(x_train, y_train, start, end) + + # Run optimization op (backprop) + feed_dict_batch = {x: x_batch, y: y_batch} + sess.run(optimizer, feed_dict=feed_dict_batch) + + if iteration % display_freq == 0: + # Calculate and display the batch loss and accuracy + loss_batch, acc_batch = sess.run([loss, accuracy], + feed_dict=feed_dict_batch) + + print("iter {0:3d}:\t Loss={1:.2f},\tTraining Accuracy={2:.01%}". + format(iteration, loss_batch, acc_batch)) + + # Run validation after every epoch + feed_dict_valid = {x: x_valid[:1000], y: y_valid[:1000]} + loss_valid, acc_valid = sess.run([loss, accuracy], feed_dict=feed_dict_valid) + print('---------------------------------------------------------') + print("Epoch: {0}, validation loss: {1:.2f}, validation accuracy: {2:.01%}". + format(epoch + 1, loss_valid, acc_valid)) + print('---------------------------------------------------------') \ No newline at end of file