From 2584b684e4ba978d43da0ec645d9b238871870fe Mon Sep 17 00:00:00 2001 From: Haiping Chen Date: Fri, 30 Dec 2022 09:18:55 -0600 Subject: [PATCH] Fix conv_net.load_weights #956 --- .../Engine/Layer.Layers.cs | 2 +- src/TensorFlowNET.Keras/Engine/Model.Build.cs | 28 ++++ src/python/subclassing.py | 154 ++++++++++++++++++ 3 files changed, 183 insertions(+), 1 deletion(-) create mode 100644 src/TensorFlowNET.Keras/Engine/Model.Build.cs create mode 100644 src/python/subclassing.py diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs b/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs index f38750c2..488c55cb 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.Layers.cs @@ -10,7 +10,7 @@ namespace Tensorflow.Keras.Engine protected void StackLayers(params ILayer[] layers) { - _layers.AddRange(layers); + _self_tracked_trackables.AddRange(layers); } public virtual Shape ComputeOutputShape(Shape input_shape) diff --git a/src/TensorFlowNET.Keras/Engine/Model.Build.cs b/src/TensorFlowNET.Keras/Engine/Model.Build.cs new file mode 100644 index 00000000..1e0a880a --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Model.Build.cs @@ -0,0 +1,28 @@ +using System; +using System.Linq; +using Tensorflow.Graphs; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Optimizers; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.Engine +{ + public partial class Model + { + public override void build(Shape input_shape) + { + var graph = tf.executing_eagerly() ? new FuncGraph("build_graph") : keras.backend.get_graph(); + + graph.as_default(); + + var x = tf.placeholder(DType, input_shape); + Call(x, training: false); + + graph.Exit(); + + base.build(input_shape); + } + } +} diff --git a/src/python/subclassing.py b/src/python/subclassing.py new file mode 100644 index 00000000..fe2674ec --- /dev/null +++ b/src/python/subclassing.py @@ -0,0 +1,154 @@ +from __future__ import absolute_import, division, print_function + +import tensorflow as tf +from tensorflow.keras import Model, layers +import numpy as np + +# MNIST dataset parameters. +num_classes = 10 # total classes (0-9 digits). + +# Training parameters. +learning_rate = 0.001 +training_steps = 100 +batch_size = 128 +display_step = 10 + +# Network parameters. +conv1_filters = 32 # number of filters for 1st conv layer. +conv2_filters = 64 # number of filters for 2nd conv layer. +fc1_units = 1024 # number of neurons for 1st fully-connected layer. + +# Prepare MNIST data. +from tensorflow.keras.datasets import mnist +(x_train, y_train), (x_test, y_test) = mnist.load_data() +# Convert to float32. +x_train, x_test = np.array(x_train, np.float32), np.array(x_test, np.float32) +# Normalize images value from [0, 255] to [0, 1]. +x_train, x_test = x_train / 255., x_test / 255. + +# Use tf.data API to shuffle and batch data. +train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train)) +train_data = train_data.repeat().shuffle(5000).batch(batch_size).prefetch(1) + +# Create TF Model. +class ConvNet(Model): + # Set layers. + def __init__(self): + super(ConvNet, self).__init__() + # Convolution Layer with 32 filters and a kernel size of 5. + self.conv1 = layers.Conv2D(32, kernel_size=5, activation=tf.nn.relu) + # Max Pooling (down-sampling) with kernel size of 2 and strides of 2. + self.maxpool1 = layers.MaxPool2D(2, strides=2) + + # Convolution Layer with 64 filters and a kernel size of 3. + self.conv2 = layers.Conv2D(64, kernel_size=3, activation=tf.nn.relu) + # Max Pooling (down-sampling) with kernel size of 2 and strides of 2. + self.maxpool2 = layers.MaxPool2D(2, strides=2) + + # Flatten the data to a 1-D vector for the fully connected layer. + self.flatten = layers.Flatten() + + # Fully connected layer. + self.fc1 = layers.Dense(1024) + # Apply Dropout (if is_training is False, dropout is not applied). + self.dropout = layers.Dropout(rate=0.5) + + # Output layer, class prediction. + self.out = layers.Dense(num_classes) + + # Set forward pass. + def call(self, x, is_training=False): + x = tf.reshape(x, [-1, 28, 28, 1]) + x = self.conv1(x) + x = self.maxpool1(x) + x = self.conv2(x) + x = self.maxpool2(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.dropout(x, training=is_training) + x = self.out(x) + if not is_training: + # tf cross entropy expect logits without softmax, so only + # apply softmax when not training. + x = tf.nn.softmax(x) + return x +''' +# Build neural network model. +conv_net = ConvNet() + +# Cross-Entropy Loss. +# Note that this will apply 'softmax' to the logits. +def cross_entropy_loss(x, y): + # Convert labels to int 64 for tf cross-entropy function. + y = tf.cast(y, tf.int64) + # Apply softmax to logits and compute cross-entropy. + loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=x) + # Average loss across the batch. + return tf.reduce_mean(loss) + +# Accuracy metric. +def accuracy(y_pred, y_true): + # Predicted class is the index of highest score in prediction vector (i.e. argmax). + correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.cast(y_true, tf.int64)) + return tf.reduce_mean(tf.cast(correct_prediction, tf.float32), axis=-1) + +# Stochastic gradient descent optimizer. +optimizer = tf.optimizers.Adam(learning_rate) + +# Optimization process. +def run_optimization(x, y): + # Wrap computation inside a GradientTape for automatic differentiation. + with tf.GradientTape() as g: + # Forward pass. + pred = conv_net(x, is_training=True) + # Compute loss. + loss = cross_entropy_loss(pred, y) + + # Variables to update, i.e. trainable variables. + trainable_variables = conv_net.trainable_variables + + # Compute gradients. + gradients = g.gradient(loss, trainable_variables) + + # Update W and b following gradients. + optimizer.apply_gradients(zip(gradients, trainable_variables)) + +# Run training for the given number of steps. + +for step, (batch_x, batch_y) in enumerate(train_data.take(training_steps), 1): + # Run the optimization to update W and b values. + run_optimization(batch_x, batch_y) + + if step % display_step == 0: + pred = conv_net(batch_x) + loss = cross_entropy_loss(pred, batch_y) + acc = accuracy(pred, batch_y) + print("step: %i, loss: %f, accuracy: %f" % (step, loss, acc)) + +# Test model on validation set. +pred = conv_net(x_test) +print("Test Accuracy: %f" % accuracy(pred, y_test)) + +conv_net.save_weights('weights.h5') +''' + +conv_net = ConvNet() +conv_net.build(x_test.shape) +conv_net.load_weights('weights.h5') +# Test model on validation set. +pred = conv_net(x_test) +# print("Test Accuracy: %f" % accuracy(pred, y_test)) + +# Visualize predictions. +import matplotlib.pyplot as plt + +# Predict 5 images from validation set. +n_images = 5 +test_images = x_test[:n_images] +predictions = conv_net(test_images) + +# Display image and model prediction. +for i in range(n_images): + plt.imshow(np.reshape(test_images[i], [28, 28]), cmap='gray') + plt.show() + print("Model prediction: %i" % np.argmax(predictions.numpy()[i])) \ No newline at end of file