diff --git a/README.md b/README.md index 7f7d14a4..9cf23da2 100644 --- a/README.md +++ b/README.md @@ -28,8 +28,14 @@ In comparison to other projects, like for instance TensorFlowSharp which only pr Install TF.NET and TensorFlow binary through NuGet. ```sh +### install tensorflow C# binding PM> Install-Package TensorFlow.NET + +### Install tensorflow binary +### For CPU version PM> Install-Package SciSharp.TensorFlow.Redist +### For GPU version (CUDA and cuDNN are required) +PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU ``` Import TF.NET. diff --git a/src/SciSharp.TensorFlow.Redist/README.md b/src/SciSharp.TensorFlow.Redist/README.md index 3f75c4cf..5bdf82a1 100644 --- a/src/SciSharp.TensorFlow.Redist/README.md +++ b/src/SciSharp.TensorFlow.Redist/README.md @@ -1,8 +1,14 @@ ## SciSharp.TensorFlow.Redist ## -`SciSharp.TensorFlow.Redist` is a migration from [Microsoft.ML.TensorFlow.Redist](https://github.com/dotnet/machinelearning/tree/release/1.2/src/Redist/Microsoft.ML.TensorFlow.Redist). [ML.NET](https://github.com/dotnet/machinelearning) team will not maintain the package since [ML.NET](https://www.nuget.org/packages/Microsoft.ML) v1.4.0 going forward. +`SciSharp.TensorFlow.Redist` is a migration from [Microsoft.ML.TensorFlow.Redist](https://github.com/dotnet/machinelearning/tree/release/1.2/src/Redist/Microsoft.ML.TensorFlow.Redist). [ML.NET](https://github.com/dotnet/machinelearning) team will not maintain the package since [ML.NET](https://www.nuget.org/packages/Microsoft.ML) v1.3.0 going forward. +* CPU version for all platforms (Windows, Linux, OSX) +```powershell +PM> Install-Package SciSharp.TensorFlow.Redist +``` + +* GPU version for Windows ```powershell PM> Install-Package SciSharp.TensorFlow.Redist ``` @@ -16,7 +22,7 @@ Related merged [commits](https://github.com/SciSharp/TensorFlow.NET/commit/854a5 On Windows, the tar command does not support extracting archives with symlinks. So when `dotnet pack` runs on Windows it will only package the Windows binaries. 1. Run `dotnet pack` under `src/SciSharp.TensorFlow.Redist` directory in Linux. -2. Run `nuget push SciSharp.TensorFlow.Redist.1.14.0.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json` +2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.1.14.0.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json` diff --git a/src/SciSharp.TensorFlow.Redist/Redist.nuspec b/src/SciSharp.TensorFlow.Redist/Redist-CPU.nuspec similarity index 89% rename from src/SciSharp.TensorFlow.Redist/Redist.nuspec rename to src/SciSharp.TensorFlow.Redist/Redist-CPU.nuspec index d2527c8b..11919e8c 100644 --- a/src/SciSharp.TensorFlow.Redist/Redist.nuspec +++ b/src/SciSharp.TensorFlow.Redist/Redist-CPU.nuspec @@ -9,7 +9,7 @@ LICENSE.txt https://aka.ms/deprecateLicenseUrl https://www.tensorflow.org/ - $packageId$ contains the TensorFlow C library version $version$ redistributed as a NuGet package. + $packageId$ contains the TensorFlow C library CPU version $version$ redistributed as a NuGet package. https://github.com/tensorflow/tensorflow/releases/tag/v$version$ Copyright 2019 The TensorFlow Authors. All rights reserved. TensorFlow diff --git a/src/SciSharp.TensorFlow.Redist/Redist-Windows-GPU.nuspec b/src/SciSharp.TensorFlow.Redist/Redist-Windows-GPU.nuspec new file mode 100644 index 00000000..f010c96b --- /dev/null +++ b/src/SciSharp.TensorFlow.Redist/Redist-Windows-GPU.nuspec @@ -0,0 +1,26 @@ + + + + $packageId$ + $version$ + The TensorFlow Authors + The TensorFlow Authors + true + LICENSE.txt + https://aka.ms/deprecateLicenseUrl + https://www.tensorflow.org/ + $packageId$ contains the TensorFlow C library GPU version $version$ redistributed as a NuGet package. + https://github.com/tensorflow/tensorflow/releases/tag/v$version$ + Copyright 2019 The TensorFlow Authors. All rights reserved. + TensorFlow + + + + + + + + + + + \ No newline at end of file diff --git a/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist.nupkgproj b/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-CPU.nupkgproj similarity index 99% rename from src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist.nupkgproj rename to src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-CPU.nupkgproj index a0ca0a0a..6a225ede 100644 --- a/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist.nupkgproj +++ b/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-CPU.nupkgproj @@ -17,7 +17,7 @@ true false - Redist.nuspec + Redist-CPU.nuspec packageId=$(PackageId);version=$(PackageVersion) $(ProjDir) diff --git a/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-Windows-GPU.nupkgproj b/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-Windows-GPU.nupkgproj new file mode 100644 index 00000000..08fd9386 --- /dev/null +++ b/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-Windows-GPU.nupkgproj @@ -0,0 +1,187 @@ + + + + $(MSBuildThisFileDirectory) + $(ProjDir)bin\ + $(ProjDir)obj\ + + x64 + netstandard2.0 + 1.14.0 + 1 + + $(BinDir)packages\ + $(MSBuildProjectName) + $(TensorFlowVersion) + + true + false + + Redist-Windows-GPU.nuspec + packageId=$(PackageId);version=$(PackageVersion) + $(ProjDir) + + CopyFilesFromArchive + + win + linux + osx + $(PackageRid)-$(TargetArchitecture) + + + + + false + + + + + + + + + + + + + + + + + + + + + + + + + + <_downloadFiles Include="@(TensorFlowArchive);@(AdditionalDownloadFile)" Url="%(Identity)" DestinationFile="%(DownloadFile)" /> + + + + + + + + + + + + + + + + + + + + + + @(FilesWithHashes->'%(FileHash)') + $([System.IO.File]::ReadAllText('%(LocalShaFile)').Replace("%0A", "").Replace("%0D", "")) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + <_fileFromArchive Include="%(TensorFlowArchive.FilesFromArchive)" ExtractDirectory="%(TensorFlowArchive.ExtractDirectory)" Runtime="%(TensorFlowArchive.Runtime)" /> + <_fileFromArchive DestinationFile="%(FileName)%(Extension)"/> + <_fileFromArchive PackagePath="runtimes\%(_fileFromArchive.Runtime)\native\%(_fileFromArchive.DestinationFile)" /> + + + <_fileFromArchive Condition="'%(DestinationFile)' == 'LICENSE'" PackagePath="THIRD_PARTY_NOTICES.txt" Runtime="" /> + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/SciSharp.TensorFlow.Redist/libtensorflow-gpu-windows-x86_64-1.14.0.zip.sha b/src/SciSharp.TensorFlow.Redist/libtensorflow-gpu-windows-x86_64-1.14.0.zip.sha new file mode 100644 index 00000000..739129b1 --- /dev/null +++ b/src/SciSharp.TensorFlow.Redist/libtensorflow-gpu-windows-x86_64-1.14.0.zip.sha @@ -0,0 +1 @@ +850A27858FA951DF77A78CD1BD78B54F6EE2532DD5A49F0579A7B02C795C62F0212F20177EAEA2BD77BD451A57FBBD1348362492F9E14BFE5CA5028C71711293 diff --git a/src/TensorFlowHub/MnistDataSet.cs b/src/TensorFlowHub/MnistDataSet.cs index e0717ccb..accc57e1 100644 --- a/src/TensorFlowHub/MnistDataSet.cs +++ b/src/TensorFlowHub/MnistDataSet.cs @@ -27,5 +27,54 @@ namespace Tensorflow.Hub labels.astype(dataType); Labels = labels; } + + public (NDArray, NDArray) GetNextBatch(int batch_size, bool fake_data = false, bool shuffle = true) + { + var start = IndexInEpoch; + // Shuffle for the first epoch + if(EpochsCompleted == 0 && start == 0 && shuffle) + { + var perm0 = np.arange(NumOfExamples); + np.random.shuffle(perm0); + Data = Data[perm0]; + Labels = Labels[perm0]; + } + + // Go to the next epoch + if (start + batch_size > NumOfExamples) + { + // Finished epoch + EpochsCompleted += 1; + + // Get the rest examples in this epoch + var rest_num_examples = NumOfExamples - start; + //var images_rest_part = _images[np.arange(start, _num_examples)]; + //var labels_rest_part = _labels[np.arange(start, _num_examples)]; + // Shuffle the data + if (shuffle) + { + var perm = np.arange(NumOfExamples); + np.random.shuffle(perm); + Data = Data[perm]; + Labels = Labels[perm]; + } + + start = 0; + IndexInEpoch = batch_size - rest_num_examples; + var end = IndexInEpoch; + var images_new_part = Data[np.arange(start, end)]; + var labels_new_part = Labels[np.arange(start, end)]; + + /*return (np.concatenate(new float[][] { images_rest_part.Data(), images_new_part.Data() }, axis: 0), + np.concatenate(new float[][] { labels_rest_part.Data(), labels_new_part.Data() }, axis: 0));*/ + return (images_new_part, labels_new_part); + } + else + { + IndexInEpoch += batch_size; + var end = IndexInEpoch; + return (Data[np.arange(start, end)], Labels[np.arange(start, end)]); + } + } } } diff --git a/src/TensorFlowHub/MnistModelLoader.cs b/src/TensorFlowHub/MnistModelLoader.cs index 161c02e9..2fb0aa42 100644 --- a/src/TensorFlowHub/MnistModelLoader.cs +++ b/src/TensorFlowHub/MnistModelLoader.cs @@ -15,14 +15,27 @@ namespace Tensorflow.Hub private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; - public static async Task> LoadAsync(string trainDir, bool oneHot = false) + public static async Task> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null) { var loader = new MnistModelLoader(); - return await loader.LoadAsync(new ModelLoadSetting + + var setting = new ModelLoadSetting { TrainDir = trainDir, - OneHot = oneHot - }); + OneHot = oneHot, + TrainSize = trainSize + }; + + if (trainSize.HasValue) + setting.TrainSize = trainSize.Value; + + if (validationSize.HasValue) + setting.ValidationSize = validationSize.Value; + + if (testSize.HasValue) + setting.TestSize = testSize.Value; + + return await loader.LoadAsync(setting); } public async Task> LoadAsync(ModelLoadSetting setting) diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs index 0b9c7f3e..a5a9b674 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using System; +using System.Linq; using static Tensorflow.Python; namespace Tensorflow @@ -63,17 +64,37 @@ namespace Tensorflow public static Tensor operator *(long constant, Tensor tensor) => BinaryOpWrapper("mul", constant, tensor); public static Tensor operator *(ulong constant, Tensor tensor) => BinaryOpWrapper("mul", constant, tensor); - public static Tensor operator /(Tensor x, Tensor y) => BinaryOpWrapper("truediv", x, y); - public static Tensor operator /(Tensor x, float y) => BinaryOpWrapper("truediv", x, y); + private static readonly TF_DataType[] _intTfDataTypes = { + TF_DataType.TF_INT8, TF_DataType.TF_INT16, TF_DataType.TF_INT32, TF_DataType.TF_INT64, + TF_DataType.TF_QINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QINT32, + TF_DataType.TF_UINT8, TF_DataType.TF_UINT16, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64 + }; + public static Tensor operator /(double x, Tensor y) => BinaryOpWrapper("truediv", x, y); public static Tensor operator /(float x, Tensor y) => BinaryOpWrapper("truediv", x, y); + public static Tensor operator /(int x, Tensor y) => BinaryOpWrapper("floordiv", x, y); + public static Tensor operator /(Tensor x, Tensor y) => + _intTfDataTypes.Contains(x._dtype) + ? BinaryOpWrapper("floordiv", x, y) + : BinaryOpWrapper("truediv", x, y); + public static Tensor operator /(Tensor x, int y) => BinaryOpWrapper("floordiv", x, y); + public static Tensor operator /(Tensor x, float y) => BinaryOpWrapper("truediv", x, y); public static Tensor operator /(Tensor x, double y) => BinaryOpWrapper("truediv", x, y); public static Tensor operator %(Tensor x, Tensor y) => BinaryOpWrapper("mod", x, y); + public static Tensor operator >(double x, Tensor y) => gen_math_ops.greater(x, y); + public static Tensor operator >(float x, Tensor y) => gen_math_ops.greater(x, y); + public static Tensor operator >(int x, Tensor y) => gen_math_ops.greater(x, y); + public static Tensor operator >(Tensor x, Tensor y) => gen_math_ops.greater(x, y); public static Tensor operator >(Tensor x, int y) => gen_math_ops.greater(x, y); public static Tensor operator >=(Tensor x, Tensor y) => gen_math_ops.greater_equal(x, y); public static Tensor operator >(Tensor x, float y) => gen_math_ops.greater(x, y); public static Tensor operator >(Tensor x, double y) => gen_math_ops.greater(x, y); + + public static Tensor operator <(double x, Tensor y) => gen_math_ops.less(x, y); + public static Tensor operator <(float x, Tensor y) => gen_math_ops.less(x, y); + public static Tensor operator <(int x, Tensor y) => gen_math_ops.less(x, y); + public static Tensor operator <(Tensor x, Tensor y) => gen_math_ops.less(x, y); public static Tensor operator <(Tensor x, int y) => gen_math_ops.less(x, y); public static Tensor operator <=(Tensor x, Tensor y) => gen_math_ops.less_equal(x, y); public static Tensor operator <(Tensor x, float y) => gen_math_ops.less(x, y); @@ -99,6 +120,9 @@ namespace Tensorflow case "add": result = gen_math_ops.add(x1, y1, name: scope); break; + case "floordiv": + result = gen_math_ops.floor_div(x1, y1, name: scope); + break; case "truediv": result = gen_math_ops.real_div(x1, y1, name: scope); break; diff --git a/tensorflowlib/README.md b/tensorflowlib/README.md index 63cba815..318e5dc9 100644 --- a/tensorflowlib/README.md +++ b/tensorflowlib/README.md @@ -16,6 +16,8 @@ Here are some pre-built TensorFlow binaries you can use for each platform: - CPU-only: https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-1.14.0.zip - GPU-enabled: https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-windows-x86_64-1.14.0.zip + + ### Run in Linux `Install-Package TensorFlow.NET` @@ -31,10 +33,21 @@ sudo apt install libgdiplus More information about [System.Drawing on Linux](). + + ### Run in Mac OS -### GPU Tensorflow for windows -Before running verify you installed CUDA and cuDNN + + +### Tensorflow GPU for Windows + +Before running verify you installed CUDA and cuDNN (TensorFlow v1.14 is compatible with CUDA v10.0 and cuDNN v7.4), and make sure the corresponding cuda version is compatible. + +```powershell +PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU +``` + + ### Build from source for Windows diff --git a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs index 9756db61..536fcfc7 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs @@ -16,6 +16,7 @@ using NumSharp; using System; +using System.Diagnostics; using Tensorflow; using TensorFlowNET.Examples.Utility; using static Tensorflow.Python; @@ -144,6 +145,8 @@ namespace TensorFlowNET.Examples.ImageProcess float loss_val = 100.0f; float accuracy_val = 0f; + var sw = new Stopwatch(); + sw.Start(); foreach (var epoch in range(epochs)) { print($"Training epoch: {epoch + 1}"); @@ -165,7 +168,8 @@ namespace TensorFlowNET.Examples.ImageProcess var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_batch), new FeedItem(y, y_batch)); loss_val = result[0]; accuracy_val = result[1]; - print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}"); + print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")} {sw.ElapsedMilliseconds}ms"); + sw.Restart(); } } diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index 37f8d450..358f3fb9 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -467,5 +467,679 @@ namespace TensorFlowNET.UnitTest } #endregion } + + private IEnumerable MultiplyArray(IReadOnlyCollection first, IReadOnlyCollection second) + { + if(first.Count != second.Count) + throw new ArgumentException("Arrays should be of equal size!"); + + var firstEnumerator = first.GetEnumerator(); + var secondEnumerator = second.GetEnumerator(); + var result = new List(); + while (firstEnumerator.MoveNext()) + { + secondEnumerator.MoveNext(); + result.Add(firstEnumerator.Current * secondEnumerator.Current); + } + + firstEnumerator.Dispose(); + secondEnumerator.Dispose(); + + return result; + } + private IEnumerable MultiplyArray(IReadOnlyCollection first, IReadOnlyCollection second) + { + if(first.Count != second.Count) + throw new ArgumentException("Arrays should be of equal size!"); + + var firstEnumerator = first.GetEnumerator(); + var secondEnumerator = second.GetEnumerator(); + var result = new List(); + while (firstEnumerator.MoveNext()) + { + secondEnumerator.MoveNext(); + result.Add(firstEnumerator.Current * secondEnumerator.Current); + } + + firstEnumerator.Dispose(); + secondEnumerator.Dispose(); + + return result; + } + private IEnumerable MultiplyArray(IReadOnlyCollection first, IReadOnlyCollection second) + { + if(first.Count != second.Count) + throw new ArgumentException("Arrays should be of equal size!"); + + var firstEnumerator = first.GetEnumerator(); + var secondEnumerator = second.GetEnumerator(); + var result = new List(); + while (firstEnumerator.MoveNext()) + { + secondEnumerator.MoveNext(); + result.Add(firstEnumerator.Current * secondEnumerator.Current); + } + + firstEnumerator.Dispose(); + secondEnumerator.Dispose(); + + return result; + } + + [TestMethod] + public void mulOpTests() + { + const int rows = 2; // to avoid broadcasting effect + const int cols = 10; + + #region intTest + const int firstIntVal = 2; + const int secondIntVal = 3; + + var firstIntFeed = Enumerable.Repeat(firstIntVal, rows * cols).ToArray(); + var secondIntFeed = Enumerable.Repeat(secondIntVal, rows * cols).ToArray(); + var intResult = MultiplyArray(firstIntFeed, secondIntFeed).Sum(); + + var a = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols)); + var b = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols)); + var c = tf.reduce_sum(tf.reduce_sum(tf.multiply(a, b), 1)); + + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + } + + // Testing `operator *(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(a * b, 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + } + + // Testing `operator *(Tensor x, int y) + c = tf.reduce_sum(tf.reduce_sum(a * secondIntVal, 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + } + + // Testing `operator *(int x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(firstIntVal * b, 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + } + #endregion + + #region floatTest + const float firstFloatVal = 2.0f; + const float secondFloatVal = 3.0f; + + var firstFloatFeed = Enumerable.Repeat(firstFloatVal, rows * cols).ToArray(); + var secondFloatFeed = Enumerable.Repeat(secondFloatVal, rows * cols).ToArray(); + var floatResult = MultiplyArray(firstFloatFeed, secondFloatFeed).Sum(); + + a = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols)); + b = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.multiply(a, b), 1)); + + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((float)o, floatResult); + } + + // Testing `operator *(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(a * b, 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((float)o, floatResult); + } + + // Testing `operator *(Tensor x, float y) + c = tf.reduce_sum(tf.reduce_sum(a * secondFloatVal, 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((float)o, floatResult); + } + + // Testing `operator *(float x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(firstFloatVal * b, 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((float)o, floatResult); + } + #endregion + + #region doubleTest + const double firstDoubleVal = 2.0; + const double secondDoubleVal = 3.0; + + var firstDoubleFeed = Enumerable.Repeat(firstDoubleVal, rows * cols).ToArray(); + var secondDoubleFeed = Enumerable.Repeat(secondDoubleVal, rows * cols).ToArray(); + var doubleResult = MultiplyArray(firstDoubleFeed, secondDoubleFeed).Sum(); + + a = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols)); + b = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.multiply(a, b), 1)); + + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((double)o, doubleResult); + } + + // Testing `operator *(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(a * b, 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((double)o, doubleResult); + } + + // Testing `operator *(Tensor x, double y) + c = tf.reduce_sum(tf.reduce_sum(a * secondFloatVal, 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((double)o, doubleResult); + } + + // Testing `operator *(double x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(firstFloatVal * b, 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((double) o, doubleResult); + } + #endregion + } + + [TestMethod] + public void divOpTests() + { + const int rows = 2; // to avoid broadcasting effect + const int cols = 10; + + #region intTest + const int firstIntVal = 6; + const int secondIntVal = 3; + + var firstIntFeed = Enumerable.Repeat(firstIntVal, rows * cols).ToArray(); + var secondIntFeed = Enumerable.Repeat(secondIntVal, rows * cols).ToArray(); + var intResult = (int)(firstIntFeed.Sum() / (float)secondIntVal); + + var a = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols)); + var b = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols)); + var c = tf.reduce_sum(tf.reduce_sum(gen_math_ops.floor_div(a, b), 1)); + + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + } + + // Testing `operator /(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(a / b, 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + } + + // Testing `operator /(Tensor x, int y) + c = tf.reduce_sum(tf.reduce_sum(a / secondIntVal, 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + } + + // Testing `operator /(int x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(firstIntVal / b, 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + } + #endregion + + #region floatTest + const float firstFloatVal = 6.0f; + const float secondFloatVal = 3.0f; + + var firstFloatFeed = Enumerable.Repeat(firstFloatVal, rows * cols).ToArray(); + var secondFloatFeed = Enumerable.Repeat(secondFloatVal, rows * cols).ToArray(); + var floatResult = MultiplyArray(firstFloatFeed, secondFloatFeed.Select(x => 1/x).ToArray()).Sum(); + + a = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols)); + b = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.divide(a, b), 1)); + + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((float)o, floatResult); + } + + // Testing `operator /(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(a / b, 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((float)o, floatResult); + } + + // Testing `operator /(Tensor x, float y) + c = tf.reduce_sum(tf.reduce_sum(a / secondFloatVal, 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((float)o, floatResult); + } + + // Testing `operator /(float x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(firstFloatVal / b, 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((float)o, floatResult); + } + #endregion + + #region doubleTest + const double firstDoubleVal = 6.0; + const double secondDoubleVal = 3.0; + + var firstDoubleFeed = Enumerable.Repeat(firstDoubleVal, rows * cols).ToArray(); + var secondDoubleFeed = Enumerable.Repeat(secondDoubleVal, rows * cols).ToArray(); + var doubleResult = MultiplyArray(firstDoubleFeed, secondDoubleFeed.Select(x => 1/x).ToArray()).Sum(); + + a = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols)); + b = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.divide(a, b), 1)); + + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((double)o, doubleResult); + } + + // Testing `operator /(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(a / b, 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((double)o, doubleResult); + } + + // Testing `operator /(Tensor x, double y) + c = tf.reduce_sum(tf.reduce_sum(a / secondFloatVal, 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((double)o, doubleResult); + } + + // Testing `operator /(double x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(firstFloatVal / b, 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((double)o, doubleResult); + } + #endregion + } + + [TestMethod] + public void greaterThanOpTests() + { + const int rows = 2; // to avoid broadcasting effect + const int cols = 10; + + #region intTest + const int intThreshold = 10; + + var firstIntFeed = Enumerable.Range(0, rows * cols).ToArray(); + var secondIntFeed = Enumerable.Repeat(intThreshold, rows * cols).ToArray(); + var intResult = firstIntFeed.Count(elem => elem > intThreshold); + var intResultTwo = firstIntFeed.Count(elem => elem < intThreshold); + + var a = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols)); + var b = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols)); + var c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.greater(a, b), tf.int32), 1)); + + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + } + + // Testing `operator >(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a > b, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + } + + // Testing `operator >(Tensor x, int y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a > intThreshold, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + } + + // Testing `operator >(int x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(intThreshold > a, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResultTwo); + } + #endregion + + #region floatTest + const float floatThreshold = 10.0f; + + var firstFloatFeed = Enumerable.Range(0, rows * cols).Select(elem => (float)elem).ToArray(); + var secondFloatFeed = Enumerable.Repeat(floatThreshold, rows * cols).ToArray(); + var floatResult = firstFloatFeed.Count(elem => elem > floatThreshold); + var floatResultTwo = firstFloatFeed.Count(elem => elem < floatThreshold); + + a = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols)); + b = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.greater(a, b), tf.int32), 1)); + + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResult); + } + + // Testing `operator >(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a > b, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResult); + } + + // Testing `operator >(Tensor x, float y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a > floatThreshold, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResult); + } + + // Testing `operator >(float x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(floatThreshold > a, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResultTwo); + } + #endregion + + #region doubleTest + const double doubleThreshold = 10.0; + + var firstDoubleFeed = Enumerable.Repeat(0, rows * cols).Select(elem => (double)elem).ToArray(); + var secondDoubleFeed = Enumerable.Repeat(doubleThreshold, rows * cols).ToArray(); + var doubleResult = firstDoubleFeed.Count(elem => elem > doubleThreshold); + var doubleResultTwo = firstDoubleFeed.Count(elem => elem < doubleThreshold); + + a = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols)); + b = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.greater(a, b), tf.int32), 1)); + + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResult); + } + + // Testing `operator >(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a > b, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResult); + } + + // Testing `operator >(Tensor x, double y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a > doubleThreshold, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResult); + } + + // Testing `operator >(double x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(doubleThreshold > a, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResultTwo); + } + #endregion + } + + [TestMethod] + public void lessThanOpTests() + { + const int rows = 2; // to avoid broadcasting effect + const int cols = 10; + + #region intTest + const int intThreshold = 10; + + var firstIntFeed = Enumerable.Range(0, rows * cols).ToArray(); + var secondIntFeed = Enumerable.Repeat(intThreshold, rows * cols).ToArray(); + var intResult = firstIntFeed.Count(elem => elem < intThreshold); + var intResultTwo = firstIntFeed.Count(elem => elem > intThreshold); + + var a = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols)); + var b = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols)); + var c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.less(a, b), tf.int32), 1)); + + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + } + + // Testing `operator <(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a < b, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + } + + // Testing `operator <(Tensor x, int y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a < intThreshold, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + } + + // Testing `operator <(int x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(intThreshold < a, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResultTwo); + } + #endregion + + #region floatTest + const float floatThreshold = 10.0f; + + var firstFloatFeed = Enumerable.Range(0, rows * cols).Select(elem => (float)elem).ToArray(); + var secondFloatFeed = Enumerable.Repeat(floatThreshold, rows * cols).ToArray(); + var floatResult = firstFloatFeed.Count(elem => elem < floatThreshold); + var floatResultTwo = firstFloatFeed.Count(elem => elem > floatThreshold); + + a = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols)); + b = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.less(a, b), tf.int32), 1)); + + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResult); + } + + // Testing `operator <(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a < b, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResult); + } + + // Testing `operator <(Tensor x, float y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a < floatThreshold, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResult); + } + + // Testing `operator <(float x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(floatThreshold < a, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResultTwo); + } + #endregion + + #region doubleTest + const double doubleThreshold = 10.0; + + var firstDoubleFeed = Enumerable.Repeat(0, rows * cols).Select(elem => (double)elem).ToArray(); + var secondDoubleFeed = Enumerable.Repeat(doubleThreshold, rows * cols).ToArray(); + var doubleResult = firstDoubleFeed.Count(elem => elem < doubleThreshold); + var doubleResultTwo = firstDoubleFeed.Count(elem => elem > doubleThreshold); + + a = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols)); + b = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.less(a, b), tf.int32), 1)); + + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResult); + } + + // Testing `operator <(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a < b, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResult); + } + + // Testing `operator <(Tensor x, double y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a < doubleThreshold, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResult); + } + + // Testing `operator <(double x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(doubleThreshold < a, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResultTwo); + } + #endregion + } } }