Browse Source

Examples: added NeuralNetXor

tags/v0.9
Meinrad Recheis 6 years ago
parent
commit
ad1f1e43bc
7 changed files with 180 additions and 5 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/APIs/tf.math.cs
  2. +3
    -1
      src/TensorFlowNET.Core/APIs/tf.nn.cs
  3. +8
    -0
      src/TensorFlowNET.Core/APIs/tf.random.cs
  4. +56
    -0
      src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
  5. +96
    -0
      test/TensorFlowNET.Examples/NeuralNetXor.cs
  6. +4
    -4
      test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs
  7. +10
    -0
      test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs

+ 3
- 0
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -298,5 +298,8 @@ namespace Tensorflow

public static Tensor argmax(Tensor input, int axis = -1, string name = null, int? dimension = null, TF_DataType output_type = TF_DataType.TF_INT64)
=> gen_math_ops.arg_max(input, axis, name: name, output_type: output_type);

public static Tensor square(Tensor x, string name = null)
=> gen_math_ops.square(x, name: name);
}
}

+ 3
- 1
src/TensorFlowNET.Core/APIs/tf.nn.cs View File

@@ -26,7 +26,9 @@ namespace Tensorflow
partition_strategy: partition_strategy,
name: name);

public static IActivation relu => new relu();
public static IActivation relu() => new relu();

public static Tensor relu(Tensor features, string name = null) => gen_nn_ops.relu(features, name);

public static Tensor[] fused_batch_norm(Tensor x,
RefVariable scale,


+ 8
- 0
src/TensorFlowNET.Core/APIs/tf.random.cs View File

@@ -29,5 +29,13 @@ namespace Tensorflow
TF_DataType dtype = TF_DataType.TF_FLOAT,
int? seed = null,
string name = null) => random_ops.random_uniform(shape, minval, maxval, dtype, seed, name);

public static Tensor truncated_normal(int[] shape,
float mean = 0.0f,
float stddev = 1.0f,
TF_DataType dtype = TF_DataType.TF_FLOAT,
int? seed = null,
string name = null)
=> random_ops.truncated_normal(shape, mean, stddev, dtype, seed, name);
}
}

+ 56
- 0
src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs View File

@@ -1,6 +1,8 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Operations.Activation;

namespace Tensorflow.Operations
{
@@ -173,5 +175,59 @@ namespace Tensorflow.Operations

return (_op.outputs[0], _op.outputs[1]);
}

/// <summary>
/// Computes rectified linear: `max(features, 0)`.
/// </summary>
/// <param name="features">A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `int64`, `bfloat16`, `uint16`, `half`, `uint32`, `uint64`, `qint8`.</param>
/// <param name="name">A name for the operation (optional).</param>
/// <returns>A `Tensor`. Has the same type as `features`.</returns>
public static Tensor relu(Tensor features, string name = null)
{
//_ctx = _context._context
//if _ctx is not None and _ctx._eager_context.is_eager:
// try:
// _result = _pywrap_tensorflow.TFE_Py_FastPathExecute(
// _ctx._context_handle, _ctx._eager_context.device_name, "Relu", name,
// _ctx._post_execution_callbacks, features)
// return _result
// except _core._FallbackException:
// try:
// return relu_eager_fallback(
// features, name=name, ctx=_ctx)
// except _core._SymbolicException:
// pass # Add nodes to the TensorFlow graph.
// except (TypeError, ValueError):
// result = _dispatch.dispatch(
// relu, features=features, name=name)
// if result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
// return result
// raise
// except _core._NotOkStatusException as e:
// if name is not None:
// message = e.message + " name: " + name
// else:
// message = e.message
// _six.raise_from(_core._status_to_exception(e.code, message), None)
//# Add nodes to the TensorFlow graph.
//try:
OpDefLibrary _op_def_lib = new OpDefLibrary();
var _op = _op_def_lib._apply_op_helper("Relu", name: name, args: new { features });
return _op.outputs[0];
//except (TypeError, ValueError):
// result = _dispatch.dispatch(
// relu, features=features, name=name)
// if result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
// return result
// raise
// var _result = _op.outputs.ToArray();
//_inputs_flat = _op.inputs
//_attrs = ("T", _op.get_attr("T"))
//_execute.record_gradient(
// "Relu", _inputs_flat, _attrs, _result, name)
//_result, = _result
// return _result;
}
}
}

+ 96
- 0
test/TensorFlowNET.Examples/NeuralNetXor.cs View File

@@ -0,0 +1,96 @@
using System;
using System.Collections.Generic;
using System.Text;
using NumSharp;
using Tensorflow;

