@@ -298,5 +298,8 @@ namespace Tensorflow | |||
public static Tensor argmax(Tensor input, int axis = -1, string name = null, int? dimension = null, TF_DataType output_type = TF_DataType.TF_INT64) | |||
=> gen_math_ops.arg_max(input, axis, name: name, output_type: output_type); | |||
public static Tensor square(Tensor x, string name = null) | |||
=> gen_math_ops.square(x, name: name); | |||
} | |||
} |
@@ -26,7 +26,9 @@ namespace Tensorflow | |||
partition_strategy: partition_strategy, | |||
name: name); | |||
public static IActivation relu => new relu(); | |||
public static IActivation relu() => new relu(); | |||
public static Tensor relu(Tensor features, string name = null) => gen_nn_ops.relu(features, name); | |||
public static Tensor[] fused_batch_norm(Tensor x, | |||
RefVariable scale, | |||
@@ -29,5 +29,13 @@ namespace Tensorflow | |||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
int? seed = null, | |||
string name = null) => random_ops.random_uniform(shape, minval, maxval, dtype, seed, name); | |||
public static Tensor truncated_normal(int[] shape, | |||
float mean = 0.0f, | |||
float stddev = 1.0f, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
int? seed = null, | |||
string name = null) | |||
=> random_ops.truncated_normal(shape, mean, stddev, dtype, seed, name); | |||
} | |||
} |
@@ -1,6 +1,8 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Operations.Activation; | |||
namespace Tensorflow.Operations | |||
{ | |||
@@ -173,5 +175,59 @@ namespace Tensorflow.Operations | |||
return (_op.outputs[0], _op.outputs[1]); | |||
} | |||
/// <summary> | |||
/// Computes rectified linear: `max(features, 0)`. | |||
/// </summary> | |||
/// <param name="features">A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `int64`, `bfloat16`, `uint16`, `half`, `uint32`, `uint64`, `qint8`.</param> | |||
/// <param name="name">A name for the operation (optional).</param> | |||
/// <returns>A `Tensor`. Has the same type as `features`.</returns> | |||
public static Tensor relu(Tensor features, string name = null) | |||
{ | |||
//_ctx = _context._context | |||
//if _ctx is not None and _ctx._eager_context.is_eager: | |||
// try: | |||
// _result = _pywrap_tensorflow.TFE_Py_FastPathExecute( | |||
// _ctx._context_handle, _ctx._eager_context.device_name, "Relu", name, | |||
// _ctx._post_execution_callbacks, features) | |||
// return _result | |||
// except _core._FallbackException: | |||
// try: | |||
// return relu_eager_fallback( | |||
// features, name=name, ctx=_ctx) | |||
// except _core._SymbolicException: | |||
// pass # Add nodes to the TensorFlow graph. | |||
// except (TypeError, ValueError): | |||
// result = _dispatch.dispatch( | |||
// relu, features=features, name=name) | |||
// if result is not _dispatch.OpDispatcher.NOT_SUPPORTED: | |||
// return result | |||
// raise | |||
// except _core._NotOkStatusException as e: | |||
// if name is not None: | |||
// message = e.message + " name: " + name | |||
// else: | |||
// message = e.message | |||
// _six.raise_from(_core._status_to_exception(e.code, message), None) | |||
//# Add nodes to the TensorFlow graph. | |||
//try: | |||
OpDefLibrary _op_def_lib = new OpDefLibrary(); | |||
var _op = _op_def_lib._apply_op_helper("Relu", name: name, args: new { features }); | |||
return _op.outputs[0]; | |||
//except (TypeError, ValueError): | |||
// result = _dispatch.dispatch( | |||
// relu, features=features, name=name) | |||
// if result is not _dispatch.OpDispatcher.NOT_SUPPORTED: | |||
// return result | |||
// raise | |||
// var _result = _op.outputs.ToArray(); | |||
//_inputs_flat = _op.inputs | |||
//_attrs = ("T", _op.get_attr("T")) | |||
//_execute.record_gradient( | |||
// "Relu", _inputs_flat, _attrs, _result, name) | |||
//_result, = _result | |||
// return _result; | |||
} | |||
} | |||
} |
@@ -0,0 +1,96 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using NumSharp; | |||
using Tensorflow; | |||
namespace TensorFlowNET.Examples | |||
{ | |||
/// <summary> | |||
/// Simple vanilla neural net solving the famous XOR problem | |||
/// https://github.com/amygdala/tensorflow-workshop/blob/master/workshop_sections/getting_started/xor/README.md | |||
/// </summary> | |||
public class NeuralNetXor : Python, IExample | |||
{ | |||
public int Priority => 2; | |||
public bool Enabled { get; set; } = true; | |||
public string Name => "NN XOR"; | |||
public int num_steps = 5000; | |||
private (Operation, Tensor, RefVariable) make_graph(Tensor features,Tensor labels, int num_hidden = 8) | |||
{ | |||
var stddev = 1 / Math.Sqrt(2); | |||
var hidden_weights = tf.Variable(tf.truncated_normal(new int []{2, num_hidden}, stddev: (float) stddev )); | |||
// Shape [4, num_hidden] | |||
var hidden_activations = tf.nn.relu(tf.matmul(features, hidden_weights)); | |||
var output_weights = tf.Variable(tf.truncated_normal( | |||
new[] {num_hidden, 1}, | |||
stddev: (float) (1 / Math.Sqrt(num_hidden)) | |||
)); | |||
// Shape [4, 1] | |||
var logits = tf.matmul(hidden_activations, output_weights); | |||
// Shape [4] | |||
var predictions = tf.sigmoid(tf.squeeze(logits)); | |||
var loss = tf.reduce_mean(tf.square(predictions - tf.cast(labels, tf.float32))); | |||
var gs = tf.Variable(0, trainable: false); | |||
var train_op = tf.train.GradientDescentOptimizer(0.2f).minimize(loss, global_step: gs); | |||
return (train_op, loss, gs); | |||
} | |||
public bool Run() | |||
{ | |||
var graph = tf.Graph(); | |||
var init=with(graph.as_default(), g => | |||
{ | |||
var features = tf.placeholder(tf.float32, new TensorShape(4, 2)); | |||
var labels = tf.placeholder(tf.int32, new TensorShape(4)); | |||
var (train_op, loss, gs) = make_graph(features, labels); | |||
return tf.global_variables_initializer(); | |||
}); | |||
// Start tf session | |||
with<Session>(tf.Session(), sess => | |||
{ | |||
init.run(); | |||
var step = 0; | |||
var xy = np.array(new bool[,] | |||
{ | |||
{true, false, }, | |||
{true, true, }, | |||
{false, false, }, | |||
{false, true, }, | |||
}, dtype: np.float32); | |||
var y_ = np.array(new[] {true, false, false, true}, dtype: np.int32); | |||
while (step < num_steps) | |||
{ | |||
// original python: | |||
//_, step, loss_value = sess.run( | |||
// [train_op, gs, loss], | |||
// feed_dict={features: xy, labels: y_} | |||
// ) | |||
// TODO: how the hell to port that to c#? | |||
// var ( _, step, loss_value) = sess.run(new object[] {train_op, gs, loss},feed_dict: new {"features": xy, "labels": y_}); | |||
} | |||
//tf.logging.info('Final loss is: {}'.format(loss_value)) | |||
//Console.WriteLine($"Final loss is: {loss_value}"); | |||
}); | |||
return true; | |||
} | |||
public void PrepareData() | |||
{ | |||
} | |||
} | |||
} |
@@ -66,7 +66,7 @@ namespace TensorFlowNET.Examples.TextClassification | |||
filters: num_filters[0], | |||
kernel_size: new int[] { filter_sizes[0], embedding_size }, | |||
kernel_initializer: cnn_initializer, | |||
activation: tf.nn.relu); | |||
activation: tf.nn.relu()); | |||
conv0 = tf.transpose(conv0, new int[] { 0, 1, 3, 2 }); | |||
}); | |||
@@ -99,12 +99,12 @@ namespace TensorFlowNET.Examples.TextClassification | |||
// ============= Fully Connected Layers ============= | |||
with(tf.name_scope("fc-1"), scope => | |||
{ | |||
fc1_out = tf.layers.dense(h_flat, 2048, activation: tf.nn.relu, kernel_initializer: fc_initializer); | |||
fc1_out = tf.layers.dense(h_flat, 2048, activation: tf.nn.relu(), kernel_initializer: fc_initializer); | |||
}); | |||
with(tf.name_scope("fc-2"), scope => | |||
{ | |||
fc2_out = tf.layers.dense(fc1_out, 2048, activation: tf.nn.relu, kernel_initializer: fc_initializer); | |||
fc2_out = tf.layers.dense(fc1_out, 2048, activation: tf.nn.relu(), kernel_initializer: fc_initializer); | |||
}); | |||
with(tf.name_scope("fc-3"), scope => | |||
@@ -148,7 +148,7 @@ namespace TensorFlowNET.Examples.TextClassification | |||
// batch normalization | |||
conv = tf.layers.batch_normalization(conv, training: is_training); | |||
// relu | |||
conv = tf.nn.relu.Activate(conv); | |||
conv = tf.nn.relu(conv); | |||
conv = tf.transpose(conv, new int[] { 0, 1, 3, 2 }); | |||
}); | |||
} | |||
@@ -108,6 +108,16 @@ namespace TensorFlowNET.ExamplesTests | |||
tf.Graph().as_default(); | |||
new TextClassificationWithMovieReviews() { Enabled = true }.Run(); | |||
} | |||
[Ignore] | |||
[TestMethod] | |||
public void NeuralNetXor() | |||
{ | |||
tf.Graph().as_default(); | |||
new NeuralNetXor() { Enabled = true }.Run(); | |||
} | |||
} | |||
} |