From ee2bbbc1019c27c67b0cef410672a7180c82ab46 Mon Sep 17 00:00:00 2001 From: degtiadr Date: Thu, 30 May 2019 22:13:38 +0200 Subject: [PATCH 1/3] Inconsistency in handling of DT_FLOAT and DT_DOUBLE types by gradient calculation --- src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs index 3c6dac91..0d1e6c8a 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs @@ -368,7 +368,7 @@ namespace Tensorflow if (y.dtype.is_complex()) throw new TypeAccessException($"Gradients of complex tensors must set grad_ys (y.dtype = {y.dtype})"); var shape = array_ops.shape(y); - var constant = constant_op.constant(1.0f, name: $"grad_ys_{i}"); + var constant = constant_op.constant(y.dtype == TF_DataType.TF_DOUBLE ? (object)1.0 : (object)1.0f, name: $"grad_ys_{i}"); var fill = gen_array_ops.fill(shape, constant); new_grad_ys.Add(fill); } From 43caba84a04b492345642f34465156f81796ead6 Mon Sep 17 00:00:00 2001 From: degtiadr Date: Thu, 30 May 2019 22:35:58 +0200 Subject: [PATCH 2/3] F# example created --- TensorFlow.NET.sln | 6 + .../FunctionApproximation.fs | 104 ++++++++++++++++++ TensorFlowNET.Examples.FSharp/Program.fs | 8 ++ .../TensorFlowNET.Examples.FSharp.fsproj | 17 +++ 4 files changed, 135 insertions(+) create mode 100644 TensorFlowNET.Examples.FSharp/FunctionApproximation.fs create mode 100644 TensorFlowNET.Examples.FSharp/Program.fs create mode 100644 TensorFlowNET.Examples.FSharp/TensorFlowNET.Examples.FSharp.fsproj diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln index 7982345e..0b647158 100644 --- a/TensorFlow.NET.sln +++ b/TensorFlow.NET.sln @@ -15,6 +15,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Keras.Example", "test\Keras EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Keras.UnitTest", "test\KerasNET.Test\Keras.UnitTest.csproj", "{A5839A45-A117-4BEA-898B-DE1ED6E0D58F}" EndProject +Project("{F2A71F9B-5D33-465A-A702-920D77279786}") = "TensorFlowNET.Examples.FSharp", "TensorFlowNET.Examples.FSharp\TensorFlowNET.Examples.FSharp.fsproj", "{B3B14578-1BC4-41AE-9116-6A6B2CE5E182}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -45,6 +47,10 @@ Global {A5839A45-A117-4BEA-898B-DE1ED6E0D58F}.Debug|Any CPU.Build.0 = Debug|Any CPU {A5839A45-A117-4BEA-898B-DE1ED6E0D58F}.Release|Any CPU.ActiveCfg = Release|Any CPU {A5839A45-A117-4BEA-898B-DE1ED6E0D58F}.Release|Any CPU.Build.0 = Release|Any CPU + {B3B14578-1BC4-41AE-9116-6A6B2CE5E182}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {B3B14578-1BC4-41AE-9116-6A6B2CE5E182}.Debug|Any CPU.Build.0 = Debug|Any CPU + {B3B14578-1BC4-41AE-9116-6A6B2CE5E182}.Release|Any CPU.ActiveCfg = Release|Any CPU + {B3B14578-1BC4-41AE-9116-6A6B2CE5E182}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/TensorFlowNET.Examples.FSharp/FunctionApproximation.fs b/TensorFlowNET.Examples.FSharp/FunctionApproximation.fs new file mode 100644 index 00000000..44e5c7a7 --- /dev/null +++ b/TensorFlowNET.Examples.FSharp/FunctionApproximation.fs @@ -0,0 +1,104 @@ +module FunctionApproximation + +//reduced example from https://github.com/tirthajyoti/Machine-Learning-with-Python/blob/master/Function%20Approximation%20by%20Neural%20Network/Function%20approximation%20by%20linear%20model%20and%20deep%20network.ipynb + +open NumSharp +open Tensorflow +open System + + +let run()= + + let N_points = 75 // Number of points for constructing function + let x_min = 1.0 // Min of the range of x (feature) + let x_max = 15.0 // Max of the range of x (feature) + let noise_mean = 0.0 // Mean of the Gaussian noise adder + let noise_sd = 10.0 // Std.Dev of the Gaussian noise adder + + let linspace points = [| for i in 0 .. (points - 1) -> x_min + (x_max - x_min)/(float)points * (float)i |] + + let func_trans(xAr:float []) = + xAr + |>Array.map (fun (x:float) -> (20.0 * x+3.0 * System.Math.Pow(x,2.0)+0.1 * System.Math.Pow(x,3.0))*sin(x)*exp(-0.1*x)) + + let X_raw = linspace N_points + let Y_raw = func_trans(X_raw) + let X_mtr = Array2D.init X_raw.Length 1 (fun i j -> X_raw.[i]) + let X = np.array(X_mtr) + + let noise_x = np.random.normal(noise_mean,noise_sd,N_points) + let y = np.array(Y_raw)+noise_x + + let X_train = X + let y_train = y + + let learning_rate = 0.00001 + let training_epochs = 35000 + + let n_input = 1 // Number of features + let n_output = 1 // Regression output is a number only + let n_hidden_layer_1 = 25 // Hidden layer 1 + let n_hidden_layer_2 = 25 // Hidden layer 2 + + let x = tf.placeholder(tf.float64, new TensorShape(N_points,n_input)) + let y = tf.placeholder(tf.float64, new TensorShape(n_output)) + + + let weights = dict[ + "hidden_layer_1", tf.Variable(tf.random_normal([|n_input; n_hidden_layer_1|],dtype=tf.float64)) + "hidden_layer_2", tf.Variable(tf.random_normal([|n_hidden_layer_1; n_hidden_layer_2|],dtype=tf.float64)) + "out", tf.Variable(tf.random_normal([|n_hidden_layer_2; n_output|],dtype=tf.float64)) + ] + let biases = dict[ + "hidden_layer_1", tf.Variable(tf.random_normal([|n_hidden_layer_1|],dtype=tf.float64)) + "hidden_layer_2", tf.Variable(tf.random_normal([|n_hidden_layer_2|],dtype=tf.float64)) + "out", tf.Variable(tf.random_normal([|n_output|],dtype=tf.float64)) + ] + + + // Hidden layer with RELU activation + + let layer_1 = tf.add(tf.matmul(x, weights.["hidden_layer_1"]._AsTensor()),biases.["hidden_layer_1"]) + let layer_1 = tf.nn.relu(layer_1) + + let layer_2 = tf.add(tf.matmul(layer_1, weights.["hidden_layer_2"]._AsTensor()),biases.["hidden_layer_2"]) + let layer_2 = tf.nn.relu(layer_2) + + // Output layer with linear activation + let ops = tf.add(tf.matmul(layer_2, weights.["out"]._AsTensor()), biases.["out"]) + + // Define loss and optimizer + let cost = tf.reduce_mean(tf.square(tf.squeeze(ops)-y)) + + let gs = tf.Variable(1, trainable= false, name= "global_step") + + let optimizer = tf.train.GradientDescentOptimizer(learning_rate=(float32)learning_rate).minimize(cost,global_step = gs) + + let init = tf.global_variables_initializer() + + + Tensorflow.Python.``with``(tf.Session(), fun (sess:Session) -> + sess.run(init) |> ignore + // Loop over epochs + for epoch in [0..training_epochs] do + // Run optimization process (backprop) and cost function (to get loss value) + + let result=sess.run([|optimizer:>ITensorOrOperation; gs._AsTensor():>ITensorOrOperation; cost:>ITensorOrOperation|], new FeedItem(x, X_train), new FeedItem(y, y_train)) + + + let loss_value = (double) result.[2]; + + let step = (int) result.[1]; + + if epoch % 1000 = 0 then + sprintf "Step %d loss: %f" step loss_value |> Console.WriteLine + let w=sess.run(weights |> Array.ofSeq |> Array.map (fun pair -> pair.Value)) + let b = sess.run(biases |> Array.ofSeq |> Array.map (fun pair -> pair.Value)) + let yhat=sess.run([|ops:>ITensorOrOperation|],new FeedItem(x,X_train)) + for i in [0..(N_points-1)] do + sprintf "pred %f real: %f" ((double)(yhat.[0].[i].[0])) ((double)Y_raw.[i]) |> Console.WriteLine + ) + + + + diff --git a/TensorFlowNET.Examples.FSharp/Program.fs b/TensorFlowNET.Examples.FSharp/Program.fs new file mode 100644 index 00000000..3cbe7ea9 --- /dev/null +++ b/TensorFlowNET.Examples.FSharp/Program.fs @@ -0,0 +1,8 @@ +// Learn more about F# at http://fsharp.org + +open System + +[] +let main argv = + FunctionApproximation.run() + 0 // return an integer exit code diff --git a/TensorFlowNET.Examples.FSharp/TensorFlowNET.Examples.FSharp.fsproj b/TensorFlowNET.Examples.FSharp/TensorFlowNET.Examples.FSharp.fsproj new file mode 100644 index 00000000..7414399f --- /dev/null +++ b/TensorFlowNET.Examples.FSharp/TensorFlowNET.Examples.FSharp.fsproj @@ -0,0 +1,17 @@ + + + + Exe + netcoreapp2.2 + + + + + + + + + + + + From 917ff43e777885ecffd2c20eee85d16790ca59cb Mon Sep 17 00:00:00 2001 From: degtiadr Date: Thu, 30 May 2019 22:46:01 +0200 Subject: [PATCH 3/3] F# example project structure adjusted --- TensorFlow.NET.sln | 10 +++++----- .../FunctionApproximation.fs | 0 .../TensorFlowNET.Examples.FSharp}/Program.fs | 0 .../TensorFlowNET.Examples.FSharp.fsproj | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) rename {TensorFlowNET.Examples.FSharp => test/TensorFlowNET.Examples.FSharp}/FunctionApproximation.fs (100%) rename {TensorFlowNET.Examples.FSharp => test/TensorFlowNET.Examples.FSharp}/Program.fs (100%) rename {TensorFlowNET.Examples.FSharp => test/TensorFlowNET.Examples.FSharp}/TensorFlowNET.Examples.FSharp.fsproj (78%) diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln index 0b647158..51125309 100644 --- a/TensorFlow.NET.sln +++ b/TensorFlow.NET.sln @@ -15,7 +15,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Keras.Example", "test\Keras EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Keras.UnitTest", "test\KerasNET.Test\Keras.UnitTest.csproj", "{A5839A45-A117-4BEA-898B-DE1ED6E0D58F}" EndProject -Project("{F2A71F9B-5D33-465A-A702-920D77279786}") = "TensorFlowNET.Examples.FSharp", "TensorFlowNET.Examples.FSharp\TensorFlowNET.Examples.FSharp.fsproj", "{B3B14578-1BC4-41AE-9116-6A6B2CE5E182}" +Project("{6EC3EE1D-3C4E-46DD-8F32-0CC8E7565705}") = "TensorFlowNET.Examples.FSharp", "test\TensorFlowNET.Examples.FSharp\TensorFlowNET.Examples.FSharp.fsproj", "{62BC3801-F0D3-44A9-A0AC-712F40C8F961}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -47,10 +47,10 @@ Global {A5839A45-A117-4BEA-898B-DE1ED6E0D58F}.Debug|Any CPU.Build.0 = Debug|Any CPU {A5839A45-A117-4BEA-898B-DE1ED6E0D58F}.Release|Any CPU.ActiveCfg = Release|Any CPU {A5839A45-A117-4BEA-898B-DE1ED6E0D58F}.Release|Any CPU.Build.0 = Release|Any CPU - {B3B14578-1BC4-41AE-9116-6A6B2CE5E182}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {B3B14578-1BC4-41AE-9116-6A6B2CE5E182}.Debug|Any CPU.Build.0 = Debug|Any CPU - {B3B14578-1BC4-41AE-9116-6A6B2CE5E182}.Release|Any CPU.ActiveCfg = Release|Any CPU - {B3B14578-1BC4-41AE-9116-6A6B2CE5E182}.Release|Any CPU.Build.0 = Release|Any CPU + {62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Debug|Any CPU.Build.0 = Debug|Any CPU + {62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Release|Any CPU.ActiveCfg = Release|Any CPU + {62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/TensorFlowNET.Examples.FSharp/FunctionApproximation.fs b/test/TensorFlowNET.Examples.FSharp/FunctionApproximation.fs similarity index 100% rename from TensorFlowNET.Examples.FSharp/FunctionApproximation.fs rename to test/TensorFlowNET.Examples.FSharp/FunctionApproximation.fs diff --git a/TensorFlowNET.Examples.FSharp/Program.fs b/test/TensorFlowNET.Examples.FSharp/Program.fs similarity index 100% rename from TensorFlowNET.Examples.FSharp/Program.fs rename to test/TensorFlowNET.Examples.FSharp/Program.fs diff --git a/TensorFlowNET.Examples.FSharp/TensorFlowNET.Examples.FSharp.fsproj b/test/TensorFlowNET.Examples.FSharp/TensorFlowNET.Examples.FSharp.fsproj similarity index 78% rename from TensorFlowNET.Examples.FSharp/TensorFlowNET.Examples.FSharp.fsproj rename to test/TensorFlowNET.Examples.FSharp/TensorFlowNET.Examples.FSharp.fsproj index 7414399f..f54474d4 100644 --- a/TensorFlowNET.Examples.FSharp/TensorFlowNET.Examples.FSharp.fsproj +++ b/test/TensorFlowNET.Examples.FSharp/TensorFlowNET.Examples.FSharp.fsproj @@ -11,7 +11,7 @@ - +