diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs
index bfa71f31..8683809f 100644
--- a/src/TensorFlowNET.Core/APIs/tf.nn.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs
@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Text;
using Tensorflow.Operations;
using Tensorflow.Operations.Activation;
+using static Tensorflow.Python;
namespace Tensorflow
{
@@ -101,6 +102,25 @@ namespace Tensorflow
Tensor logits = null, string name = null)
=> nn_ops.sparse_softmax_cross_entropy_with_logits(labels: labels, logits: logits, name: name);
+ ///
+ /// Computes softmax cross entropy between `logits` and `labels`.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static Tensor softmax_cross_entropy_with_logits(Tensor labels, Tensor logits, int dim = -1, string name = null)
+ {
+ with(ops.name_scope(name, "softmax_cross_entropy_with_logits_sg", new { logits, labels }), scope =>
+ {
+ name = scope;
+ labels = array_ops.stop_gradient(labels, name: "labels_stop_gradient");
+ });
+
+ return softmax_cross_entropy_with_logits_v2(labels, logits, axis: dim, name: name);
+ }
+
public static Tensor softmax_cross_entropy_with_logits_v2(Tensor labels, Tensor logits, int axis = -1, string name = null)
=> nn_ops.softmax_cross_entropy_with_logits_v2_helper(labels, logits, axis: axis, name: name);
}
diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
index 00be4db8..84941e8a 100644
--- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
+++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs
@@ -94,14 +94,28 @@ namespace Tensorflow
if (attrs.ContainsKey(input_arg.TypeAttr))
dtype = (DataType)attrs[input_arg.TypeAttr];
else
- if (values is Tensor[] values1)
- dtype = values1[0].dtype.as_datatype_enum();
+ switch (values)
+ {
+ case Tensor[] values1:
+ dtype = values1[0].dtype.as_datatype_enum();
+ break;
+ case object[] values1:
+ foreach(var t in values1)
+ if(t is Tensor tensor)
+ {
+ dtype = tensor.dtype.as_datatype_enum();
+ break;
+ }
+ break;
+ default:
+ throw new NotImplementedException($"can't infer the dtype for {values.GetType()}");
+ }
if (dtype == DataType.DtInvalid && default_type_attr_map.ContainsKey(input_arg.TypeAttr))
default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr];
}
- if(input_arg.IsRef && dtype != DataType.DtInvalid)
+ if(!input_arg.IsRef && dtype != DataType.DtInvalid)
dtype = dtype.as_base_dtype();
values = ops.internal_convert_n_to_tensor(values,
diff --git a/src/TensorFlowNET.Core/Operations/Operation.Output.cs b/src/TensorFlowNET.Core/Operations/Operation.Output.cs
index 52cfa2ad..635009eb 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.Output.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.Output.cs
@@ -17,9 +17,7 @@ namespace Tensorflow
private Tensor[] _outputs;
public Tensor[] outputs => _outputs;
-#if GRAPH_SERIALIZE
- [JsonIgnore]
-#endif
+
public Tensor output => _outputs.FirstOrDefault();
public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle);
diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs
index c5bd77f6..3c1c632e 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.cs
@@ -1,7 +1,4 @@
using Google.Protobuf.Collections;
-#if GRAPH_SERIALIZE
-using Newtonsoft.Json;
-#endif
using System;
using System.Collections.Generic;
using System.Linq;
@@ -37,21 +34,11 @@ namespace Tensorflow
private Graph _graph;
public string type => OpType;
-#if GRAPH_SERIALIZE
- [JsonIgnore]
public Graph graph => _graph;
- [JsonIgnore]
public int _id => _id_value;
- [JsonIgnore]
public int _id_value;
- [JsonIgnore]
public Operation op => this;
-#else
- public Graph graph => _graph;
- public int _id => _id_value;
- public int _id_value;
- public Operation op => this;
-#endif
+
public TF_DataType dtype => TF_DataType.DtInvalid;
private Status status = new Status();
@@ -60,9 +47,6 @@ namespace Tensorflow
public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle));
private NodeDef _node_def;
-#if GRAPH_SERIALIZE
- [JsonIgnore]
-#endif
public NodeDef node_def
{
get
diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs
index 262752b2..e2ce9446 100644
--- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs
@@ -492,13 +492,18 @@ namespace Tensorflow
{
return with(ops.name_scope(name), scope => {
var t = ops.convert_to_tensor(axis, name: "concat_dim", dtype: TF_DataType.TF_INT32);
- return identity(values[0], name = scope);
+ return identity(values[0], name: scope);
});
}
return gen_array_ops.concat_v2(values, axis, name: name);
}
+ public static Tensor concat(object[] values, int axis, string name = "concat")
+ {
+ return gen_array_ops.concat_v2(values, axis, name: name);
+ }
+
public static Tensor gather(Tensor @params, Tensor indices, string name = null, int axis = 0)
=> gen_array_ops.gather_v2(@params, indices, axis, name: name);
diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
index 8308d48d..70e710a6 100644
--- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
@@ -19,7 +19,7 @@ namespace Tensorflow
///
///
///
- public static Tensor concat_v2(Tensor[] values, int axis, string name = null)
+ public static Tensor concat_v2(T[] values, int axis, string name = null)
{
var _op = _op_def_lib._apply_op_helper("ConcatV2", name: name, args: new { values, axis });
diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs
index f3c39f9b..0954b673 100644
--- a/src/TensorFlowNET.Core/Operations/nn_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
+using System.Linq;
using System.Text;
using Tensorflow.Operations;
using static Tensorflow.Python;
@@ -159,8 +160,9 @@ namespace Tensorflow
int axis = -1,
string name = null)
{
- return Python.with(ops.name_scope(name, "softmax_cross_entropy_with_logits", new { }), scope =>
+ return with(ops.name_scope(name, "softmax_cross_entropy_with_logits", new { logits, labels }), scope =>
{
+ name = scope;
var precise_logits = logits;
var input_rank = array_ops.rank(precise_logits);
var shape = logits.TensorShape;
@@ -170,6 +172,10 @@ namespace Tensorflow
var input_shape = array_ops.shape(precise_logits);
+ // Make precise_logits and labels into matrices.
+ precise_logits = _flatten_outer_dims(precise_logits);
+ labels = _flatten_outer_dims(labels);
+
// Do the actual op computation.
// The second output tensor contains the gradients. We use it in
// _CrossEntropyGrad() in nn_grad but not here.
@@ -186,5 +192,50 @@ namespace Tensorflow
return cost;
});
}
+
+ ///
+ /// Flattens logits' outer dimensions and keep its last dimension.
+ ///
+ ///
+ ///
+ private static Tensor _flatten_outer_dims(Tensor logits)
+ {
+ var rank = array_ops.rank(logits);
+ var last_dim_size = array_ops.slice(array_ops.shape(logits),
+ new[] { math_ops.subtract(rank, 1) },
+ new[] { 1 });
+
+ var ops = array_ops.concat(new[] { new[] { -1 }, (object)last_dim_size }, 0);
+ var output = array_ops.reshape(logits, ops);
+
+ // Set output shape if known.
+ // if not context.executing_eagerly():
+ var shape = logits.TensorShape;
+ if(shape != null && shape.NDim > 0)
+ {
+ var product = 1;
+ var product_valid = true;
+ foreach(var d in shape.Dimensions.Take(shape.NDim - 1))
+ {
+ if(d == -1)
+ {
+ product_valid = false;
+ break;
+ }
+ else
+ {
+ product *= d;
+ }
+ }
+
+ if (product_valid)
+ {
+ var output_shape = new[] { product };
+ throw new NotImplementedException("_flatten_outer_dims product_valid");
+ }
+ }
+
+ return output;
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs
index 8918c03c..3ca6ed9b 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs
@@ -22,21 +22,11 @@ namespace Tensorflow
private int _id;
private Operation _op;
-#if GRAPH_SERIALIZE
- [JsonIgnore]
- public int Id => _id;
- [JsonIgnore]
- public Graph graph => op?.graph;
- [JsonIgnore]
- public Operation op => _op;
- [JsonIgnore]
- public Tensor[] outputs => op.outputs;
-#else
+
public int Id => _id;
public Graph graph => op?.graph;
public Operation op => _op;
public Tensor[] outputs => op.outputs;
-#endif
///
/// The string name of this tensor.
@@ -50,18 +40,12 @@ namespace Tensorflow
private TF_DataType _dtype = TF_DataType.DtInvalid;
public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle);
-#if GRAPH_SERIALIZE
- [JsonIgnore]
-#endif
+
public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle);
-#if GRAPH_SERIALIZE
- [JsonIgnore]
-#endif
+
public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype);
public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize;
-#if GRAPH_SERIALIZE
- [JsonIgnore]
-#endif
+
public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle);
public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out);
@@ -70,9 +54,6 @@ namespace Tensorflow
///
/// used for keep other pointer when do implicit operating
///
-#if GRAPH_SERIALIZE
- [JsonIgnore]
-#endif
public object Tag { get; set; }
public int[] shape
@@ -140,9 +121,7 @@ namespace Tensorflow
}
}
}
-#if GRAPH_SERIALIZE
- [JsonIgnore]
-#endif
+
public int NDims => rank;
public string Device => op.Device;
diff --git a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs
index 15557679..8364e86c 100644
--- a/src/TensorFlowNET.Core/Train/AdamOptimizer.cs
+++ b/src/TensorFlowNET.Core/Train/AdamOptimizer.cs
@@ -110,7 +110,7 @@ namespace Tensorflow.Train
var update_beta2 = beta2_power.assign(beta2_power * _beta2_t, use_locking: _use_locking);
operations.Add(update_beta1);
- operations.Add(update_beta1);
+ operations.Add(update_beta2);
});
return control_flow_ops.group(operations.ToArray(), name: name_scope);
diff --git a/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionNN.cs b/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionNN.cs
index 9e30e421..d57cda11 100644
--- a/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionNN.cs
+++ b/test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionNN.cs
@@ -49,8 +49,6 @@ namespace TensorFlowNET.Examples.ImageProcess
public Graph BuildGraph()
{
- var g = tf.Graph();
-
// Placeholders for inputs (x) and outputs(y)
x = tf.placeholder(tf.float32, shape: (-1, img_size_flat), name: "X");
y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: "Y");
@@ -60,7 +58,8 @@ namespace TensorFlowNET.Examples.ImageProcess
// Create a fully-connected layer with n_classes nodes as output layer
var output_logits = fc_layer(fc1, n_classes, "OUT", use_relu: false);
// Define the loss function, optimizer, and accuracy
- loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels: y, logits: output_logits), name: "loss");
+ var logits = tf.nn.softmax_cross_entropy_with_logits(labels: y, logits: output_logits);
+ loss = tf.reduce_mean(logits, name: "loss");
optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: "Adam-op").minimize(loss);
var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: "correct_pred");
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: "accuracy");
@@ -68,7 +67,7 @@ namespace TensorFlowNET.Examples.ImageProcess
// Network predictions
var cls_prediction = tf.argmax(output_logits, axis: 1, name: "predictions");
- return g;
+ return tf.get_default_graph();
}
private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true)
@@ -93,16 +92,10 @@ namespace TensorFlowNET.Examples.ImageProcess
return layer;
}
- public Graph ImportGraph()
- {
- throw new NotImplementedException();
- }
-
- public bool Predict()
- {
- throw new NotImplementedException();
- }
+ public Graph ImportGraph() => throw new NotImplementedException();
+ public bool Predict() => throw new NotImplementedException();
+
public void PrepareData()
{
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true);
@@ -112,7 +105,6 @@ namespace TensorFlowNET.Examples.ImageProcess
{
// Number of training iterations in each epoch
var num_tr_iter = mnist.train.labels.len / batch_size;
-
return with(tf.Session(), sess =>
{
var init = tf.global_variables_initializer();
@@ -153,10 +145,9 @@ namespace TensorFlowNET.Examples.ImageProcess
print("---------------------------------------------------------");
print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}");
print("---------------------------------------------------------");
-
}
- return accuracy_val > 0.9;
+ return accuracy_val > 0.95;
});
}
diff --git a/test/TensorFlowNET.Examples/python/neural_network.py b/test/TensorFlowNET.Examples/python/neural_network.py
new file mode 100644
index 00000000..ac9597ca
--- /dev/null
+++ b/test/TensorFlowNET.Examples/python/neural_network.py
@@ -0,0 +1,164 @@
+
+# imports
+import tensorflow as tf
+import numpy as np
+import matplotlib.pyplot as plt
+
+img_h = img_w = 28 # MNIST images are 28x28
+img_size_flat = img_h * img_w # 28x28=784, the total number of pixels
+n_classes = 10 # Number of classes, one class per digit
+
+def load_data(mode='train'):
+ """
+ Function to (download and) load the MNIST data
+ :param mode: train or test
+ :return: images and the corresponding labels
+ """
+ from tensorflow.examples.tutorials.mnist import input_data
+ mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
+ if mode == 'train':
+ x_train, y_train, x_valid, y_valid = mnist.train.images, mnist.train.labels, \
+ mnist.validation.images, mnist.validation.labels
+ return x_train, y_train, x_valid, y_valid
+ elif mode == 'test':
+ x_test, y_test = mnist.test.images, mnist.test.labels
+ return x_test, y_test
+
+def randomize(x, y):
+ """ Randomizes the order of data samples and their corresponding labels"""
+ permutation = np.random.permutation(y.shape[0])
+ shuffled_x = x[permutation, :]
+ shuffled_y = y[permutation]
+ return shuffled_x, shuffled_y
+
+def get_next_batch(x, y, start, end):
+ x_batch = x[start:end]
+ y_batch = y[start:end]
+ return x_batch, y_batch
+
+# Load MNIST data
+x_train, y_train, x_valid, y_valid = load_data(mode='train')
+print("Size of:")
+print("- Training-set:\t\t{}".format(len(y_train)))
+print("- Validation-set:\t{}".format(len(y_valid)))
+
+print('x_train:\t{}'.format(x_train.shape))
+print('y_train:\t{}'.format(y_train.shape))
+print('x_train:\t{}'.format(x_valid.shape))
+print('y_valid:\t{}'.format(y_valid.shape))
+
+print(y_valid[:5, :])
+
+# Hyper-parameters
+epochs = 10 # Total number of training epochs
+batch_size = 100 # Training batch size
+display_freq = 100 # Frequency of displaying the training results
+learning_rate = 0.001 # The optimization initial learning rate
+
+h1 = 200 # number of nodes in the 1st hidden layer
+
+# weight and bais wrappers
+def weight_variable(name, shape):
+ """
+ Create a weight variable with appropriate initialization
+ :param name: weight name
+ :param shape: weight shape
+ :return: initialized weight variable
+ """
+ initer = tf.truncated_normal_initializer(stddev=0.01)
+ return tf.get_variable('W_' + name,
+ dtype=tf.float32,
+ shape=shape,
+ initializer=initer)
+
+
+def bias_variable(name, shape):
+ """
+ Create a bias variable with appropriate initialization
+ :param name: bias variable name
+ :param shape: bias variable shape
+ :return: initialized bias variable
+ """
+ initial = tf.constant(0., shape=shape, dtype=tf.float32)
+ return tf.get_variable('b_' + name,
+ dtype=tf.float32,
+ initializer=initial)
+
+def fc_layer(x, num_units, name, use_relu=True):
+ """
+ Create a fully-connected layer
+ :param x: input from previous layer
+ :param num_units: number of hidden units in the fully-connected layer
+ :param name: layer name
+ :param use_relu: boolean to add ReLU non-linearity (or not)
+ :return: The output array
+ """
+ in_dim = x.get_shape()[1]
+ W = weight_variable(name, shape=[in_dim, num_units])
+ b = bias_variable(name, [num_units])
+ layer = tf.matmul(x, W)
+ layer += b
+ if use_relu:
+ layer = tf.nn.relu(layer)
+ return layer
+
+# Create the graph for the linear model
+# Placeholders for inputs (x) and outputs(y)
+x = tf.placeholder(tf.float32, shape=[None, img_size_flat], name='X')
+y = tf.placeholder(tf.float32, shape=[None, n_classes], name='Y')
+
+# Create a fully-connected layer with h1 nodes as hidden layer
+fc1 = fc_layer(x, h1, 'FC1', use_relu=True)
+# Create a fully-connected layer with n_classes nodes as output layer
+output_logits = fc_layer(fc1, n_classes, 'OUT', use_relu=False)
+
+# Define the loss function, optimizer, and accuracy
+logits = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=output_logits)
+loss = tf.reduce_mean(logits, name='loss')
+optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, name='Adam-op').minimize(loss)
+correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name='correct_pred')
+accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy')
+
+# Network predictions
+cls_prediction = tf.argmax(output_logits, axis=1, name='predictions')
+
+# export graph
+#tf.train.export_meta_graph(filename='neural_network.meta', graph=tf.get_default_graph(), clear_extraneous_savers= True, as_text = True)
+
+# Create the op for initializing all variables
+init = tf.global_variables_initializer()
+
+# Create an interactive session (to keep the session in the other cells)
+sess = tf.InteractiveSession()
+# Initialize all variables
+sess.run(init)
+# Number of training iterations in each epoch
+num_tr_iter = int(len(y_train) / batch_size)
+for epoch in range(epochs):
+ print('Training epoch: {}'.format(epoch + 1))
+ # Randomly shuffle the training data at the beginning of each epoch
+ x_train, y_train = randomize(x_train, y_train)
+ for iteration in range(num_tr_iter):
+ start = iteration * batch_size
+ end = (iteration + 1) * batch_size
+ x_batch, y_batch = get_next_batch(x_train, y_train, start, end)
+
+ # Run optimization op (backprop)
+ feed_dict_batch = {x: x_batch, y: y_batch}
+ sess.run(optimizer, feed_dict=feed_dict_batch)
+
+ if iteration % display_freq == 0:
+ # Calculate and display the batch loss and accuracy
+ loss_batch, acc_batch = sess.run([loss, accuracy],
+ feed_dict=feed_dict_batch)
+
+ print("iter {0:3d}:\t Loss={1:.2f},\tTraining Accuracy={2:.01%}".
+ format(iteration, loss_batch, acc_batch))
+
+ # Run validation after every epoch
+ feed_dict_valid = {x: x_valid[:1000], y: y_valid[:1000]}
+ loss_valid, acc_valid = sess.run([loss, accuracy], feed_dict=feed_dict_valid)
+ print('---------------------------------------------------------')
+ print("Epoch: {0}, validation loss: {1:.2f}, validation accuracy: {2:.01%}".
+ format(epoch + 1, loss_valid, acc_valid))
+ print('---------------------------------------------------------')
\ No newline at end of file