@@ -142,13 +142,13 @@ Example runner will download all the required files like training data and model | |||||
* [Logistic Regression](test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs) | * [Logistic Regression](test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs) | ||||
* [Nearest Neighbor](test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs) | * [Nearest Neighbor](test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs) | ||||
* [Naive Bayes Classification](test/TensorFlowNET.Examples/BasicModels/NaiveBayesClassifier.cs) | * [Naive Bayes Classification](test/TensorFlowNET.Examples/BasicModels/NaiveBayesClassifier.cs) | ||||
* [Full Connected Neural Network](test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionNN.cs) | |||||
* [Image Recognition](test/TensorFlowNET.Examples/ImageProcess) | * [Image Recognition](test/TensorFlowNET.Examples/ImageProcess) | ||||
* [K-means Clustering](test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs) | * [K-means Clustering](test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs) | ||||
* [NN XOR](test/TensorFlowNET.Examples/BasicModels/NeuralNetXor.cs) | * [NN XOR](test/TensorFlowNET.Examples/BasicModels/NeuralNetXor.cs) | ||||
* [Object Detection](test/TensorFlowNET.Examples/ImageProcess/ObjectDetection.cs) | * [Object Detection](test/TensorFlowNET.Examples/ImageProcess/ObjectDetection.cs) | ||||
* [Text Classification](test/TensorFlowNET.Examples/TextProcess/BinaryTextClassification.cs) | * [Text Classification](test/TensorFlowNET.Examples/TextProcess/BinaryTextClassification.cs) | ||||
* [CNN Text Classification](test/TensorFlowNET.Examples/TextProcess/cnn_models/VdCnn.cs) | * [CNN Text Classification](test/TensorFlowNET.Examples/TextProcess/cnn_models/VdCnn.cs) | ||||
* [Named Entity Recognition](test/TensorFlowNET.Examples/TextProcess/NER) | * [Named Entity Recognition](test/TensorFlowNET.Examples/TextProcess/NER) | ||||
* [Transfer Learning for Image Classification in InceptionV3](test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs) | * [Transfer Learning for Image Classification in InceptionV3](test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs) | ||||
@@ -40,7 +40,7 @@ namespace Keras.Layers | |||||
var dot = tf.matmul(x, W); | var dot = tf.matmul(x, W); | ||||
if (this.activation != null) | if (this.activation != null) | ||||
dot = activation.Activate(dot); | dot = activation.Activate(dot); | ||||
Console.WriteLine("Calling Layer \"" + name + "(" + np.array(dot.GetShape().Dimensions).ToString() + ")\" ..."); | |||||
Console.WriteLine("Calling Layer \"" + name + "(" + np.array(dot.TensorShape.Dimensions).ToString() + ")\" ..."); | |||||
return dot; | return dot; | ||||
} | } | ||||
public TensorShape __shape__() | public TensorShape __shape__() | ||||
@@ -65,7 +65,7 @@ namespace Keras | |||||
#endregion | #endregion | ||||
#region Model Graph Form Layer Stack | #region Model Graph Form Layer Stack | ||||
var flow_shape = features.GetShape(); | |||||
var flow_shape = features.TensorShape; | |||||
Flow = features; | Flow = features; | ||||
for (int i = 0; i < layer_stack.Count; i++) | for (int i = 0; i < layer_stack.Count; i++) | ||||
{ | { | ||||
@@ -37,7 +37,7 @@ namespace Tensorflow.Framework | |||||
public static bool has_fully_defined_shape(Tensor tensor) | public static bool has_fully_defined_shape(Tensor tensor) | ||||
{ | { | ||||
return tensor.GetShape().is_fully_defined(); | |||||
return tensor.TensorShape.is_fully_defined(); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -161,7 +161,7 @@ namespace Tensorflow.Keras.Layers | |||||
if (_dtype == TF_DataType.DtInvalid) | if (_dtype == TF_DataType.DtInvalid) | ||||
_dtype = input.dtype; | _dtype = input.dtype; | ||||
var input_shapes = input.GetShape(); | |||||
var input_shapes = input.TensorShape; | |||||
build(input_shapes); | build(input_shapes); | ||||
built = true; | built = true; | ||||
} | } | ||||
@@ -118,8 +118,8 @@ namespace Tensorflow | |||||
if(weights > 0) | if(weights > 0) | ||||
{ | { | ||||
var weights_tensor = ops.convert_to_tensor(weights); | var weights_tensor = ops.convert_to_tensor(weights); | ||||
var labels_rank = labels.GetShape().NDim; | |||||
var weights_shape = weights_tensor.GetShape(); | |||||
var labels_rank = labels.TensorShape.NDim; | |||||
var weights_shape = weights_tensor.TensorShape; | |||||
var weights_rank = weights_shape.NDim; | var weights_rank = weights_shape.NDim; | ||||
if (labels_rank > -1 && weights_rank > -1) | if (labels_rank > -1 && weights_rank > -1) | ||||
@@ -18,7 +18,7 @@ namespace Tensorflow.Operations | |||||
string data_format = null) | string data_format = null) | ||||
{ | { | ||||
var dilation_rate_tensor = ops.convert_to_tensor(dilation_rate, TF_DataType.TF_INT32, name: "dilation_rate"); | var dilation_rate_tensor = ops.convert_to_tensor(dilation_rate, TF_DataType.TF_INT32, name: "dilation_rate"); | ||||
var rate_shape = dilation_rate_tensor.GetShape(); | |||||
var rate_shape = dilation_rate_tensor.TensorShape; | |||||
var num_spatial_dims = rate_shape.Dimensions[0]; | var num_spatial_dims = rate_shape.Dimensions[0]; | ||||
int starting_spatial_dim = -1; | int starting_spatial_dim = -1; | ||||
if (!string.IsNullOrEmpty(data_format) && data_format.StartsWith("NC")) | if (!string.IsNullOrEmpty(data_format) && data_format.StartsWith("NC")) | ||||
@@ -24,9 +24,9 @@ namespace Tensorflow | |||||
{ | { | ||||
predictions = ops.convert_to_tensor(predictions); | predictions = ops.convert_to_tensor(predictions); | ||||
labels = ops.convert_to_tensor(labels); | labels = ops.convert_to_tensor(labels); | ||||
var predictions_shape = predictions.GetShape(); | |||||
var predictions_shape = predictions.TensorShape; | |||||
var predictions_rank = predictions_shape.NDim; | var predictions_rank = predictions_shape.NDim; | ||||
var labels_shape = labels.GetShape(); | |||||
var labels_shape = labels.TensorShape; | |||||
var labels_rank = labels_shape.NDim; | var labels_rank = labels_shape.NDim; | ||||
if(labels_rank > -1 && predictions_rank > -1) | if(labels_rank > -1 && predictions_rank > -1) | ||||
{ | { | ||||
@@ -83,7 +83,7 @@ namespace Tensorflow | |||||
// float to be selected, hence we use a >= comparison. | // float to be selected, hence we use a >= comparison. | ||||
var keep_mask = random_tensor >= rate; | var keep_mask = random_tensor >= rate; | ||||
var ret = x * scale * math_ops.cast(keep_mask, x.dtype); | var ret = x * scale * math_ops.cast(keep_mask, x.dtype); | ||||
ret.SetShape(x.GetShape()); | |||||
ret.SetShape(x.TensorShape); | |||||
return ret; | return ret; | ||||
}); | }); | ||||
} | } | ||||
@@ -131,14 +131,14 @@ namespace Tensorflow | |||||
var precise_logits = logits.dtype == TF_DataType.TF_HALF ? math_ops.cast(logits, dtypes.float32) : logits; | var precise_logits = logits.dtype == TF_DataType.TF_HALF ? math_ops.cast(logits, dtypes.float32) : logits; | ||||
// Store label shape for result later. | // Store label shape for result later. | ||||
var labels_static_shape = labels.GetShape(); | |||||
var labels_static_shape = labels.TensorShape; | |||||
var labels_shape = array_ops.shape(labels); | var labels_shape = array_ops.shape(labels); | ||||
/*bool static_shapes_fully_defined = ( | /*bool static_shapes_fully_defined = ( | ||||
labels_static_shape.is_fully_defined() && | labels_static_shape.is_fully_defined() && | ||||
logits.get_shape()[:-1].is_fully_defined());*/ | logits.get_shape()[:-1].is_fully_defined());*/ | ||||
// Check if no reshapes are required. | // Check if no reshapes are required. | ||||
if(logits.GetShape().NDim == 2) | |||||
if(logits.TensorShape.NDim == 2) | |||||
{ | { | ||||
var (cost, _) = gen_nn_ops.sparse_softmax_cross_entropy_with_logits( | var (cost, _) = gen_nn_ops.sparse_softmax_cross_entropy_with_logits( | ||||
precise_logits, labels, name: name); | precise_logits, labels, name: name); | ||||
@@ -163,7 +163,7 @@ namespace Tensorflow | |||||
{ | { | ||||
var precise_logits = logits; | var precise_logits = logits; | ||||
var input_rank = array_ops.rank(precise_logits); | var input_rank = array_ops.rank(precise_logits); | ||||
var shape = logits.GetShape(); | |||||
var shape = logits.TensorShape; | |||||
if (axis != -1) | if (axis != -1) | ||||
throw new NotImplementedException("softmax_cross_entropy_with_logits_v2_helper axis != -1"); | throw new NotImplementedException("softmax_cross_entropy_with_logits_v2_helper axis != -1"); | ||||
@@ -16,8 +16,8 @@ namespace Tensorflow | |||||
weights, dtype: values.dtype.as_base_dtype(), name: "weights"); | weights, dtype: values.dtype.as_base_dtype(), name: "weights"); | ||||
// Try static check for exact match. | // Try static check for exact match. | ||||
var weights_shape = weights.GetShape(); | |||||
var values_shape = values.GetShape(); | |||||
var weights_shape = weights.TensorShape; | |||||
var values_shape = values.TensorShape; | |||||
if (weights_shape.is_fully_defined() && | if (weights_shape.is_fully_defined() && | ||||
values_shape.is_fully_defined()) | values_shape.is_fully_defined()) | ||||
return weights; | return weights; | ||||
@@ -1,4 +1,5 @@ | |||||
using System; | |||||
using NumSharp; | |||||
using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
@@ -110,10 +110,7 @@ namespace Tensorflow | |||||
return shape.Select(x => (int)x).ToArray(); | return shape.Select(x => (int)x).ToArray(); | ||||
} | } | ||||
public TensorShape GetShape() | |||||
{ | |||||
return tensor_util.to_shape(shape); | |||||
} | |||||
public TensorShape TensorShape => tensor_util.to_shape(shape); | |||||
public void SetShape(Shape shape) | public void SetShape(Shape shape) | ||||
{ | { | ||||
@@ -37,5 +37,7 @@ namespace Tensorflow | |||||
{ | { | ||||
throw new NotImplementedException("TensorShape is_compatible_with"); | throw new NotImplementedException("TensorShape is_compatible_with"); | ||||
} | } | ||||
public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); | |||||
} | } | ||||
} | } |
@@ -20,6 +20,15 @@ namespace Tensorflow | |||||
verify_shape: verify_shape, | verify_shape: verify_shape, | ||||
allow_broadcast: false); | allow_broadcast: false); | ||||
public static Tensor constant(float value, | |||||
int shape, | |||||
string name = "Const") => constant_op._constant_impl(value, | |||||
tf.float32, | |||||
new int[] { shape }, | |||||
name, | |||||
verify_shape: false, | |||||
allow_broadcast: false); | |||||
public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) => array_ops.zeros(shape, dtype, name); | public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) => array_ops.zeros(shape, dtype, name); | ||||
public static Tensor size(Tensor input, | public static Tensor size(Tensor input, | ||||
@@ -10,9 +10,11 @@ namespace Tensorflow | |||||
{ | { | ||||
public static class train | public static class train | ||||
{ | { | ||||
public static Optimizer GradientDescentOptimizer(float learning_rate) => new GradientDescentOptimizer(learning_rate); | |||||
public static Optimizer GradientDescentOptimizer(float learning_rate) | |||||
=> new GradientDescentOptimizer(learning_rate); | |||||
public static Optimizer AdamOptimizer(float learning_rate) => new AdamOptimizer(learning_rate); | |||||
public static Optimizer AdamOptimizer(float learning_rate, string name = null) | |||||
=> new AdamOptimizer(learning_rate, name: name); | |||||
public static Saver Saver(VariableV1[] var_list = null) => new Saver(var_list: var_list); | public static Saver Saver(VariableV1[] var_list = null) => new Saver(var_list: var_list); | ||||
@@ -153,7 +153,7 @@ namespace Tensorflow | |||||
// Manually overrides the variable's shape with the initial value's. | // Manually overrides the variable's shape with the initial value's. | ||||
if (validate_shape) | if (validate_shape) | ||||
{ | { | ||||
var initial_value_shape = _initial_value.GetShape(); | |||||
var initial_value_shape = _initial_value.TensorShape; | |||||
if (!initial_value_shape.is_fully_defined()) | if (!initial_value_shape.is_fully_defined()) | ||||
throw new ValueError($"initial_value must have a shape specified: {_initial_value}"); | throw new ValueError($"initial_value must have a shape specified: {_initial_value}"); | ||||
} | } | ||||
@@ -15,8 +15,8 @@ namespace Keras.Test | |||||
{ | { | ||||
var dense_1 = new Dense(1, name: "dense_1", activation: tf.nn.relu()); | var dense_1 = new Dense(1, name: "dense_1", activation: tf.nn.relu()); | ||||
var input = new Tensor(np.array(new int[] { 3 })); | var input = new Tensor(np.array(new int[] { 3 })); | ||||
dense_1.__build__(input.GetShape()); | |||||
var outputShape = dense_1.output_shape(input.GetShape()); | |||||
dense_1.__build__(input.TensorShape); | |||||
var outputShape = dense_1.output_shape(input.TensorShape); | |||||
var a = (int[])(outputShape.Dimensions); | var a = (int[])(outputShape.Dimensions); | ||||
var b = (int[])(new int[] { 1 }); | var b = (int[])(new int[] { 1 }); | ||||
var _a = np.array(a); | var _a = np.array(a); | ||||
@@ -1,14 +1,17 @@ | |||||
using System; | |||||
using NumSharp; | |||||
using System; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow; | using Tensorflow; | ||||
using TensorFlowNET.Examples.Utility; | using TensorFlowNET.Examples.Utility; | ||||
using static Tensorflow.Python; | |||||
namespace TensorFlowNET.Examples.ImageProcess | namespace TensorFlowNET.Examples.ImageProcess | ||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// Neural Network classifier for Hand Written Digits | /// Neural Network classifier for Hand Written Digits | ||||
/// Sample Neural Network architecture with two layers implemented for classifying MNIST digits | |||||
/// Sample Neural Network architecture with two layers implemented for classifying MNIST digits. | |||||
/// Use Stochastic Gradient Descent (SGD) optimizer. | |||||
/// http://www.easy-tensorflow.com/tf-tutorials/neural-networks | /// http://www.easy-tensorflow.com/tf-tutorials/neural-networks | ||||
/// </summary> | /// </summary> | ||||
public class DigitRecognitionNN : IExample | public class DigitRecognitionNN : IExample | ||||
@@ -22,24 +25,74 @@ namespace TensorFlowNET.Examples.ImageProcess | |||||
const int img_w = 28; | const int img_w = 28; | ||||
int img_size_flat = img_h * img_w; // 784, the total number of pixels | int img_size_flat = img_h * img_w; // 784, the total number of pixels | ||||
int n_classes = 10; // Number of classes, one class per digit | int n_classes = 10; // Number of classes, one class per digit | ||||
int training_epochs = 10; | |||||
int? train_size = null; | |||||
int validation_size = 5000; | |||||
int? test_size = null; | |||||
// Hyper-parameters | |||||
int epochs = 10; | |||||
int batch_size = 100; | int batch_size = 100; | ||||
float learning_rate = 0.001f; | |||||
int h1 = 200; // number of nodes in the 1st hidden layer | |||||
Datasets mnist; | Datasets mnist; | ||||
Tensor x, y; | |||||
Tensor loss, accuracy; | |||||
Operation optimizer; | |||||
int display_freq = 100; | |||||
public bool Run() | public bool Run() | ||||
{ | { | ||||
PrepareData(); | PrepareData(); | ||||
BuildGraph(); | |||||
Train(); | |||||
return true; | return true; | ||||
} | } | ||||
public Graph BuildGraph() | public Graph BuildGraph() | ||||
{ | { | ||||
throw new NotImplementedException(); | |||||
var g = tf.Graph(); | |||||
// Placeholders for inputs (x) and outputs(y) | |||||
x = tf.placeholder(tf.float32, shape: (-1, img_size_flat), name: "X"); | |||||
y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: "Y"); | |||||
// Create a fully-connected layer with h1 nodes as hidden layer | |||||
var fc1 = fc_layer(x, h1, "FC1", use_relu: true); | |||||
// Create a fully-connected layer with n_classes nodes as output layer | |||||
var output_logits = fc_layer(fc1, n_classes, "OUT", use_relu: false); | |||||
// Define the loss function, optimizer, and accuracy | |||||
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels: y, logits: output_logits), name: "loss"); | |||||
optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: "Adam-op").minimize(loss); | |||||
var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: "correct_pred"); | |||||
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: "accuracy"); | |||||
// Network predictions | |||||
var cls_prediction = tf.argmax(output_logits, axis: 1, name: "predictions"); | |||||
return g; | |||||
} | } | ||||
private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true) | |||||
{ | |||||
var in_dim = x.shape[1]; | |||||
var initer = tf.truncated_normal_initializer(stddev: 0.01f); | |||||
var W = tf.get_variable("W_" + name, | |||||
dtype: tf.float32, | |||||
shape: (in_dim, num_units), | |||||
initializer: initer); | |||||
var initial = tf.constant(0f, num_units); | |||||
var b = tf.get_variable("b_" + name, | |||||
dtype: tf.float32, | |||||
initializer: initial); | |||||
var layer = tf.matmul(x, W) + b; | |||||
if (use_relu) | |||||
layer = tf.nn.relu(layer); | |||||
return layer; | |||||
} | |||||
public Graph ImportGraph() | public Graph ImportGraph() | ||||
{ | { | ||||
throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
@@ -52,12 +105,82 @@ namespace TensorFlowNET.Examples.ImageProcess | |||||
public void PrepareData() | public void PrepareData() | ||||
{ | { | ||||
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: train_size, validation_size: validation_size, test_size: test_size); | |||||
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true); | |||||
} | } | ||||
public bool Train() | public bool Train() | ||||
{ | { | ||||
throw new NotImplementedException(); | |||||
// Number of training iterations in each epoch | |||||
var num_tr_iter = mnist.train.labels.len / batch_size; | |||||
return with(tf.Session(), sess => | |||||
{ | |||||
var init = tf.global_variables_initializer(); | |||||
sess.run(init); | |||||
float loss_val = 100.0f; | |||||
float accuracy_val = 0f; | |||||
foreach (var epoch in range(epochs)) | |||||
{ | |||||
print($"Training epoch: {epoch + 1}"); | |||||
// Randomly shuffle the training data at the beginning of each epoch | |||||
var (x_train, y_train) = randomize(mnist.train.images, mnist.train.labels); | |||||
foreach (var iteration in range(num_tr_iter)) | |||||
{ | |||||
var start = iteration * batch_size; | |||||
var end = (iteration + 1) * batch_size; | |||||
var (x_batch, y_batch) = get_next_batch(x_train, y_train, start, end); | |||||
// Run optimization op (backprop) | |||||
sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch)); | |||||
if (iteration % display_freq == 0) | |||||
{ | |||||
// Calculate and display the batch loss and accuracy | |||||
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_batch), new FeedItem(y, y_batch)); | |||||
loss_val = result[0]; | |||||
accuracy_val = result[1]; | |||||
print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}"); | |||||
} | |||||
} | |||||
// Run validation after every epoch | |||||
var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.validation.images), new FeedItem(y, mnist.validation.labels)); | |||||
loss_val = results1[0]; | |||||
accuracy_val = results1[1]; | |||||
print("---------------------------------------------------------"); | |||||
print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}"); | |||||
print("---------------------------------------------------------"); | |||||
} | |||||
return accuracy_val > 0.9; | |||||
}); | |||||
} | |||||
private (NDArray, NDArray) randomize(NDArray x, NDArray y) | |||||
{ | |||||
var perm = np.random.permutation(y.shape[0]); | |||||
np.random.shuffle(perm); | |||||
return (mnist.train.images[perm], mnist.train.labels[perm]); | |||||
} | |||||
/// <summary> | |||||
/// selects a few number of images determined by the batch_size variable (if you don't know why, read about Stochastic Gradient Method) | |||||
/// </summary> | |||||
/// <param name="x"></param> | |||||
/// <param name="y"></param> | |||||
/// <param name="start"></param> | |||||
/// <param name="end"></param> | |||||
/// <returns></returns> | |||||
private (NDArray, NDArray) get_next_batch(NDArray x, NDArray y, int start, int end) | |||||
{ | |||||
var x_batch = x[$"{start}:{end}"]; | |||||
var y_batch = y[$"{start}:{end}"]; | |||||
return (x_batch, y_batch); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -264,12 +264,12 @@ namespace TensorFlowNET.Examples.ImageProcess | |||||
private (Operation, Tensor, Tensor, Tensor, Tensor) add_final_retrain_ops(int class_count, string final_tensor_name, | private (Operation, Tensor, Tensor, Tensor, Tensor) add_final_retrain_ops(int class_count, string final_tensor_name, | ||||
Tensor bottleneck_tensor, bool quantize_layer, bool is_training) | Tensor bottleneck_tensor, bool quantize_layer, bool is_training) | ||||
{ | { | ||||
var (batch_size, bottleneck_tensor_size) = (bottleneck_tensor.GetShape().Dimensions[0], bottleneck_tensor.GetShape().Dimensions[1]); | |||||
var (batch_size, bottleneck_tensor_size) = (bottleneck_tensor.TensorShape.Dimensions[0], bottleneck_tensor.TensorShape.Dimensions[1]); | |||||
with(tf.name_scope("input"), scope => | with(tf.name_scope("input"), scope => | ||||
{ | { | ||||
bottleneck_input = tf.placeholder_with_default( | bottleneck_input = tf.placeholder_with_default( | ||||
bottleneck_tensor, | bottleneck_tensor, | ||||
shape: bottleneck_tensor.GetShape().Dimensions, | |||||
shape: bottleneck_tensor.TensorShape.Dimensions, | |||||
name: "BottleneckInputPlaceholder"); | name: "BottleneckInputPlaceholder"); | ||||
ground_truth_input = tf.placeholder(tf.int64, new TensorShape(batch_size), name: "GroundTruthInput"); | ground_truth_input = tf.placeholder(tf.int64, new TensorShape(batch_size), name: "GroundTruthInput"); | ||||