namespace TensorFlowNET.Examples
{
/// <summary>
/// Simple vanilla neural net solving the famous XOR problem
/// https://github.com/amygdala/tensorflow-workshop/blob/master/workshop_sections/getting_started/xor/README.md
/// </summary>
public class NeuralNetXor : Python, IExample
{
public int Priority => 2;
public bool Enabled { get; set; } = true;
public string Name => "NN XOR";

public int num_steps = 5000;

private (Operation, Tensor, RefVariable) make_graph(Tensor features,Tensor labels, int num_hidden = 8)
{
var stddev = 1 / Math.Sqrt(2);
var hidden_weights = tf.Variable(tf.truncated_normal(new int []{2, num_hidden}, stddev: (float) stddev ));

// Shape [4, num_hidden]
var hidden_activations = tf.nn.relu(tf.matmul(features, hidden_weights));

var output_weights = tf.Variable(tf.truncated_normal(
new[] {num_hidden, 1},
stddev: (float) (1 / Math.Sqrt(num_hidden))
));

// Shape [4, 1]
var logits = tf.matmul(hidden_activations, output_weights);

// Shape [4]
var predictions = tf.sigmoid(tf.squeeze(logits));
var loss = tf.reduce_mean(tf.square(predictions - tf.cast(labels, tf.float32)));

var gs = tf.Variable(0, trainable: false);
var train_op = tf.train.GradientDescentOptimizer(0.2f).minimize(loss, global_step: gs);

return (train_op, loss, gs);
}

public bool Run()
{

var graph = tf.Graph();

var init=with(graph.as_default(), g =>
{
var features = tf.placeholder(tf.float32, new TensorShape(4, 2));
var labels = tf.placeholder(tf.int32, new TensorShape(4));

var (train_op, loss, gs) = make_graph(features, labels);
return tf.global_variables_initializer();
});

// Start tf session
with<Session>(tf.Session(), sess =>
{
init.run();
var step = 0;
var xy = np.array(new bool[,]
{
{true, false, },
{true, true, },
{false, false, },
{false, true, },
}, dtype: np.float32);

var y_ = np.array(new[] {true, false, false, true}, dtype: np.int32);
while (step < num_steps)
{
// original python:
//_, step, loss_value = sess.run(
// [train_op, gs, loss],
// feed_dict={features: xy, labels: y_}
// )
// TODO: how the hell to port that to c#?
// var ( _, step, loss_value) = sess.run(new object[] {train_op, gs, loss},feed_dict: new {"features": xy, "labels": y_});
}
//tf.logging.info('Final loss is: {}'.format(loss_value))
//Console.WriteLine($"Final loss is: {loss_value}");

});
return true;
}

public void PrepareData()
{
}
}
}

+ 4
- 4
test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs View File

@@ -66,7 +66,7 @@ namespace TensorFlowNET.Examples.TextClassification
filters: num_filters[0],
kernel_size: new int[] { filter_sizes[0], embedding_size },
kernel_initializer: cnn_initializer,
activation: tf.nn.relu);
activation: tf.nn.relu());

conv0 = tf.transpose(conv0, new int[] { 0, 1, 3, 2 });
});
@@ -99,12 +99,12 @@ namespace TensorFlowNET.Examples.TextClassification
// ============= Fully Connected Layers =============
with(tf.name_scope("fc-1"), scope =>
{
fc1_out = tf.layers.dense(h_flat, 2048, activation: tf.nn.relu, kernel_initializer: fc_initializer);
fc1_out = tf.layers.dense(h_flat, 2048, activation: tf.nn.relu(), kernel_initializer: fc_initializer);
});

with(tf.name_scope("fc-2"), scope =>
{
fc2_out = tf.layers.dense(fc1_out, 2048, activation: tf.nn.relu, kernel_initializer: fc_initializer);
fc2_out = tf.layers.dense(fc1_out, 2048, activation: tf.nn.relu(), kernel_initializer: fc_initializer);
});

with(tf.name_scope("fc-3"), scope =>
@@ -148,7 +148,7 @@ namespace TensorFlowNET.Examples.TextClassification
// batch normalization
conv = tf.layers.batch_normalization(conv, training: is_training);
// relu
conv = tf.nn.relu.Activate(conv);
conv = tf.nn.relu(conv);
conv = tf.transpose(conv, new int[] { 0, 1, 3, 2 });
});
}


+ 10
- 0
test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs View File

@@ -108,6 +108,16 @@ namespace TensorFlowNET.ExamplesTests
tf.Graph().as_default();
new TextClassificationWithMovieReviews() { Enabled = true }.Run();
}
[Ignore]
[TestMethod]
public void NeuralNetXor()
{
tf.Graph().as_default();
new NeuralNetXor() { Enabled = true }.Run();
}
}
}

Loading…
Cancel
Save