diff --git a/README.md b/README.md
index 7f7d14a4..9cf23da2 100644
--- a/README.md
+++ b/README.md
@@ -28,8 +28,14 @@ In comparison to other projects, like for instance TensorFlowSharp which only pr
Install TF.NET and TensorFlow binary through NuGet.
```sh
+### install tensorflow C# binding
PM> Install-Package TensorFlow.NET
+
+### Install tensorflow binary
+### For CPU version
PM> Install-Package SciSharp.TensorFlow.Redist
+### For GPU version (CUDA and cuDNN are required)
+PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU
```
Import TF.NET.
diff --git a/src/SciSharp.TensorFlow.Redist/README.md b/src/SciSharp.TensorFlow.Redist/README.md
index 3f75c4cf..5bdf82a1 100644
--- a/src/SciSharp.TensorFlow.Redist/README.md
+++ b/src/SciSharp.TensorFlow.Redist/README.md
@@ -1,8 +1,14 @@
## SciSharp.TensorFlow.Redist ##
-`SciSharp.TensorFlow.Redist` is a migration from [Microsoft.ML.TensorFlow.Redist](https://github.com/dotnet/machinelearning/tree/release/1.2/src/Redist/Microsoft.ML.TensorFlow.Redist). [ML.NET](https://github.com/dotnet/machinelearning) team will not maintain the package since [ML.NET](https://www.nuget.org/packages/Microsoft.ML) v1.4.0 going forward.
+`SciSharp.TensorFlow.Redist` is a migration from [Microsoft.ML.TensorFlow.Redist](https://github.com/dotnet/machinelearning/tree/release/1.2/src/Redist/Microsoft.ML.TensorFlow.Redist). [ML.NET](https://github.com/dotnet/machinelearning) team will not maintain the package since [ML.NET](https://www.nuget.org/packages/Microsoft.ML) v1.3.0 going forward.
+* CPU version for all platforms (Windows, Linux, OSX)
+```powershell
+PM> Install-Package SciSharp.TensorFlow.Redist
+```
+
+* GPU version for Windows
```powershell
PM> Install-Package SciSharp.TensorFlow.Redist
```
@@ -16,7 +22,7 @@ Related merged [commits](https://github.com/SciSharp/TensorFlow.NET/commit/854a5
On Windows, the tar command does not support extracting archives with symlinks. So when `dotnet pack` runs on Windows it will only package the Windows binaries.
1. Run `dotnet pack` under `src/SciSharp.TensorFlow.Redist` directory in Linux.
-2. Run `nuget push SciSharp.TensorFlow.Redist.1.14.0.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json`
+2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.1.14.0.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json`
diff --git a/src/SciSharp.TensorFlow.Redist/Redist.nuspec b/src/SciSharp.TensorFlow.Redist/Redist-CPU.nuspec
similarity index 89%
rename from src/SciSharp.TensorFlow.Redist/Redist.nuspec
rename to src/SciSharp.TensorFlow.Redist/Redist-CPU.nuspec
index d2527c8b..11919e8c 100644
--- a/src/SciSharp.TensorFlow.Redist/Redist.nuspec
+++ b/src/SciSharp.TensorFlow.Redist/Redist-CPU.nuspec
@@ -9,7 +9,7 @@
LICENSE.txt
https://aka.ms/deprecateLicenseUrl
https://www.tensorflow.org/
- $packageId$ contains the TensorFlow C library version $version$ redistributed as a NuGet package.
+ $packageId$ contains the TensorFlow C library CPU version $version$ redistributed as a NuGet package.
https://github.com/tensorflow/tensorflow/releases/tag/v$version$
Copyright 2019 The TensorFlow Authors. All rights reserved.
TensorFlow
diff --git a/src/SciSharp.TensorFlow.Redist/Redist-Windows-GPU.nuspec b/src/SciSharp.TensorFlow.Redist/Redist-Windows-GPU.nuspec
new file mode 100644
index 00000000..f010c96b
--- /dev/null
+++ b/src/SciSharp.TensorFlow.Redist/Redist-Windows-GPU.nuspec
@@ -0,0 +1,26 @@
+
+
+
+ $packageId$
+ $version$
+ The TensorFlow Authors
+ The TensorFlow Authors
+ true
+ LICENSE.txt
+ https://aka.ms/deprecateLicenseUrl
+ https://www.tensorflow.org/
+ $packageId$ contains the TensorFlow C library GPU version $version$ redistributed as a NuGet package.
+ https://github.com/tensorflow/tensorflow/releases/tag/v$version$
+ Copyright 2019 The TensorFlow Authors. All rights reserved.
+ TensorFlow
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist.nupkgproj b/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-CPU.nupkgproj
similarity index 99%
rename from src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist.nupkgproj
rename to src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-CPU.nupkgproj
index a0ca0a0a..6a225ede 100644
--- a/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist.nupkgproj
+++ b/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-CPU.nupkgproj
@@ -17,7 +17,7 @@
true
false
- Redist.nuspec
+ Redist-CPU.nuspec
packageId=$(PackageId);version=$(PackageVersion)
$(ProjDir)
diff --git a/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-Windows-GPU.nupkgproj b/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-Windows-GPU.nupkgproj
new file mode 100644
index 00000000..08fd9386
--- /dev/null
+++ b/src/SciSharp.TensorFlow.Redist/SciSharp.TensorFlow.Redist-Windows-GPU.nupkgproj
@@ -0,0 +1,187 @@
+
+
+
+ $(MSBuildThisFileDirectory)
+ $(ProjDir)bin\
+ $(ProjDir)obj\
+
+ x64
+ netstandard2.0
+ 1.14.0
+ 1
+
+ $(BinDir)packages\
+ $(MSBuildProjectName)
+ $(TensorFlowVersion)
+
+ true
+ false
+
+ Redist-Windows-GPU.nuspec
+ packageId=$(PackageId);version=$(PackageVersion)
+ $(ProjDir)
+
+ CopyFilesFromArchive
+
+ win
+ linux
+ osx
+ $(PackageRid)-$(TargetArchitecture)
+
+
+
+
+ false
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ <_downloadFiles Include="@(TensorFlowArchive);@(AdditionalDownloadFile)" Url="%(Identity)" DestinationFile="%(DownloadFile)" />
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ @(FilesWithHashes->'%(FileHash)')
+ $([System.IO.File]::ReadAllText('%(LocalShaFile)').Replace("%0A", "").Replace("%0D", ""))
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ <_fileFromArchive Include="%(TensorFlowArchive.FilesFromArchive)" ExtractDirectory="%(TensorFlowArchive.ExtractDirectory)" Runtime="%(TensorFlowArchive.Runtime)" />
+ <_fileFromArchive DestinationFile="%(FileName)%(Extension)"/>
+ <_fileFromArchive PackagePath="runtimes\%(_fileFromArchive.Runtime)\native\%(_fileFromArchive.DestinationFile)" />
+
+
+ <_fileFromArchive Condition="'%(DestinationFile)' == 'LICENSE'" PackagePath="THIRD_PARTY_NOTICES.txt" Runtime="" />
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/SciSharp.TensorFlow.Redist/libtensorflow-gpu-windows-x86_64-1.14.0.zip.sha b/src/SciSharp.TensorFlow.Redist/libtensorflow-gpu-windows-x86_64-1.14.0.zip.sha
new file mode 100644
index 00000000..739129b1
--- /dev/null
+++ b/src/SciSharp.TensorFlow.Redist/libtensorflow-gpu-windows-x86_64-1.14.0.zip.sha
@@ -0,0 +1 @@
+850A27858FA951DF77A78CD1BD78B54F6EE2532DD5A49F0579A7B02C795C62F0212F20177EAEA2BD77BD451A57FBBD1348362492F9E14BFE5CA5028C71711293
diff --git a/src/TensorFlowHub/MnistDataSet.cs b/src/TensorFlowHub/MnistDataSet.cs
index e0717ccb..accc57e1 100644
--- a/src/TensorFlowHub/MnistDataSet.cs
+++ b/src/TensorFlowHub/MnistDataSet.cs
@@ -27,5 +27,54 @@ namespace Tensorflow.Hub
labels.astype(dataType);
Labels = labels;
}
+
+ public (NDArray, NDArray) GetNextBatch(int batch_size, bool fake_data = false, bool shuffle = true)
+ {
+ var start = IndexInEpoch;
+ // Shuffle for the first epoch
+ if(EpochsCompleted == 0 && start == 0 && shuffle)
+ {
+ var perm0 = np.arange(NumOfExamples);
+ np.random.shuffle(perm0);
+ Data = Data[perm0];
+ Labels = Labels[perm0];
+ }
+
+ // Go to the next epoch
+ if (start + batch_size > NumOfExamples)
+ {
+ // Finished epoch
+ EpochsCompleted += 1;
+
+ // Get the rest examples in this epoch
+ var rest_num_examples = NumOfExamples - start;
+ //var images_rest_part = _images[np.arange(start, _num_examples)];
+ //var labels_rest_part = _labels[np.arange(start, _num_examples)];
+ // Shuffle the data
+ if (shuffle)
+ {
+ var perm = np.arange(NumOfExamples);
+ np.random.shuffle(perm);
+ Data = Data[perm];
+ Labels = Labels[perm];
+ }
+
+ start = 0;
+ IndexInEpoch = batch_size - rest_num_examples;
+ var end = IndexInEpoch;
+ var images_new_part = Data[np.arange(start, end)];
+ var labels_new_part = Labels[np.arange(start, end)];
+
+ /*return (np.concatenate(new float[][] { images_rest_part.Data(), images_new_part.Data() }, axis: 0),
+ np.concatenate(new float[][] { labels_rest_part.Data(), labels_new_part.Data() }, axis: 0));*/
+ return (images_new_part, labels_new_part);
+ }
+ else
+ {
+ IndexInEpoch += batch_size;
+ var end = IndexInEpoch;
+ return (Data[np.arange(start, end)], Labels[np.arange(start, end)]);
+ }
+ }
}
}
diff --git a/src/TensorFlowHub/MnistModelLoader.cs b/src/TensorFlowHub/MnistModelLoader.cs
index 161c02e9..2fb0aa42 100644
--- a/src/TensorFlowHub/MnistModelLoader.cs
+++ b/src/TensorFlowHub/MnistModelLoader.cs
@@ -15,14 +15,27 @@ namespace Tensorflow.Hub
private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz";
private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz";
- public static async Task> LoadAsync(string trainDir, bool oneHot = false)
+ public static async Task> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null)
{
var loader = new MnistModelLoader();
- return await loader.LoadAsync(new ModelLoadSetting
+
+ var setting = new ModelLoadSetting
{
TrainDir = trainDir,
- OneHot = oneHot
- });
+ OneHot = oneHot,
+ TrainSize = trainSize
+ };
+
+ if (trainSize.HasValue)
+ setting.TrainSize = trainSize.Value;
+
+ if (validationSize.HasValue)
+ setting.ValidationSize = validationSize.Value;
+
+ if (testSize.HasValue)
+ setting.TestSize = testSize.Value;
+
+ return await loader.LoadAsync(setting);
}
public async Task> LoadAsync(ModelLoadSetting setting)
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
index 0b9c7f3e..a5a9b674 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
@@ -15,6 +15,7 @@
******************************************************************************/
using System;
+using System.Linq;
using static Tensorflow.Python;
namespace Tensorflow
@@ -63,17 +64,37 @@ namespace Tensorflow
public static Tensor operator *(long constant, Tensor tensor) => BinaryOpWrapper("mul", constant, tensor);
public static Tensor operator *(ulong constant, Tensor tensor) => BinaryOpWrapper("mul", constant, tensor);
- public static Tensor operator /(Tensor x, Tensor y) => BinaryOpWrapper("truediv", x, y);
- public static Tensor operator /(Tensor x, float y) => BinaryOpWrapper("truediv", x, y);
+ private static readonly TF_DataType[] _intTfDataTypes = {
+ TF_DataType.TF_INT8, TF_DataType.TF_INT16, TF_DataType.TF_INT32, TF_DataType.TF_INT64,
+ TF_DataType.TF_QINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QINT32,
+ TF_DataType.TF_UINT8, TF_DataType.TF_UINT16, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64
+ };
+ public static Tensor operator /(double x, Tensor y) => BinaryOpWrapper("truediv", x, y);
public static Tensor operator /(float x, Tensor y) => BinaryOpWrapper("truediv", x, y);
+ public static Tensor operator /(int x, Tensor y) => BinaryOpWrapper("floordiv", x, y);
+ public static Tensor operator /(Tensor x, Tensor y) =>
+ _intTfDataTypes.Contains(x._dtype)
+ ? BinaryOpWrapper("floordiv", x, y)
+ : BinaryOpWrapper("truediv", x, y);
+ public static Tensor operator /(Tensor x, int y) => BinaryOpWrapper("floordiv", x, y);
+ public static Tensor operator /(Tensor x, float y) => BinaryOpWrapper("truediv", x, y);
public static Tensor operator /(Tensor x, double y) => BinaryOpWrapper("truediv", x, y);
public static Tensor operator %(Tensor x, Tensor y) => BinaryOpWrapper("mod", x, y);
+ public static Tensor operator >(double x, Tensor y) => gen_math_ops.greater(x, y);
+ public static Tensor operator >(float x, Tensor y) => gen_math_ops.greater(x, y);
+ public static Tensor operator >(int x, Tensor y) => gen_math_ops.greater(x, y);
+ public static Tensor operator >(Tensor x, Tensor y) => gen_math_ops.greater(x, y);
public static Tensor operator >(Tensor x, int y) => gen_math_ops.greater(x, y);
public static Tensor operator >=(Tensor x, Tensor y) => gen_math_ops.greater_equal(x, y);
public static Tensor operator >(Tensor x, float y) => gen_math_ops.greater(x, y);
public static Tensor operator >(Tensor x, double y) => gen_math_ops.greater(x, y);
+
+ public static Tensor operator <(double x, Tensor y) => gen_math_ops.less(x, y);
+ public static Tensor operator <(float x, Tensor y) => gen_math_ops.less(x, y);
+ public static Tensor operator <(int x, Tensor y) => gen_math_ops.less(x, y);
+ public static Tensor operator <(Tensor x, Tensor y) => gen_math_ops.less(x, y);
public static Tensor operator <(Tensor x, int y) => gen_math_ops.less(x, y);
public static Tensor operator <=(Tensor x, Tensor y) => gen_math_ops.less_equal(x, y);
public static Tensor operator <(Tensor x, float y) => gen_math_ops.less(x, y);
@@ -99,6 +120,9 @@ namespace Tensorflow
case "add":
result = gen_math_ops.add(x1, y1, name: scope);
break;
+ case "floordiv":
+ result = gen_math_ops.floor_div(x1, y1, name: scope);
+ break;
case "truediv":
result = gen_math_ops.real_div(x1, y1, name: scope);
break;
diff --git a/tensorflowlib/README.md b/tensorflowlib/README.md
index 63cba815..318e5dc9 100644
--- a/tensorflowlib/README.md
+++ b/tensorflowlib/README.md
@@ -16,6 +16,8 @@ Here are some pre-built TensorFlow binaries you can use for each platform:
- CPU-only: https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-1.14.0.zip
- GPU-enabled: https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-windows-x86_64-1.14.0.zip
+
+
### Run in Linux
`Install-Package TensorFlow.NET`
@@ -31,10 +33,21 @@ sudo apt install libgdiplus
More information about [System.Drawing on Linux]().
+
+
### Run in Mac OS
-### GPU Tensorflow for windows
-Before running verify you installed CUDA and cuDNN
+
+
+### Tensorflow GPU for Windows
+
+Before running verify you installed CUDA and cuDNN (TensorFlow v1.14 is compatible with CUDA v10.0 and cuDNN v7.4), and make sure the corresponding cuda version is compatible.
+
+```powershell
+PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU
+```
+
+
### Build from source for Windows
diff --git a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs
index 9756db61..536fcfc7 100644
--- a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs
+++ b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs
@@ -16,6 +16,7 @@
using NumSharp;
using System;
+using System.Diagnostics;
using Tensorflow;
using TensorFlowNET.Examples.Utility;
using static Tensorflow.Python;
@@ -144,6 +145,8 @@ namespace TensorFlowNET.Examples.ImageProcess
float loss_val = 100.0f;
float accuracy_val = 0f;
+ var sw = new Stopwatch();
+ sw.Start();
foreach (var epoch in range(epochs))
{
print($"Training epoch: {epoch + 1}");
@@ -165,7 +168,8 @@ namespace TensorFlowNET.Examples.ImageProcess
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_batch), new FeedItem(y, y_batch));
loss_val = result[0];
accuracy_val = result[1];
- print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}");
+ print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")} {sw.ElapsedMilliseconds}ms");
+ sw.Restart();
}
}
diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs
index 37f8d450..358f3fb9 100644
--- a/test/TensorFlowNET.UnitTest/OperationsTest.cs
+++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs
@@ -467,5 +467,679 @@ namespace TensorFlowNET.UnitTest
}
#endregion
}
+
+ private IEnumerable MultiplyArray(IReadOnlyCollection first, IReadOnlyCollection second)
+ {
+ if(first.Count != second.Count)
+ throw new ArgumentException("Arrays should be of equal size!");
+
+ var firstEnumerator = first.GetEnumerator();
+ var secondEnumerator = second.GetEnumerator();
+ var result = new List();
+ while (firstEnumerator.MoveNext())
+ {
+ secondEnumerator.MoveNext();
+ result.Add(firstEnumerator.Current * secondEnumerator.Current);
+ }
+
+ firstEnumerator.Dispose();
+ secondEnumerator.Dispose();
+
+ return result;
+ }
+ private IEnumerable MultiplyArray(IReadOnlyCollection first, IReadOnlyCollection second)
+ {
+ if(first.Count != second.Count)
+ throw new ArgumentException("Arrays should be of equal size!");
+
+ var firstEnumerator = first.GetEnumerator();
+ var secondEnumerator = second.GetEnumerator();
+ var result = new List();
+ while (firstEnumerator.MoveNext())
+ {
+ secondEnumerator.MoveNext();
+ result.Add(firstEnumerator.Current * secondEnumerator.Current);
+ }
+
+ firstEnumerator.Dispose();
+ secondEnumerator.Dispose();
+
+ return result;
+ }
+ private IEnumerable MultiplyArray(IReadOnlyCollection first, IReadOnlyCollection second)
+ {
+ if(first.Count != second.Count)
+ throw new ArgumentException("Arrays should be of equal size!");
+
+ var firstEnumerator = first.GetEnumerator();
+ var secondEnumerator = second.GetEnumerator();
+ var result = new List();
+ while (firstEnumerator.MoveNext())
+ {
+ secondEnumerator.MoveNext();
+ result.Add(firstEnumerator.Current * secondEnumerator.Current);
+ }
+
+ firstEnumerator.Dispose();
+ secondEnumerator.Dispose();
+
+ return result;
+ }
+
+ [TestMethod]
+ public void mulOpTests()
+ {
+ const int rows = 2; // to avoid broadcasting effect
+ const int cols = 10;
+
+ #region intTest
+ const int firstIntVal = 2;
+ const int secondIntVal = 3;
+
+ var firstIntFeed = Enumerable.Repeat(firstIntVal, rows * cols).ToArray();
+ var secondIntFeed = Enumerable.Repeat(secondIntVal, rows * cols).ToArray();
+ var intResult = MultiplyArray(firstIntFeed, secondIntFeed).Sum();
+
+ var a = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols));
+ var b = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols));
+ var c = tf.reduce_sum(tf.reduce_sum(tf.multiply(a, b), 1));
+
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
+ new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, intResult);
+ }
+
+ // Testing `operator *(Tensor x, Tensor y)
+ c = tf.reduce_sum(tf.reduce_sum(a * b, 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
+ new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, intResult);
+ }
+
+ // Testing `operator *(Tensor x, int y)
+ c = tf.reduce_sum(tf.reduce_sum(a * secondIntVal, 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, intResult);
+ }
+
+ // Testing `operator *(int x, Tensor y)
+ c = tf.reduce_sum(tf.reduce_sum(firstIntVal * b, 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, intResult);
+ }
+ #endregion
+
+ #region floatTest
+ const float firstFloatVal = 2.0f;
+ const float secondFloatVal = 3.0f;
+
+ var firstFloatFeed = Enumerable.Repeat(firstFloatVal, rows * cols).ToArray();
+ var secondFloatFeed = Enumerable.Repeat(secondFloatVal, rows * cols).ToArray();
+ var floatResult = MultiplyArray(firstFloatFeed, secondFloatFeed).Sum();
+
+ a = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols));
+ b = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols));
+ c = tf.reduce_sum(tf.reduce_sum(tf.multiply(a, b), 1));
+
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
+ new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
+ Assert.AreEqual((float)o, floatResult);
+ }
+
+ // Testing `operator *(Tensor x, Tensor y)
+ c = tf.reduce_sum(tf.reduce_sum(a * b, 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
+ new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
+ Assert.AreEqual((float)o, floatResult);
+ }
+
+ // Testing `operator *(Tensor x, float y)
+ c = tf.reduce_sum(tf.reduce_sum(a * secondFloatVal, 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
+ Assert.AreEqual((float)o, floatResult);
+ }
+
+ // Testing `operator *(float x, Tensor y)
+ c = tf.reduce_sum(tf.reduce_sum(firstFloatVal * b, 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
+ Assert.AreEqual((float)o, floatResult);
+ }
+ #endregion
+
+ #region doubleTest
+ const double firstDoubleVal = 2.0;
+ const double secondDoubleVal = 3.0;
+
+ var firstDoubleFeed = Enumerable.Repeat(firstDoubleVal, rows * cols).ToArray();
+ var secondDoubleFeed = Enumerable.Repeat(secondDoubleVal, rows * cols).ToArray();
+ var doubleResult = MultiplyArray(firstDoubleFeed, secondDoubleFeed).Sum();
+
+ a = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols));
+ b = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols));
+ c = tf.reduce_sum(tf.reduce_sum(tf.multiply(a, b), 1));
+
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
+ new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
+ Assert.AreEqual((double)o, doubleResult);
+ }
+
+ // Testing `operator *(Tensor x, Tensor y)
+ c = tf.reduce_sum(tf.reduce_sum(a * b, 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
+ new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
+ Assert.AreEqual((double)o, doubleResult);
+ }
+
+ // Testing `operator *(Tensor x, double y)
+ c = tf.reduce_sum(tf.reduce_sum(a * secondFloatVal, 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
+ Assert.AreEqual((double)o, doubleResult);
+ }
+
+ // Testing `operator *(double x, Tensor y)
+ c = tf.reduce_sum(tf.reduce_sum(firstFloatVal * b, 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
+ Assert.AreEqual((double) o, doubleResult);
+ }
+ #endregion
+ }
+
+ [TestMethod]
+ public void divOpTests()
+ {
+ const int rows = 2; // to avoid broadcasting effect
+ const int cols = 10;
+
+ #region intTest
+ const int firstIntVal = 6;
+ const int secondIntVal = 3;
+
+ var firstIntFeed = Enumerable.Repeat(firstIntVal, rows * cols).ToArray();
+ var secondIntFeed = Enumerable.Repeat(secondIntVal, rows * cols).ToArray();
+ var intResult = (int)(firstIntFeed.Sum() / (float)secondIntVal);
+
+ var a = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols));
+ var b = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols));
+ var c = tf.reduce_sum(tf.reduce_sum(gen_math_ops.floor_div(a, b), 1));
+
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
+ new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, intResult);
+ }
+
+ // Testing `operator /(Tensor x, Tensor y)
+ c = tf.reduce_sum(tf.reduce_sum(a / b, 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
+ new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, intResult);
+ }
+
+ // Testing `operator /(Tensor x, int y)
+ c = tf.reduce_sum(tf.reduce_sum(a / secondIntVal, 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, intResult);
+ }
+
+ // Testing `operator /(int x, Tensor y)
+ c = tf.reduce_sum(tf.reduce_sum(firstIntVal / b, 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, intResult);
+ }
+ #endregion
+
+ #region floatTest
+ const float firstFloatVal = 6.0f;
+ const float secondFloatVal = 3.0f;
+
+ var firstFloatFeed = Enumerable.Repeat(firstFloatVal, rows * cols).ToArray();
+ var secondFloatFeed = Enumerable.Repeat(secondFloatVal, rows * cols).ToArray();
+ var floatResult = MultiplyArray(firstFloatFeed, secondFloatFeed.Select(x => 1/x).ToArray()).Sum();
+
+ a = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols));
+ b = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols));
+ c = tf.reduce_sum(tf.reduce_sum(tf.divide(a, b), 1));
+
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
+ new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
+ Assert.AreEqual((float)o, floatResult);
+ }
+
+ // Testing `operator /(Tensor x, Tensor y)
+ c = tf.reduce_sum(tf.reduce_sum(a / b, 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
+ new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
+ Assert.AreEqual((float)o, floatResult);
+ }
+
+ // Testing `operator /(Tensor x, float y)
+ c = tf.reduce_sum(tf.reduce_sum(a / secondFloatVal, 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
+ Assert.AreEqual((float)o, floatResult);
+ }
+
+ // Testing `operator /(float x, Tensor y)
+ c = tf.reduce_sum(tf.reduce_sum(firstFloatVal / b, 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
+ Assert.AreEqual((float)o, floatResult);
+ }
+ #endregion
+
+ #region doubleTest
+ const double firstDoubleVal = 6.0;
+ const double secondDoubleVal = 3.0;
+
+ var firstDoubleFeed = Enumerable.Repeat(firstDoubleVal, rows * cols).ToArray();
+ var secondDoubleFeed = Enumerable.Repeat(secondDoubleVal, rows * cols).ToArray();
+ var doubleResult = MultiplyArray(firstDoubleFeed, secondDoubleFeed.Select(x => 1/x).ToArray()).Sum();
+
+ a = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols));
+ b = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols));
+ c = tf.reduce_sum(tf.reduce_sum(tf.divide(a, b), 1));
+
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
+ new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
+ Assert.AreEqual((double)o, doubleResult);
+ }
+
+ // Testing `operator /(Tensor x, Tensor y)
+ c = tf.reduce_sum(tf.reduce_sum(a / b, 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
+ new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
+ Assert.AreEqual((double)o, doubleResult);
+ }
+
+ // Testing `operator /(Tensor x, double y)
+ c = tf.reduce_sum(tf.reduce_sum(a / secondFloatVal, 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
+ Assert.AreEqual((double)o, doubleResult);
+ }
+
+ // Testing `operator /(double x, Tensor y)
+ c = tf.reduce_sum(tf.reduce_sum(firstFloatVal / b, 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
+ Assert.AreEqual((double)o, doubleResult);
+ }
+ #endregion
+ }
+
+ [TestMethod]
+ public void greaterThanOpTests()
+ {
+ const int rows = 2; // to avoid broadcasting effect
+ const int cols = 10;
+
+ #region intTest
+ const int intThreshold = 10;
+
+ var firstIntFeed = Enumerable.Range(0, rows * cols).ToArray();
+ var secondIntFeed = Enumerable.Repeat(intThreshold, rows * cols).ToArray();
+ var intResult = firstIntFeed.Count(elem => elem > intThreshold);
+ var intResultTwo = firstIntFeed.Count(elem => elem < intThreshold);
+
+ var a = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols));
+ var b = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols));
+ var c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.greater(a, b), tf.int32), 1));
+
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
+ new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, intResult);
+ }
+
+ // Testing `operator >(Tensor x, Tensor y)
+ c = tf.reduce_sum(tf.reduce_sum(tf.cast(a > b, tf.int32), 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
+ new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, intResult);
+ }
+
+ // Testing `operator >(Tensor x, int y)
+ c = tf.reduce_sum(tf.reduce_sum(tf.cast(a > intThreshold, tf.int32), 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, intResult);
+ }
+
+ // Testing `operator >(int x, Tensor y)
+ c = tf.reduce_sum(tf.reduce_sum(tf.cast(intThreshold > a, tf.int32), 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, intResultTwo);
+ }
+ #endregion
+
+ #region floatTest
+ const float floatThreshold = 10.0f;
+
+ var firstFloatFeed = Enumerable.Range(0, rows * cols).Select(elem => (float)elem).ToArray();
+ var secondFloatFeed = Enumerable.Repeat(floatThreshold, rows * cols).ToArray();
+ var floatResult = firstFloatFeed.Count(elem => elem > floatThreshold);
+ var floatResultTwo = firstFloatFeed.Count(elem => elem < floatThreshold);
+
+ a = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols));
+ b = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols));
+ c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.greater(a, b), tf.int32), 1));
+
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
+ new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, floatResult);
+ }
+
+ // Testing `operator >(Tensor x, Tensor y)
+ c = tf.reduce_sum(tf.reduce_sum(tf.cast(a > b, tf.int32), 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
+ new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, floatResult);
+ }
+
+ // Testing `operator >(Tensor x, float y)
+ c = tf.reduce_sum(tf.reduce_sum(tf.cast(a > floatThreshold, tf.int32), 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, floatResult);
+ }
+
+ // Testing `operator >(float x, Tensor y)
+ c = tf.reduce_sum(tf.reduce_sum(tf.cast(floatThreshold > a, tf.int32), 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, floatResultTwo);
+ }
+ #endregion
+
+ #region doubleTest
+ const double doubleThreshold = 10.0;
+
+ var firstDoubleFeed = Enumerable.Repeat(0, rows * cols).Select(elem => (double)elem).ToArray();
+ var secondDoubleFeed = Enumerable.Repeat(doubleThreshold, rows * cols).ToArray();
+ var doubleResult = firstDoubleFeed.Count(elem => elem > doubleThreshold);
+ var doubleResultTwo = firstDoubleFeed.Count(elem => elem < doubleThreshold);
+
+ a = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols));
+ b = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols));
+ c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.greater(a, b), tf.int32), 1));
+
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
+ new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, doubleResult);
+ }
+
+ // Testing `operator >(Tensor x, Tensor y)
+ c = tf.reduce_sum(tf.reduce_sum(tf.cast(a > b, tf.int32), 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
+ new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, doubleResult);
+ }
+
+ // Testing `operator >(Tensor x, double y)
+ c = tf.reduce_sum(tf.reduce_sum(tf.cast(a > doubleThreshold, tf.int32), 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, doubleResult);
+ }
+
+ // Testing `operator >(double x, Tensor y)
+ c = tf.reduce_sum(tf.reduce_sum(tf.cast(doubleThreshold > a, tf.int32), 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, doubleResultTwo);
+ }
+ #endregion
+ }
+
+ [TestMethod]
+ public void lessThanOpTests()
+ {
+ const int rows = 2; // to avoid broadcasting effect
+ const int cols = 10;
+
+ #region intTest
+ const int intThreshold = 10;
+
+ var firstIntFeed = Enumerable.Range(0, rows * cols).ToArray();
+ var secondIntFeed = Enumerable.Repeat(intThreshold, rows * cols).ToArray();
+ var intResult = firstIntFeed.Count(elem => elem < intThreshold);
+ var intResultTwo = firstIntFeed.Count(elem => elem > intThreshold);
+
+ var a = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols));
+ var b = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols));
+ var c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.less(a, b), tf.int32), 1));
+
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
+ new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, intResult);
+ }
+
+ // Testing `operator <(Tensor x, Tensor y)
+ c = tf.reduce_sum(tf.reduce_sum(tf.cast(a < b, tf.int32), 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
+ new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, intResult);
+ }
+
+ // Testing `operator <(Tensor x, int y)
+ c = tf.reduce_sum(tf.reduce_sum(tf.cast(a < intThreshold, tf.int32), 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, intResult);
+ }
+
+ // Testing `operator <(int x, Tensor y)
+ c = tf.reduce_sum(tf.reduce_sum(tf.cast(intThreshold < a, tf.int32), 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, intResultTwo);
+ }
+ #endregion
+
+ #region floatTest
+ const float floatThreshold = 10.0f;
+
+ var firstFloatFeed = Enumerable.Range(0, rows * cols).Select(elem => (float)elem).ToArray();
+ var secondFloatFeed = Enumerable.Repeat(floatThreshold, rows * cols).ToArray();
+ var floatResult = firstFloatFeed.Count(elem => elem < floatThreshold);
+ var floatResultTwo = firstFloatFeed.Count(elem => elem > floatThreshold);
+
+ a = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols));
+ b = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols));
+ c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.less(a, b), tf.int32), 1));
+
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
+ new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, floatResult);
+ }
+
+ // Testing `operator <(Tensor x, Tensor y)
+ c = tf.reduce_sum(tf.reduce_sum(tf.cast(a < b, tf.int32), 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
+ new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, floatResult);
+ }
+
+ // Testing `operator <(Tensor x, float y)
+ c = tf.reduce_sum(tf.reduce_sum(tf.cast(a < floatThreshold, tf.int32), 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, floatResult);
+ }
+
+ // Testing `operator <(float x, Tensor y)
+ c = tf.reduce_sum(tf.reduce_sum(tf.cast(floatThreshold < a, tf.int32), 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, floatResultTwo);
+ }
+ #endregion
+
+ #region doubleTest
+ const double doubleThreshold = 10.0;
+
+ var firstDoubleFeed = Enumerable.Repeat(0, rows * cols).Select(elem => (double)elem).ToArray();
+ var secondDoubleFeed = Enumerable.Repeat(doubleThreshold, rows * cols).ToArray();
+ var doubleResult = firstDoubleFeed.Count(elem => elem < doubleThreshold);
+ var doubleResultTwo = firstDoubleFeed.Count(elem => elem > doubleThreshold);
+
+ a = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols));
+ b = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols));
+ c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.less(a, b), tf.int32), 1));
+
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
+ new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, doubleResult);
+ }
+
+ // Testing `operator <(Tensor x, Tensor y)
+ c = tf.reduce_sum(tf.reduce_sum(tf.cast(a < b, tf.int32), 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
+ new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, doubleResult);
+ }
+
+ // Testing `operator <(Tensor x, double y)
+ c = tf.reduce_sum(tf.reduce_sum(tf.cast(a < doubleThreshold, tf.int32), 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, doubleResult);
+ }
+
+ // Testing `operator <(double x, Tensor y)
+ c = tf.reduce_sum(tf.reduce_sum(tf.cast(doubleThreshold < a, tf.int32), 1));
+ using (var sess = tf.Session())
+ {
+ var o = sess.run(c,
+ new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
+ Assert.AreEqual((int)o, doubleResultTwo);
+ }
+ #endregion
+ }
}
}