Browse Source

Fix conv_net.load_weights #956

tags/v0.100.4-load-saved-model
Haiping Chen 2 years ago
parent
commit
2584b684e4
3 changed files with 183 additions and 1 deletions
  1. +1
    -1
      src/TensorFlowNET.Keras/Engine/Layer.Layers.cs
  2. +28
    -0
      src/TensorFlowNET.Keras/Engine/Model.Build.cs
  3. +154
    -0
      src/python/subclassing.py

+ 1
- 1
src/TensorFlowNET.Keras/Engine/Layer.Layers.cs View File

@@ -10,7 +10,7 @@ namespace Tensorflow.Keras.Engine

protected void StackLayers(params ILayer[] layers)
{
_layers.AddRange(layers);
_self_tracked_trackables.AddRange(layers);
}

public virtual Shape ComputeOutputShape(Shape input_shape)


+ 28
- 0
src/TensorFlowNET.Keras/Engine/Model.Build.cs View File

@@ -0,0 +1,28 @@
using System;
using System.Linq;
using Tensorflow.Graphs;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Optimizers;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.Engine
{
public partial class Model
{
public override void build(Shape input_shape)
{
var graph = tf.executing_eagerly() ? new FuncGraph("build_graph") : keras.backend.get_graph();

graph.as_default();

var x = tf.placeholder(DType, input_shape);
Call(x, training: false);

graph.Exit();

base.build(input_shape);
}
}
}

+ 154
- 0
src/python/subclassing.py View File

@@ -0,0 +1,154 @@
from __future__ import absolute_import, division, print_function

import tensorflow as tf
from tensorflow.keras import Model, layers
import numpy as np

# MNIST dataset parameters.
num_classes = 10 # total classes (0-9 digits).

# Training parameters.
learning_rate = 0.001
training_steps = 100
batch_size = 128
display_step = 10

# Network parameters.
conv1_filters = 32 # number of filters for 1st conv layer.
conv2_filters = 64 # number of filters for 2nd conv layer.
fc1_units = 1024 # number of neurons for 1st fully-connected layer.

# Prepare MNIST data.
from tensorflow.keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# Convert to float32.
x_train, x_test = np.array(x_train, np.float32), np.array(x_test, np.float32)
# Normalize images value from [0, 255] to [0, 1].
x_train, x_test = x_train / 255., x_test / 255.

# Use tf.data API to shuffle and batch data.
train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_data = train_data.repeat().shuffle(5000).batch(batch_size).prefetch(1)

# Create TF Model.
class ConvNet(Model):
# Set layers.
def __init__(self):
super(ConvNet, self).__init__()
# Convolution Layer with 32 filters and a kernel size of 5.
self.conv1 = layers.Conv2D(32, kernel_size=5, activation=tf.nn.relu)
# Max Pooling (down-sampling) with kernel size of 2 and strides of 2.
self.maxpool1 = layers.MaxPool2D(2, strides=2)

# Convolution Layer with 64 filters and a kernel size of 3.
self.conv2 = layers.Conv2D(64, kernel_size=3, activation=tf.nn.relu)
# Max Pooling (down-sampling) with kernel size of 2 and strides of 2.
self.maxpool2 = layers.MaxPool2D(2, strides=2)

# Flatten the data to a 1-D vector for the fully connected layer.
self.flatten = layers.Flatten()

# Fully connected layer.
self.fc1 = layers.Dense(1024)
# Apply Dropout (if is_training is False, dropout is not applied).
self.dropout = layers.Dropout(rate=0.5)

# Output layer, class prediction.
self.out = layers.Dense(num_classes)

# Set forward pass.
def call(self, x, is_training=False):
x = tf.reshape(x, [-1, 28, 28, 1])
x = self.conv1(x)
x = self.maxpool1(x)
x = self.conv2(x)
x = self.maxpool2(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.dropout(x, training=is_training)
x = self.out(x)
if not is_training:
# tf cross entropy expect logits without softmax, so only
# apply softmax when not training.
x = tf.nn.softmax(x)
return x
'''
# Build neural network model.
conv_net = ConvNet()

# Cross-Entropy Loss.
# Note that this will apply 'softmax' to the logits.
def cross_entropy_loss(x, y):
# Convert labels to int 64 for tf cross-entropy function.
y = tf.cast(y, tf.int64)
# Apply softmax to logits and compute cross-entropy.
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=x)
# Average loss across the batch.
return tf.reduce_mean(loss)

# Accuracy metric.
def accuracy(y_pred, y_true):
# Predicted class is the index of highest score in prediction vector (i.e. argmax).
correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.cast(y_true, tf.int64))
return tf.reduce_mean(tf.cast(correct_prediction, tf.float32), axis=-1)

# Stochastic gradient descent optimizer.
optimizer = tf.optimizers.Adam(learning_rate)

# Optimization process.
def run_optimization(x, y):
# Wrap computation inside a GradientTape for automatic differentiation.
with tf.GradientTape() as g:
# Forward pass.
pred = conv_net(x, is_training=True)
# Compute loss.
loss = cross_entropy_loss(pred, y)
# Variables to update, i.e. trainable variables.
trainable_variables = conv_net.trainable_variables

# Compute gradients.
gradients = g.gradient(loss, trainable_variables)
# Update W and b following gradients.
optimizer.apply_gradients(zip(gradients, trainable_variables))

# Run training for the given number of steps.

for step, (batch_x, batch_y) in enumerate(train_data.take(training_steps), 1):
# Run the optimization to update W and b values.
run_optimization(batch_x, batch_y)
if step % display_step == 0:
pred = conv_net(batch_x)
loss = cross_entropy_loss(pred, batch_y)
acc = accuracy(pred, batch_y)
print("step: %i, loss: %f, accuracy: %f" % (step, loss, acc))

# Test model on validation set.
pred = conv_net(x_test)
print("Test Accuracy: %f" % accuracy(pred, y_test))

conv_net.save_weights('weights.h5')
'''

conv_net = ConvNet()
conv_net.build(x_test.shape)
conv_net.load_weights('weights.h5')
# Test model on validation set.
pred = conv_net(x_test)
# print("Test Accuracy: %f" % accuracy(pred, y_test))

# Visualize predictions.
import matplotlib.pyplot as plt

# Predict 5 images from validation set.
n_images = 5
test_images = x_test[:n_images]
predictions = conv_net(test_images)

# Display image and model prediction.
for i in range(n_images):
plt.imshow(np.reshape(test_images[i], [28, 28]), cmap='gray')
plt.show()
print("Model prediction: %i" % np.argmax(predictions.numpy()[i]))

Loading…
Cancel
Save