From 6c8c2e5ec9274eafebb86ed0f216e63be6586ca2 Mon Sep 17 00:00:00 2001 From: Eli Belash Date: Thu, 22 Aug 2019 06:58:50 +0300 Subject: [PATCH] Performance optimization, refactoring and revamping. (#362) * Refactored DisposableObject * Added different build directory for TensorflowNET.Examples.GPU * _FetchHandler: Switched to NPTypeCode * gfile.cs, Walk(...): Handle case when directory top doesn't exist. * Tensor.Creation: Perf-opted when creating tensor from NDArray of string * Graph.cs: refactor and added docs * Tensor.Creation.cs: perf-ops * Tensor.Explicit.cs: perf-ops * Copied globals.regen from NumSharp - Added supported_numericals_TF_DataType * Tensor perf-ops and cleanup, Revamped dtypes.cs, some renames. - Cleanup and docs to all Tensor.cs files - Changed all uses of System.Convert to NumSharp.Utilities.Converts - Added all missing types in dtypes.cs - Renamed tensor.Data to tensor.ToArray, added obsolete message - Renamed tensor.Data() to tensor.BufferToArray(), added obsolete message - Made GraphKeys to use const string instead allocating strings at every use of GraphKeys. * Tensor: Added guards for explicit casts. * Tensor: Added explicit cast to string * Tensor.ToArray(): Added support for cases when tensor is scalar. * Tensor.BufferToArray(): Fixed to use long instead of int. * TensorShape: Revamped and documented. * BaseSession: Added Session.run(ITensorOrOperation fetche, params FeedItem[] feed_dict) * Tensor: renamed _dtype to _override_dtype - Fixed all locations _dtype is used incorrectly. * Fixed unit tests * Tensor.Operations: Reverted commit * DisposableObject: sorted internal_dispose to properly handle Dispose() calls * Tensor.DisposeUnmanagedResources: Nullify _handle after delete. * TensorShape.this[...]: fixed guard check. * DisposableObject #362 --- src/TensorFlowNET.Core/APIs/c_api.cs | 2 +- src/TensorFlowNET.Core/Binding.Util.cs | 14 +- src/TensorFlowNET.Core/Buffers/Buffer.cs | 2 +- src/TensorFlowNET.Core/DisposableObject.cs | 37 +- .../Eager/ContextOptions.cs | 3 +- src/TensorFlowNET.Core/Graphs/Graph.cs | 103 ++--- .../Graphs/ImportGraphDefOptions.cs | 2 +- src/TensorFlowNET.Core/IO/gfile.cs | 4 + .../ControlFlows/ControlFlowContext.cs | 2 +- .../Operations/NnOps/rnn.cs | 2 +- .../Operations/array_ops.py.cs | 2 +- src/TensorFlowNET.Core/Operations/nn_ops.cs | 2 +- .../Sessions/BaseSession.cs | 15 +- .../Sessions/SessionOptions.cs | 2 +- .../Sessions/_FetchHandler.cs | 23 +- src/TensorFlowNET.Core/Status/Status.cs | 6 +- .../Tensors/Tensor.Creation.cs | 139 ++++--- .../Tensors/Tensor.Explicit.cs | 113 ++++-- .../Tensors/Tensor.Operators.cs | 7 +- src/TensorFlowNET.Core/Tensors/Tensor.cs | 377 +++++++++++++----- src/TensorFlowNET.Core/Tensors/TensorShape.cs | 108 ++++- src/TensorFlowNET.Core/Tensors/dtypes.cs | 99 +++-- src/TensorFlowNET.Core/Tensors/tensor_util.cs | 15 +- src/TensorFlowNET.Core/globals.regen | 38 ++ src/TensorFlowNET.Core/ops.GraphKeys.cs | 84 +++- .../TensorFlowNET.Examples.GPU.csproj | 8 + test/TensorFlowNET.UnitTest/ConstantTest.cs | 34 +- test/TensorFlowNET.UnitTest/GradientTest.cs | 6 +- test/TensorFlowNET.UnitTest/OperationsTest.cs | 222 +++++------ .../TensorFlowNET.UnitTest/PlaceholderTest.cs | 2 +- test/TensorFlowNET.UnitTest/SessionTest.cs | 4 +- test/TensorFlowNET.UnitTest/TensorTest.cs | 2 +- test/TensorFlowNET.UnitTest/VariableTest.cs | 13 +- .../nn_test/ZeroFractionTest.cs | 2 +- 34 files changed, 994 insertions(+), 500 deletions(-) create mode 100644 src/TensorFlowNET.Core/globals.regen diff --git a/src/TensorFlowNET.Core/APIs/c_api.cs b/src/TensorFlowNET.Core/APIs/c_api.cs index 1656edd0..adf0b86f 100644 --- a/src/TensorFlowNET.Core/APIs/c_api.cs +++ b/src/TensorFlowNET.Core/APIs/c_api.cs @@ -59,6 +59,6 @@ namespace Tensorflow } [DllImport(TensorFlowLibName)] - public static unsafe extern IntPtr TF_Version(); + public static extern IntPtr TF_Version(); } } diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index da2cdf6e..bfbfa4ec 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -308,15 +308,14 @@ namespace Tensorflow public static IEnumerable TupleToEnumerable(object tuple) { Type t = tuple.GetType(); - if(t.IsGenericType && (t.FullName.StartsWith("System.Tuple") || t.FullName.StartsWith("System.ValueTuple"))) + if (t.IsGenericType && (t.FullName.StartsWith("System.Tuple") || t.FullName.StartsWith("System.ValueTuple"))) { var flds = t.GetFields(); - for(int i = 0; i < flds.Length;i++) + for (int i = 0; i < flds.Length; i++) { yield return flds[i].GetValue(tuple); } - } - else + } else { throw new System.Exception("Expected Tuple."); } @@ -329,12 +328,9 @@ namespace Tensorflow public static bool isinstance(object Item1, object tuple) { - var tup = TupleToEnumerable(tuple); - foreach(var t in tup) - { - if(isinstance(Item1, (Type)t)) + foreach (var t in TupleToEnumerable(tuple)) + if (isinstance(Item1, (Type) t)) return true; - } return false; } } diff --git a/src/TensorFlowNET.Core/Buffers/Buffer.cs b/src/TensorFlowNET.Core/Buffers/Buffer.cs index dbe576b8..396fb311 100644 --- a/src/TensorFlowNET.Core/Buffers/Buffer.cs +++ b/src/TensorFlowNET.Core/Buffers/Buffer.cs @@ -66,7 +66,7 @@ namespace Tensorflow return buffer.Data; } - protected override void DisposeUnManagedState(IntPtr handle) + protected override void DisposeUnmanagedResources(IntPtr handle) => c_api.TF_DeleteBuffer(handle); } } diff --git a/src/TensorFlowNET.Core/DisposableObject.cs b/src/TensorFlowNET.Core/DisposableObject.cs index 7e416e6d..688ac92c 100644 --- a/src/TensorFlowNET.Core/DisposableObject.cs +++ b/src/TensorFlowNET.Core/DisposableObject.cs @@ -29,18 +29,10 @@ namespace Tensorflow protected DisposableObject() { } - public DisposableObject(IntPtr handle) - { - _handle = handle; - } - - protected virtual void DisposeManagedState() - { - } + protected DisposableObject(IntPtr handle) + => _handle = handle; - protected abstract void DisposeUnManagedState(IntPtr handle); - - protected virtual void Dispose(bool disposing) + private void internal_dispose(bool disposing) { if (disposing) { @@ -48,30 +40,43 @@ namespace Tensorflow if (_handle != IntPtr.Zero) { // dispose managed state (managed objects). - DisposeManagedState(); + DisposeManagedResources(); // set large fields to null. - DisposeUnManagedState(_handle); + DisposeUnmanagedResources(_handle); _handle = IntPtr.Zero; } } } + /// + /// Dispose any managed resources. + /// + /// Equivalent to what you would perform inside + protected virtual void DisposeManagedResources() + { + } + + /// + /// Dispose any unmanaged resources related to given . + /// + protected abstract void DisposeUnmanagedResources(IntPtr handle); + // override a finalizer only if Dispose(bool disposing) above has code to free unmanaged resources. ~DisposableObject() { // Do not change this code. Put cleanup code in Dispose(bool disposing) above. - Dispose(false); + internal_dispose(false); } // This code added to correctly implement the disposable pattern. public void Dispose() { // Do not change this code. Put cleanup code in Dispose(bool disposing) above. - Dispose(true); + internal_dispose(true); // uncomment the following line if the finalizer is overridden above. GC.SuppressFinalize(this); } } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Eager/ContextOptions.cs b/src/TensorFlowNET.Core/Eager/ContextOptions.cs index 4bffddf6..4bdf04b3 100644 --- a/src/TensorFlowNET.Core/Eager/ContextOptions.cs +++ b/src/TensorFlowNET.Core/Eager/ContextOptions.cs @@ -1,8 +1,9 @@ using System; +using System.IO; namespace Tensorflow.Eager { - public class ContextOptions : IDisposable + public class ContextOptions : IDisposable //TODO! Eli: Shouldn't this inherieting DisposableObject? { private IntPtr _handle; diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 77926dca..07dc117e 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -23,57 +23,58 @@ using static Tensorflow.Binding; namespace Tensorflow { + /* + A TensorFlow computation, represented as a dataflow graph. + + A `Graph` contains a set of + `tf.Operation` objects, + which represent units of computation; and + `tf.Tensor` objects, which represent + the units of data that flow between operations. + + A default `Graph` is always registered, and accessible by calling + `tf.get_default_graph`. + To add an operation to the default graph, simply call one of the functions + that defines a new `Operation`: + + ```python + c = tf.constant(4.0) + assert c.graph is tf.get_default_graph() + ``` + + Another typical usage involves the + `tf.Graph.as_default` + context manager, which overrides the current default graph for the + lifetime of the context: + + ```python + g = tf.Graph() + with g.as_default(): + # Define operations and tensors in `g`. + c = tf.constant(30.0) + assert c.graph is g + ``` + + Important note: This class *is not* thread-safe for graph construction. All + operations should be created from a single thread, or external + synchronization must be provided. Unless otherwise specified, all methods + are not thread-safe. + + A `Graph` instance supports an arbitrary number of "collections" + that are identified by name. For convenience when building a large + graph, collections can store groups of related objects: for + example, the `tf.Variable` uses a collection (named + `tf.GraphKeys.GLOBAL_VARIABLES`) for + all variables that are created during the construction of a graph. The caller + may define additional collections by specifying a new name. + */ + /// - /// TensorFlow uses a dataflow graph to represent your computation in terms of the dependencies between individual operations. - /// This leads to a low-level programming model in which you first define the dataflow graph, - /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. - /// https://www.tensorflow.org/guide/graphs + /// TensorFlow uses a dataflow graph to represent your computation in terms of the dependencies between individual operations. + /// This leads to a low-level programming model in which you first define the dataflow graph, + /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. /// - /* - A TensorFlow computation, represented as a dataflow graph. - - A `Graph` contains a set of - `tf.Operation` objects, - which represent units of computation; and - `tf.Tensor` objects, which represent - the units of data that flow between operations. - - A default `Graph` is always registered, and accessible by calling - `tf.get_default_graph`. - To add an operation to the default graph, simply call one of the functions - that defines a new `Operation`: - - ```python - c = tf.constant(4.0) - assert c.graph is tf.get_default_graph() - ``` - - Another typical usage involves the - `tf.Graph.as_default` - context manager, which overrides the current default graph for the - lifetime of the context: - - ```python - g = tf.Graph() - with g.as_default(): - # Define operations and tensors in `g`. - c = tf.constant(30.0) - assert c.graph is g - ``` - - Important note: This class *is not* thread-safe for graph construction. All - operations should be created from a single thread, or external - synchronization must be provided. Unless otherwise specified, all methods - are not thread-safe. - - A `Graph` instance supports an arbitrary number of "collections" - that are identified by name. For convenience when building a large - graph, collections can store groups of related objects: for - example, the `tf.Variable` uses a collection (named - `tf.GraphKeys.GLOBAL_VARIABLES`) for - all variables that are created during the construction of a graph. The caller - may define additional collections by specifying a new name. - */ + /// https://www.tensorflow.org/guide/graphs

https://www.tensorflow.org/api_docs/python/tf/Graph
public partial class Graph : DisposableObject, IEnumerable { private Dictionary _nodes_by_id; @@ -439,12 +440,12 @@ namespace Tensorflow _unfetchable_ops.Add(op); } - protected override void DisposeManagedState() + protected override void DisposeManagedResources() { ops.default_graph_stack.remove(this); } - protected override void DisposeUnManagedState(IntPtr handle) + protected override void DisposeUnmanagedResources(IntPtr handle) { c_api.TF_DeleteGraph(handle); } diff --git a/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs b/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs index 97720206..bdcaf60c 100644 --- a/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs +++ b/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs @@ -37,7 +37,7 @@ namespace Tensorflow c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); } - protected override void DisposeUnManagedState(IntPtr handle) + protected override void DisposeUnmanagedResources(IntPtr handle) => c_api.TF_DeleteImportGraphDefOptions(handle); public static implicit operator IntPtr(ImportGraphDefOptions opts) => opts._handle; diff --git a/src/TensorFlowNET.Core/IO/gfile.cs b/src/TensorFlowNET.Core/IO/gfile.cs index 930dd652..a7303bf6 100644 --- a/src/TensorFlowNET.Core/IO/gfile.cs +++ b/src/TensorFlowNET.Core/IO/gfile.cs @@ -16,6 +16,7 @@ using System.Collections.Generic; using System.IO; +using System.Linq; namespace Tensorflow.IO { @@ -28,6 +29,9 @@ namespace Tensorflow.IO /// Traverse in order if True, post order if False. public IEnumerable<(string, string[], string[])> Walk(string top, bool in_order = true) { + if (!Directory.Exists(top)) + return Enumerable.Empty<(string, string[], string[])>(); + return walk_v2(top, in_order); } diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs index 2c05e36a..2a76c52c 100644 --- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs +++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs @@ -141,7 +141,7 @@ namespace Tensorflow.Operations data, frame_name, is_constant, parallel_iterations, name: name); if (use_input_shape) - result.SetShape(data.TensorShape); + result.set_shape(data.TensorShape); return result; } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs index 3198942b..1b68d1cd 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs @@ -233,7 +233,7 @@ namespace Tensorflow.Operations dims.AddRange(x_static_shape.dims.Skip(2)); var shape = new TensorShape(dims.ToArray()); - x_t.SetShape(shape); + x_t.set_shape(shape); return x_t; } diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs index d3213250..92fe2e3c 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs @@ -351,7 +351,7 @@ namespace Tensorflow var input_shape = tensor_util.to_shape(input_tensor.shape); if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined()) { - var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_datatype()); + var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_dtype()); return constant_op.constant(nd, name: name); } } diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs index cbf55861..63e0fca1 100644 --- a/src/TensorFlowNET.Core/Operations/nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs @@ -98,7 +98,7 @@ namespace Tensorflow // float to be selected, hence we use a >= comparison. var keep_mask = random_tensor >= rate; var ret = x * scale * math_ops.cast(keep_mask, x.dtype); - ret.SetShape(x.TensorShape); + ret.set_shape(x.TensorShape); return ret; }); } diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index efe2afd4..4c5f2be3 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -49,7 +49,7 @@ namespace Tensorflow // dispose newOpts if (opts == null) - c_api.TF_DeleteSessionOptions(newOpts); + newOpts.Dispose(); status.Check(true); } @@ -64,6 +64,11 @@ namespace Tensorflow return _run(fetche, feed_dict)[0]; } + public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict) + { + return _run(fetche, feed_dict)[0]; + } + public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) { var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); @@ -273,7 +278,7 @@ namespace Tensorflow { var tensor = new Tensor(output); NDArray nd = null; - Type type = tensor.dtype.as_numpy_datatype(); + Type type = tensor.dtype.as_numpy_dtype(); var ndims = tensor.shape; var offset = c_api.TF_TensorData(output); @@ -285,7 +290,7 @@ namespace Tensorflow nd = NDArray.Scalar(*(bool*)offset); break; case TF_DataType.TF_STRING: - var bytes = tensor.Data(); + var bytes = tensor.BufferToArray(); // wired, don't know why we have to start from offset 9. // length in the begin var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); @@ -324,7 +329,7 @@ namespace Tensorflow nd = np.array(bools).reshape(ndims); break; case TF_DataType.TF_STRING: - var bytes = tensor.Data(); + var bytes = tensor.BufferToArray(); // wired, don't know why we have to start from offset 9. // length in the begin var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); @@ -396,7 +401,7 @@ namespace Tensorflow Dispose(); } - protected override void DisposeUnManagedState(IntPtr handle) + protected override void DisposeUnmanagedResources(IntPtr handle) { using (var status = new Status()) { diff --git a/src/TensorFlowNET.Core/Sessions/SessionOptions.cs b/src/TensorFlowNET.Core/Sessions/SessionOptions.cs index 8e0a0a74..ed99b7fe 100644 --- a/src/TensorFlowNET.Core/Sessions/SessionOptions.cs +++ b/src/TensorFlowNET.Core/Sessions/SessionOptions.cs @@ -32,7 +32,7 @@ namespace Tensorflow _handle = handle; } - protected override void DisposeUnManagedState(IntPtr handle) + protected override void DisposeUnmanagedResources(IntPtr handle) => c_api.TF_DeleteSessionOptions(handle); public void SetConfig(ConfigProto config) diff --git a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs index a46decb1..e1a77d90 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs @@ -17,6 +17,7 @@ using NumSharp; using System; using System.Collections.Generic; +using NumSharp.Backends; namespace Tensorflow { @@ -71,18 +72,18 @@ namespace Tensorflow { if(tensor_values.Length > 0) { - switch (tensor_values[0].dtype.Name) + switch (tensor_values[0].typecode) { - case "Int32": + case NPTypeCode.Int32: full_values.Add(float.NaN); break; - case "Single": + case NPTypeCode.Single: full_values.Add(float.NaN); break; - case "String": + case NPTypeCode.String: full_values.Add(float.NaN); break; - case "Char": + case NPTypeCode.Char: full_values.Add(float.NaN); break; default: @@ -100,21 +101,21 @@ namespace Tensorflow j += 1; if (value.ndim == 0) { - switch (value.dtype.Name) + switch (value.typecode) { - case "Int16": + case NPTypeCode.Int16: full_values.Add(value.GetValue(0)); break; - case "Int32": + case NPTypeCode.Int32: full_values.Add(value.GetValue(0)); break; - case "Int64": + case NPTypeCode.Int64: full_values.Add(value.GetValue(0)); break; - case "Single": + case NPTypeCode.Single: full_values.Add(value.GetValue(0)); break; - case "Double": + case NPTypeCode.Double: full_values.Add(value.GetValue(0)); break; /*case "String": diff --git a/src/TensorFlowNET.Core/Status/Status.cs b/src/TensorFlowNET.Core/Status/Status.cs index 7eb2d7e3..2bdd806a 100644 --- a/src/TensorFlowNET.Core/Status/Status.cs +++ b/src/TensorFlowNET.Core/Status/Status.cs @@ -50,7 +50,7 @@ namespace Tensorflow /// public void Check(bool throwException = false) { - if(Code != TF_Code.TF_OK) + if (Code != TF_Code.TF_OK) { Console.WriteLine(Message); if (throwException) @@ -65,7 +65,7 @@ namespace Tensorflow return status._handle; } - protected override void DisposeUnManagedState(IntPtr handle) + protected override void DisposeUnmanagedResources(IntPtr handle) => c_api.TF_DeleteStatus(handle); } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index 63fda866..73f116ec 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -16,11 +16,13 @@ using NumSharp; using System; +using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Numerics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Text; +using NumSharp.Backends; using NumSharp.Backends.Unmanaged; using static Tensorflow.c_api; @@ -462,7 +464,7 @@ namespace Tensorflow *v = value; _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(Complex)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(Complex), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); IsMemoryOwner=true; - } + } #endif /// @@ -477,7 +479,7 @@ namespace Tensorflow IntPtr tensor = c_api.TF_TensorData(handle); Marshal.WriteInt64(tensor, 0); - fixed (byte* src = &buffer[0]) + fixed (byte* src = buffer) c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); _handle = handle; status.Check(true); @@ -486,35 +488,55 @@ namespace Tensorflow public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null) { // todo: handle nd of type "String" here too - if (tensorDType == TF_DataType.TF_STRING && nd.dtype.Name == "Byte") + if (tensorDType == TF_DataType.TF_STRING && nd.typecode == NPTypeCode.Byte) { - var buffer = nd.ToArray(); - var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); - var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); - - IntPtr tensor = c_api.TF_TensorData(handle); - Marshal.WriteInt64(tensor, 0); + if (nd.Unsafe.Storage.Shape.IsContiguous) + { + var bytesLength = (UIntPtr)nd.size; + var size = c_api.TF_StringEncodedSize(bytesLength); + var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); + + IntPtr tensor = c_api.TF_TensorData(handle); + Marshal.WriteInt64(tensor, 0); + + var status = new Status(); + c_api.TF_StringEncode((byte*) nd.Unsafe.Address, bytesLength, (sbyte*) (tensor + sizeof(Int64)), size, status); + + status.Check(true); + _handle = handle; + IsMemoryOwner = false; + } + else + { + var buffer = nd.ToArray(); + var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length); + var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); + + IntPtr tensor = c_api.TF_TensorData(handle); + Marshal.WriteInt64(tensor, 0); + + var status = new Status(); + fixed (byte* src = buffer) + c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(Int64)), size, status); + + status.Check(true); + _handle = handle; + IsMemoryOwner = false; + } - var status = new Status(); - fixed (byte* src = &buffer[0]) - c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); - - status.Check(true); - _handle=handle; - IsMemoryOwner = false; return; } + _handle = CreateTensorFromNDArray(nd, tensorDType); IsMemoryOwner = true; } private unsafe IntPtr CreateTensorFromNDArray(NDArray nd, TF_DataType? given_dtype) { - if (nd.dtype.Name == "String") - throw new NotImplementedException("Support for NDArray of type string not implemented yet"); + if (nd.dtype.Name == "String") + throw new NotImplementedException("Support for NDArray of type string not implemented yet"); IArraySlice arraySlice; - var shape = nd.Unsafe.Storage.Shape; - if (shape.IsSliced || shape.IsBroadcasted) + if (nd.Unsafe.Storage.Shape.IsContiguous == false) { // the memory is NOT contiguous, so we have to copy the view into a contiguous memory block. arraySlice = nd.CloneData(); @@ -527,51 +549,52 @@ namespace Tensorflow this.Tag = arraySlice; // keep a reference to the memory block to make sure it is not disposed while TF is using it var ptr = new IntPtr(arraySlice.Address); int num_bytes = (nd.size * nd.dtypesize); - var dtype = given_dtype ?? ToTFDataType(nd.dtype); + var dtype = given_dtype ?? nd.dtype.as_dtype(); var handle = TF_NewTensor(dtype, dims: nd.shape.Select(i=>(long)i).ToArray(), num_dims: nd.ndim, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs); IsMemoryOwner = false; return handle; - } - - public unsafe Tensor(byte[][] buffer, long[] shape) - { - int size = 0; - foreach (var b in buffer) - { - size += (int)TF_StringEncodedSize((UIntPtr)b.Length); - } - int totalSize = size + buffer.Length * 8; - ulong offset = 0; - IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr)totalSize); - - // Clear offset table - IntPtr pOffset = TF_TensorData(handle); - IntPtr dst = pOffset + buffer.Length * 8; - IntPtr dstLimit = pOffset + totalSize; - for (int i = 0; i < buffer.Length; i++) - { - Marshal.WriteInt64(pOffset, (long)offset); - using (var status = new Status()) - { - fixed (byte* src = &buffer[i][0]) - { - var written = TF_StringEncode(src, (UIntPtr)buffer[i].Length, (sbyte*)dst, (UIntPtr)(dstLimit.ToInt64() - dst.ToInt64()), status); - status.Check(true); - pOffset += 8; - dst += (int)written; - offset += written; - } - } - } - - _handle = handle; + + } + + public unsafe Tensor(byte[][] buffer, long[] shape) + { + int size = 0; + foreach (var b in buffer) + { + size += (int)TF_StringEncodedSize((UIntPtr)b.Length); + } + int totalSize = size + buffer.Length * 8; + ulong offset = 0; + IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr)totalSize); + + // Clear offset table + IntPtr pOffset = TF_TensorData(handle); + IntPtr dst = pOffset + buffer.Length * 8; + IntPtr dstLimit = pOffset + totalSize; + for (int i = 0; i < buffer.Length; i++) + { + Marshal.WriteInt64(pOffset, (long)offset); + using (var status = new Status()) + { + fixed (byte* src = &buffer[i][0]) + { + var written = TF_StringEncode(src, (UIntPtr)buffer[i].Length, (sbyte*)dst, (UIntPtr)(dstLimit.ToInt64() - dst.ToInt64()), status); + status.Check(true); + pOffset += 8; + dst += (int)written; + offset += written; + } + } + } + + _handle = handle; } public Tensor(Operation op, int value_index, TF_DataType dtype) { _op = op; _value_index = value_index; - _dtype = dtype; + _override_dtype = dtype; _id = ops.uid(); } @@ -589,11 +612,11 @@ namespace Tensorflow /// specified dimensions. /// [MethodImpl(MethodImplOptions.AggressiveInlining)] + [SuppressMessage("ReSharper", "LocalVariableHidesMember")] protected unsafe IntPtr CreateTensorWithoutCopying(TF_DataType dt, long[] shape, Array data, int element_size) { - if (dt == TF_DataType.TF_STRING && data is byte[]) + if (dt == TF_DataType.TF_STRING && data is byte[] buffer) { - var buffer = (byte[])data; var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs index 6db60b4a..6d7f20f1 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs @@ -1,4 +1,5 @@ using System; +using System.Runtime.CompilerServices; namespace Tensorflow { @@ -6,86 +7,142 @@ namespace Tensorflow { public static explicit operator bool(Tensor tensor) { - EnsureScalar(tensor); - return tensor.Data()[0]; + unsafe + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_BOOL); + return *(bool*) tensor.buffer; + } } public static explicit operator sbyte(Tensor tensor) { - EnsureScalar(tensor); - return tensor.Data()[0]; + unsafe + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_INT8); + return *(sbyte*) tensor.buffer; + } } public static explicit operator byte(Tensor tensor) { - EnsureScalar(tensor); - return tensor.Data()[0]; + unsafe + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_UINT8); + return *(byte*) tensor.buffer; + } } public static explicit operator ushort(Tensor tensor) { - EnsureScalar(tensor); - return tensor.Data()[0]; + unsafe + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_UINT16); + return *(ushort*) tensor.buffer; + } } public static explicit operator short(Tensor tensor) { - EnsureScalar(tensor); - return tensor.Data()[0]; + unsafe + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_INT16); + return *(short*) tensor.buffer; + } } public static explicit operator int(Tensor tensor) { - EnsureScalar(tensor); - return tensor.Data()[0]; + unsafe + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_INT32); + return *(int*) tensor.buffer; + } } public static explicit operator uint(Tensor tensor) { - EnsureScalar(tensor); - return tensor.Data()[0]; + unsafe + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_UINT32); + return *(uint*) tensor.buffer; + } } public static explicit operator long(Tensor tensor) { - EnsureScalar(tensor); - return tensor.Data()[0]; + unsafe + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_INT64); + return *(long*) tensor.buffer; + } } public static explicit operator ulong(Tensor tensor) { - EnsureScalar(tensor); - return tensor.Data()[0]; + unsafe + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_UINT64); + return *(ulong*) tensor.buffer; + } } public static explicit operator float(Tensor tensor) { - EnsureScalar(tensor); - return tensor.Data()[0]; + unsafe + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_FLOAT); + return *(float*) tensor.buffer; + } } public static explicit operator double(Tensor tensor) { - EnsureScalar(tensor); - return tensor.Data()[0]; + unsafe + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_DOUBLE); + return *(double*) tensor.buffer; + } + } + + public static explicit operator string(Tensor tensor) + { + unsafe + { + EnsureScalar(tensor); + EnsureDType(tensor, TF_DataType.TF_STRING); + return new string((char*) tensor.buffer, 0, (int) tensor.size); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void EnsureDType(Tensor tensor, TF_DataType @is) + { + if (tensor.dtype != @is) + throw new InvalidCastException($"Unable to cast scalar tensor {tensor.dtype} to {@is}"); } + [MethodImpl(MethodImplOptions.AggressiveInlining)] private static void EnsureScalar(Tensor tensor) { if (tensor == null) - { throw new ArgumentNullException(nameof(tensor)); - } if (tensor.TensorShape.ndim != 0) - { throw new ArgumentException("Tensor must have 0 dimensions in order to convert to scalar"); - } if (tensor.TensorShape.size != 1) - { throw new ArgumentException("Tensor must have size 1 in order to convert to scalar"); - } } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs index eb912eb9..4b15864f 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs @@ -69,11 +69,12 @@ namespace Tensorflow TF_DataType.TF_QINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QINT32, TF_DataType.TF_UINT8, TF_DataType.TF_UINT16, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64 }; + public static Tensor operator /(double x, Tensor y) => BinaryOpWrapper("truediv", x, y); public static Tensor operator /(float x, Tensor y) => BinaryOpWrapper("truediv", x, y); public static Tensor operator /(int x, Tensor y) => BinaryOpWrapper("floordiv", x, y); public static Tensor operator /(Tensor x, Tensor y) => - _intTfDataTypes.Contains(x._dtype) + _intTfDataTypes.Contains(x.dtype) ? BinaryOpWrapper("floordiv", x, y) : BinaryOpWrapper("truediv", x, y); public static Tensor operator /(Tensor x, int y) => BinaryOpWrapper("floordiv", x, y); @@ -122,8 +123,7 @@ namespace Tensorflow if (y is Tensor tr) dtype = tr.dtype.as_base_dtype(); - var namescope = ops.name_scope(null, name, new { x, y }); - return tf_with(namescope, scope => + return tf_with(ops.name_scope(null, name, new { x, y }), scope => { Tensor result = null; var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x"); @@ -155,7 +155,6 @@ namespace Tensorflow return result; }); - } } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index d52b9422..8ac6c73e 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -17,9 +17,16 @@ using NumSharp; using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Globalization; using System.Linq; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Text; +using System.Threading.Tasks; +using NumSharp.Backends; +using NumSharp.Backends.Unmanaged; +using NumSharp.Utilities; using Tensorflow.Framework; using static Tensorflow.Binding; @@ -29,42 +36,68 @@ namespace Tensorflow /// A tensor is a generalization of vectors and matrices to potentially higher dimensions. /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. /// + [SuppressMessage("ReSharper", "ConvertToAutoProperty")] public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike { - private int _id; - private Operation _op; + private readonly int _id; + private readonly Operation _op; + private readonly int _value_index; + private TF_Output? _tf_output; + private readonly TF_DataType _override_dtype; public int Id => _id; + + /// + /// The Graph that contains this tensor. + /// public Graph graph => op?.graph; + + /// + /// The Operation that produces this tensor as an output. + /// public Operation op => _op; + public Tensor[] outputs => op.outputs; /// - /// The string name of this tensor. + /// The string name of this tensor. /// public string name => $"{(op == null ? "" : $"{op.name}:{_value_index}")}"; - private int _value_index; + /// + /// The index of this tensor in the outputs of its Operation. + /// public int value_index => _value_index; - private TF_DataType _dtype = TF_DataType.DtInvalid; - public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle); + /// + /// The DType of elements in this tensor. + /// + public TF_DataType dtype => _handle == IntPtr.Zero ? _override_dtype : c_api.TF_TensorType(_handle); public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); - public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; - public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); + public int NDims => rank; - private TF_Output? _tf_output; + /// + /// The name of the device on which this tensor will be produced, or null. + /// + public string Device => op.Device; + + public int[] dims => shape; /// - /// used for keep other pointer when do implicit operating + /// Used for keep other pointer when do implicit operating /// public object Tag { get; set; } + + /// + /// Returns the shape of a tensor. + /// + /// https://www.tensorflow.org/api_docs/python/tf/shape public int[] shape { get @@ -76,14 +109,13 @@ namespace Tensorflow var status = new Status(); c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); status.Check(); - } - else + } else { for (int i = 0; i < rank; i++) dims[i] = c_api.TF_Dim(_handle, i); } - return dims.Select(x => Convert.ToInt32(x)).ToArray(); + return dims.Select(x => ((IConvertible) x).ToInt32(CultureInfo.InvariantCulture)).ToArray(); } set @@ -93,38 +125,52 @@ namespace Tensorflow if (value == null) c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); else - c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(x => Convert.ToInt64(x)).ToArray(), value.Length, status); + c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); } } public int[] _shape_tuple() { - if (shape == null) return null; - return shape.Select(x => (int)x).ToArray(); + return (int[]) shape.Clone(); } public TensorShape TensorShape => tensor_util.to_shape(shape); - public void SetShape(TensorShape shape) + /// + /// Updates the shape of this tensor. + /// + public void set_shape(TensorShape shape) { - this.shape = shape.dims; + this.shape = (int[]) shape.dims.Clone(); } + /// + /// Updates the shape of this tensor. + /// + [Obsolete("Please use set_shape(TensorShape shape) instead.", false)] + public void SetShape(TensorShape shape) + { + this.shape = (int[]) shape.dims.Clone(); + } + + /// + /// Updates the shape of this tensor. + /// public void set_shape(Tensor shape) { + // ReSharper disable once MergeConditionalExpression this.shape = shape is null ? null : shape.shape; } - public int[] dims => shape; - /// - /// number of dimensions - /// 0 Scalar (magnitude only) - /// 1 Vector (magnitude and direction) - /// 2 Matrix (table of numbers) - /// 3 3-Tensor (cube of numbers) + /// number of dimensions

+ /// 0 Scalar (magnitude only)

+ /// 1 Vector (magnitude and direction)

+ /// 2 Matrix (table of numbers)

+ /// 3 3-Tensor (cube of numbers)

/// n n-Tensor (you get the idea) ///
+ /// https://www.tensorflow.org/api_docs/python/tf/rank public int rank { get @@ -137,17 +183,15 @@ namespace Tensorflow status.Check(); return ndim; } - else - { - return c_api.TF_NumDims(_handle); - } + + return c_api.TF_NumDims(_handle); } } - public int NDims => rank; - - public string Device => op.Device; - + /// + /// Returns a list of Operations that consume this tensor. + /// + /// public Operation[] consumers() { var output = _as_tf_output(); @@ -157,37 +201,191 @@ namespace Tensorflow public TF_Output _as_tf_output() { - if(!_tf_output.HasValue) + if (!_tf_output.HasValue) _tf_output = new TF_Output(op, value_index); return _tf_output.Value; } - public T[] Data() + [Obsolete("Please use ToArray() instead.", false)] + public T[] Data() where T : unmanaged + { + return ToArray(); + } + + /// + /// + /// + /// + /// + /// When is string + public T[] ToArray() where T : unmanaged { - // Column major order - // https://en.wikipedia.org/wiki/File:Row_and_column_major_order.svg - // matrix:[[1, 2, 3], [4, 5, 6]] - // index: 0 2 4 1 3 5 - // result: 1 4 2 5 3 6 - var data = new T[size]; - - for (ulong i = 0; i < size; i++) + //when T is string + if (typeof(T) == typeof(string)) { - data[i] = Marshal.PtrToStructure(buffer + (int)(i * itemsize)); + if (dtype != TF_DataType.TF_STRING) + throw new ArgumentException($"Given <{typeof(T).Name}> can't be converted to string."); + + return (T[]) (object) StringData(); } - return data; + //Are the types matching? + if (typeof(T).as_dtype() == dtype) + { + if (NDims == 0 && size == 1) //is it a scalar? + { + unsafe + { + return new T[] {*(T*) buffer}; + } + } + + //types match, no need to perform cast + var ret = new T[size]; + unsafe + { + var len = (long) size; + fixed (T* dstRet = ret) + { + T* dst = dstRet; //local stack copy + if (typeof(T).IsPrimitive) + { + var src = (T*) buffer; + len *= ((long) itemsize); + System.Buffer.MemoryCopy(src, dst, len, len); + } else + { + var itemsize = (long) this.itemsize; + var buffer = this.buffer.ToInt64(); + Parallel.For(0L, len, i => dst[i] = Marshal.PtrToStructure(new IntPtr(buffer + i * itemsize))); + } + } + } + + return ret; + } else + { + + //types do not match, need to perform cast + if (NDims == 0 && size == 1) //is it a scalar? + { + unsafe + { +#if _REGEN + #region Compute + switch (dtype.as_numpy_dtype().GetTypeCode()) + { + %foreach supported_dtypes,supported_dtypes_lowercase% + case NPTypeCode.#1: return new T[] {Converts.ChangeType(*(#2*) buffer, NPTypeCode.#1)}; + % + case NPTypeCode.String: return new T[] {Converts.ChangeType((string)this, NPTypeCode.String)}; + default: + throw new NotSupportedException(); + } + #endregion +#else + #region Compute + switch (dtype.as_numpy_dtype()?.GetTypeCode()) + { + case NPTypeCode.Boolean: return new T[] {Converts.ChangeType(*(bool*) buffer, NPTypeCode.Boolean)}; + case NPTypeCode.Byte: return new T[] {Converts.ChangeType(*(byte*) buffer, NPTypeCode.Byte)}; + case NPTypeCode.Int16: return new T[] {Converts.ChangeType(*(short*) buffer, NPTypeCode.Int16)}; + case NPTypeCode.UInt16: return new T[] {Converts.ChangeType(*(ushort*) buffer, NPTypeCode.UInt16)}; + case NPTypeCode.Int32: return new T[] {Converts.ChangeType(*(int*) buffer, NPTypeCode.Int32)}; + case NPTypeCode.UInt32: return new T[] {Converts.ChangeType(*(uint*) buffer, NPTypeCode.UInt32)}; + case NPTypeCode.Int64: return new T[] {Converts.ChangeType(*(long*) buffer, NPTypeCode.Int64)}; + case NPTypeCode.UInt64: return new T[] {Converts.ChangeType(*(ulong*) buffer, NPTypeCode.UInt64)}; + case NPTypeCode.Char: return new T[] {Converts.ChangeType(*(char*) buffer, NPTypeCode.Char)}; + case NPTypeCode.Double: return new T[] {Converts.ChangeType(*(double*) buffer, NPTypeCode.Double)}; + case NPTypeCode.Single: return new T[] {Converts.ChangeType(*(float*) buffer, NPTypeCode.Single)}; + case NPTypeCode.String: return new T[] {Converts.ChangeType((string)this, NPTypeCode.String)}; + default: + throw new NotSupportedException(); + } + #endregion +#endif + } + } + + var ret = new T[size]; + unsafe + { + var len = (long) size; + fixed (T* dstRet = ret) + { + T* dst = dstRet; //local stack copy + +#if _REGEN + #region Compute + switch (dtype.as_numpy_dtype().GetTypeCode()) + { + %foreach supported_dtypes,supported_dtypes_lowercase% + case NPTypeCode.#1: new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + % + default: + throw new NotSupportedException(); + } + #endregion +#else + #region Compute + switch (dtype.as_numpy_dtype().GetTypeCode()) + { + case NPTypeCode.Boolean: new UnmanagedMemoryBlock((bool*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Byte: new UnmanagedMemoryBlock((byte*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Int16: new UnmanagedMemoryBlock((short*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.UInt16: new UnmanagedMemoryBlock((ushort*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Int32: new UnmanagedMemoryBlock((int*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.UInt32: new UnmanagedMemoryBlock((uint*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Int64: new UnmanagedMemoryBlock((long*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.UInt64: new UnmanagedMemoryBlock((ulong*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Char: new UnmanagedMemoryBlock((char*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Double: new UnmanagedMemoryBlock((double*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.Single: new UnmanagedMemoryBlock((float*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break; + case NPTypeCode.String: throw new NotSupportedException("Unable to convert from string to other dtypes"); //TODO! this should call Converts.To + default: + throw new NotSupportedException(); + } + #endregion +#endif + + } + } + + return ret; + } } + /// + /// Copies the memory of current buffer onto newly allocated array. + /// + /// + [Obsolete("Please use set_shape(TensorShape shape) instead.", false)] public byte[] Data() { - var data = new byte[bytesize]; - Marshal.Copy(buffer, data, 0, (int)bytesize); - return data; + return BufferToArray(); + } + + /// + /// Copies the memory of current buffer onto newly allocated array. + /// + /// + public byte[] BufferToArray() + { + unsafe + { + // ReSharper disable once LocalVariableHidesMember + var bytesize = (long) this.bytesize; + var data = new byte[bytesize]; + fixed (byte* dst = data) + System.Buffer.MemoryCopy(buffer.ToPointer(), dst, bytesize, bytesize); + + return data; + } } - public unsafe string[] StringData() + /// Used internally in ToArray<T> + private unsafe string[] StringData() { // // TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes. @@ -199,19 +397,19 @@ namespace Tensorflow var buffer = new byte[size][]; var src = c_api.TF_TensorData(_handle); - var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize); - src += (int)(size * 8); + var srcLen = (IntPtr) (src.ToInt64() + (long) bytesize); + src += (int) (size * 8); for (int i = 0; i < buffer.Length; i++) { using (var status = new Status()) { IntPtr dst = IntPtr.Zero; UIntPtr dstLen = UIntPtr.Zero; - var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, status); + var read = c_api.TF_StringDecode((byte*) src, (UIntPtr) (srcLen.ToInt64() - src.ToInt64()), (byte**) &dst, &dstLen, status); status.Check(true); - buffer[i] = new byte[(int)dstLen]; + buffer[i] = new byte[(int) dstLen]; Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); - src += (int)read; + src += (int) read; } } @@ -229,51 +427,29 @@ namespace Tensorflow } /// - /// Evaluates this tensor in a `Session`. + /// Evaluates this tensor in a `Session`. /// /// A dictionary that maps `Tensor` objects to feed values. - /// The `Session` to be used to evaluate this tensor. - /// + /// A array corresponding to the value of this tensor. public NDArray eval(params FeedItem[] feed_dict) { return ops._eval_using_default_session(this, feed_dict, graph); } + /// + /// Evaluates this tensor in a `Session`. + /// + /// A dictionary that maps `Tensor` objects to feed values. + /// The `Session` to be used to evaluate this tensor. + /// A array corresponding to the value of this tensor. public NDArray eval(Session session, FeedItem[] feed_dict = null) { return ops._eval_using_default_session(this, feed_dict, graph, session); } - public TF_DataType ToTFDataType(Type type) - { - switch (type.Name) - { - case "Char": - return TF_DataType.TF_UINT8; - case "Int16": - return TF_DataType.TF_INT16; - case "Int32": - return TF_DataType.TF_INT32; - case "Int64": - return TF_DataType.TF_INT64; - case "Single": - return TF_DataType.TF_FLOAT; - case "Double": - return TF_DataType.TF_DOUBLE; - case "Byte": - return TF_DataType.TF_UINT8; - case "String": - return TF_DataType.TF_STRING; - case "Boolean": - return TF_DataType.TF_BOOL; - default: - throw new NotImplementedException("ToTFDataType error"); - } - } - public Tensor slice(Slice slice) { - var slice_spec = new int[] { slice.Start.Value }; + var slice_spec = new int[] {slice.Start.Value}; var begin = new List(); var end = new List(); var strides = new List(); @@ -289,26 +465,26 @@ namespace Tensorflow if (slice.Stop.HasValue) { end.Add(slice.Stop.Value); - } - else + } else { end.Add(0); end_mask |= (1 << index); } + strides.Add(slice.Step); index += 1; } - return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => + return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope => { string name = scope; if (begin != null) { var (packed_begin, packed_end, packed_strides) = (array_ops.stack(begin.ToArray()), - array_ops.stack(end.ToArray()), - array_ops.stack(strides.ToArray())); + array_ops.stack(end.ToArray()), + array_ops.stack(strides.ToArray())); return gen_array_ops.strided_slice( this, @@ -320,7 +496,6 @@ namespace Tensorflow shrink_axis_mask: shrink_axis_mask, new_axis_mask: new_axis_mask, ellipsis_mask: ellipsis_mask, - name: name); } @@ -330,7 +505,7 @@ namespace Tensorflow public Tensor slice(int start) { - var slice_spec = new int[] { start }; + var slice_spec = new int[] {start}; var begin = new List(); var end = new List(); var strides = new List(); @@ -349,15 +524,15 @@ namespace Tensorflow index += 1; } - return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => + return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope => { string name = scope; if (begin != null) { var (packed_begin, packed_end, packed_strides) = (array_ops.stack(begin.ToArray()), - array_ops.stack(end.ToArray()), - array_ops.stack(strides.ToArray())); + array_ops.stack(end.ToArray()), + array_ops.stack(strides.ToArray())); return gen_array_ops.strided_slice( this, @@ -369,7 +544,6 @@ namespace Tensorflow shrink_axis_mask: shrink_axis_mask, new_axis_mask: new_axis_mask, ellipsis_mask: ellipsis_mask, - name: name); } @@ -392,15 +566,12 @@ namespace Tensorflow return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}"; } - protected override void DisposeManagedState() - { - } - - protected override void DisposeUnManagedState(IntPtr handle) + protected override void DisposeUnmanagedResources(IntPtr handle) { - if(handle != IntPtr.Zero) + if (handle != IntPtr.Zero) { c_api.TF_DeleteTensor(handle); + _handle = IntPtr.Zero; } } @@ -417,4 +588,4 @@ namespace Tensorflow public int tensor_int_val { get; set; } } -} +} \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs index 13258f79..cf62ce04 100644 --- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs +++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs @@ -1,35 +1,84 @@ using NumSharp; using System; +using System.Diagnostics.CodeAnalysis; using System.Linq; +using System.Runtime.CompilerServices; namespace Tensorflow { /// - /// Represents the shape of a `Tensor`. + /// Represents the shape of a `Tensor`. /// + /// https://www.tensorflow.org/api_docs/python/tf/TensorShape public class TensorShape { - private Shape shape; + private readonly Shape shape; + + /// + /// Returns a list of Dimensions, or None if the shape is unspecified. + /// public int[] dims => shape.Dimensions; + + /// + /// Returns the rank of this shape. + /// public int ndim => shape.NDim; + + /// + /// Returns the rank of this shape. + /// + public int rank => shape.NDim; + + /// + /// Returns the size this shape represents. + /// public int size => shape.Size; public TensorShape(TensorShapeProto proto) { if (proto.UnknownRank) return; + switch (proto.Dim.Count) + { + case 0: shape = new Shape(new int[0]); break; + case 1: shape = Shape.Vector((int) proto.Dim[0].Size); break; + case 2: shape = Shape.Matrix((int) proto.Dim[0].Size, (int) proto.Dim[1].Size); break; + default: + var protodims = proto.Dim; + var len = protodims.Count; + var dims = new int[len]; + for (int i = 0; i < len; i++) + dims[i] = (int) protodims[i].Size; + - shape.reshape(proto.Dim.Select(x => (int)x.Size).ToArray()); + shape = new Shape(dims); break; + } } public TensorShape(params int[] dims) { - shape = new Shape(dims); + switch (dims.Length) + { + case 0: shape = new Shape(new int[0]); break; + case 1: shape = Shape.Vector((int) dims[0]); break; + case 2: shape = Shape.Matrix(dims[0], dims[1]); break; + default: shape = new Shape(dims); break; + } } + /// + /// + /// + /// + /// + /// When is not an Index. + [SuppressMessage("ReSharper", "PossibleInvalidOperationException")] public TensorShape this[Slice slice] { get { + if (slice.Start.HasValue == false || slice.Length.HasValue == false) + throw new ArgumentException("Slice must has Start and Length."); + return new TensorShape(dims.Skip(slice.Start.Value) .Take(slice.Length.Value) .ToArray()); @@ -37,7 +86,7 @@ namespace Tensorflow } /// - /// Returns True iff `self` is fully defined in every dimension. + /// Returns True iff `self` is fully defined in every dimension. /// /// public bool is_fully_defined() @@ -50,6 +99,7 @@ namespace Tensorflow throw new NotImplementedException("TensorShape is_compatible_with"); } + [SuppressMessage("ReSharper", "ParameterHidesMember")] public TensorShape with_rank_at_least(int rank) { if (rank != ndim) @@ -59,35 +109,63 @@ namespace Tensorflow } /// - /// Returns the concatenation of the dimension in `self` and `other`. + /// Returns the concatenation of the dimension in `self` and `other`. /// /// /// - public TensorShape concatenate(int[] other_) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public TensorShape concatenate(int[] other) { - var other = new TensorShape(other_); + return concatenate(new TensorShape(other)); + } - if (ndim < 0 || other.ndim < 0) + /// + /// Returns the concatenation of the dimension in `self` and `other`. + /// + /// + /// + public TensorShape concatenate(TensorShape other) + { + var otherShape = other; + + if (ndim < 0 || otherShape.ndim < 0) return new TensorShape(); else { - var concatenate_dims = new int[ndim + other.ndim]; + var concatenate_dims = new int[ndim + otherShape.ndim]; for (int i = 0; i < ndim; i++) concatenate_dims[i] = dims[i]; - for (int i = 0; i < other.ndim; i++) - concatenate_dims[ndim + i] = other.dims[i]; + for (int i = 0; i < otherShape.ndim; i++) + concatenate_dims[ndim + i] = otherShape.dims[i]; return new TensorShape(concatenate_dims); } } - public static implicit operator TensorShape(Shape shape) => new TensorShape(shape.Dimensions); - public static implicit operator Shape(TensorShape shape) => new Shape(shape.dims); + public static implicit operator TensorShape(Shape shape) => new TensorShape((int[]) shape.Dimensions.Clone()); + public static implicit operator Shape(TensorShape shape) => new Shape((int[]) shape.dims.Clone()); + + public static implicit operator int[](TensorShape shape) => (int[])shape.dims.Clone(); //we clone to avoid any changes public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); - public static implicit operator int[](TensorShape shape) => shape.dims; + + public static explicit operator int(TensorShape shape) => shape.size; + public static explicit operator TensorShape(int dim) => new TensorShape(dim); + + public static explicit operator (int, int)(TensorShape shape) => shape.dims.Length == 2 ? (shape.dims[0], shape.dims[1]) : (0, 0); public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); + + public static explicit operator (int, int, int)(TensorShape shape) => shape.dims.Length == 3 ? (shape.dims[0], shape.dims[1], shape.dims[2]) : (0, 0, 0); public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3); + + public static explicit operator (int, int, int, int)(TensorShape shape) => shape.dims.Length == 4 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3]) : (0, 0, 0, 0); public static implicit operator TensorShape((int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); + + public static explicit operator (int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 5 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4]) : (0, 0, 0, 0, 0); + public static implicit operator TensorShape((int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5); + + public static explicit operator (int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 6 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5]) : (0, 0, 0, 0, 0, 0); + public static implicit operator TensorShape((int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6); + } } diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index 807dc6f5..37f1ca61 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -15,6 +15,8 @@ ******************************************************************************/ using System; +using System.Numerics; +using NumSharp.Backends; namespace Tensorflow { @@ -23,35 +25,100 @@ namespace Tensorflow public static TF_DataType int8 = TF_DataType.TF_INT8; public static TF_DataType int32 = TF_DataType.TF_INT32; public static TF_DataType int64 = TF_DataType.TF_INT64; + public static TF_DataType uint8 = TF_DataType.TF_UINT8; + public static TF_DataType uint32 = TF_DataType.TF_UINT32; + public static TF_DataType uint64 = TF_DataType.TF_UINT64; public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32? public static TF_DataType float16 = TF_DataType.TF_HALF; public static TF_DataType float64 = TF_DataType.TF_DOUBLE; - public static Type as_numpy_datatype(this TF_DataType type) + /// + /// + /// + /// + /// equivalent to , if none exists, returns null. + public static Type as_numpy_dtype(this TF_DataType type) { switch (type) { case TF_DataType.TF_BOOL: return typeof(bool); + case TF_DataType.TF_UINT8: + return typeof(byte); case TF_DataType.TF_INT64: return typeof(long); + case TF_DataType.TF_UINT64: + return typeof(ulong); case TF_DataType.TF_INT32: return typeof(int); + case TF_DataType.TF_UINT32: + return typeof(uint); case TF_DataType.TF_INT16: return typeof(short); + case TF_DataType.TF_UINT16: + return typeof(ushort); case TF_DataType.TF_FLOAT: return typeof(float); case TF_DataType.TF_DOUBLE: return typeof(double); case TF_DataType.TF_STRING: return typeof(string); + case TF_DataType.TF_COMPLEX128: + case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX + return typeof(Complex); default: return null; } } - // "sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex" - public static TF_DataType as_dtype(Type type, TF_DataType? dtype = null) + /// + /// + /// + /// + /// + /// When has no equivalent + public static NPTypeCode as_numpy_typecode(this TF_DataType type) + { + switch (type) + { + case TF_DataType.TF_BOOL: + return NPTypeCode.Boolean; + case TF_DataType.TF_UINT8: + return NPTypeCode.Byte; + case TF_DataType.TF_INT64: + return NPTypeCode.Int64; + case TF_DataType.TF_INT32: + return NPTypeCode.Int32; + case TF_DataType.TF_INT16: + return NPTypeCode.Int16; + case TF_DataType.TF_UINT64: + return NPTypeCode.UInt64; + case TF_DataType.TF_UINT32: + return NPTypeCode.UInt32; + case TF_DataType.TF_UINT16: + return NPTypeCode.UInt16; + case TF_DataType.TF_FLOAT: + return NPTypeCode.Single; + case TF_DataType.TF_DOUBLE: + return NPTypeCode.Double; + case TF_DataType.TF_STRING: + return NPTypeCode.String; + case TF_DataType.TF_COMPLEX128: + case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX + return NPTypeCode.Complex; + default: + throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode."); + } + } + + /// + /// + /// + /// + /// + /// + /// When has no equivalent + public static TF_DataType as_dtype(this Type type, TF_DataType? dtype = null) { switch (type.Name) { @@ -98,7 +165,7 @@ namespace Tensorflow dtype = TF_DataType.TF_BOOL; break; default: - throw new Exception("as_dtype Not Implemented"); + throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode."); } return dtype.Value; @@ -106,16 +173,7 @@ namespace Tensorflow public static DataType as_datatype_enum(this TF_DataType type) { - DataType dtype = DataType.DtInvalid; - - switch (type) - { - default: - Enum.TryParse(((int)type).ToString(), out dtype); - break; - } - - return dtype; + return Enum.TryParse(((int) type).ToString(), out DataType dtype) ? dtype : DataType.DtInvalid; } public static TF_DataType as_base_dtype(this TF_DataType type) @@ -132,7 +190,7 @@ namespace Tensorflow public static Type as_numpy_dtype(this DataType type) { - return type.as_tf_dtype().as_numpy_datatype(); + return type.as_tf_dtype().as_numpy_dtype(); } public static DataType as_base_dtype(this DataType type) @@ -144,16 +202,7 @@ namespace Tensorflow public static TF_DataType as_tf_dtype(this DataType type) { - TF_DataType dtype = TF_DataType.DtInvalid; - - switch (type) - { - default: - Enum.TryParse(((int)type).ToString(), out dtype); - break; - } - - return dtype; + return Enum.TryParse(((int) type).ToString(), out TF_DataType dtype) ? dtype : TF_DataType.DtInvalid; } public static TF_DataType as_ref(this TF_DataType type) diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index ded105c7..43848da6 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -17,6 +17,7 @@ using NumSharp; using System; using System.Linq; +using NumSharp.Utilities; namespace Tensorflow { @@ -109,7 +110,7 @@ namespace Tensorflow // We first convert value to a numpy array or scalar. NDArray nparray = null; - var np_dt = dtype.as_numpy_datatype(); + var np_dt = dtype.as_numpy_dtype(); if (values is NDArray nd) { @@ -188,37 +189,37 @@ namespace Tensorflow if (values.GetType().IsArray) nparray = np.array((int[])values, np_dt); else - nparray = Convert.ToInt32(values); + nparray = Converts.ToInt32(values); break; case "Int64": if (values.GetType().IsArray) nparray = np.array((int[])values, np_dt); else - nparray = Convert.ToInt64(values); + nparray = Converts.ToInt64(values); break; case "Single": if (values.GetType().IsArray) nparray = np.array((float[])values, np_dt); else - nparray = Convert.ToSingle(values); + nparray = Converts.ToSingle(values); break; case "Double": if (values.GetType().IsArray) nparray = np.array((double[])values, np_dt); else - nparray = Convert.ToDouble(values); + nparray = Converts.ToDouble(values); break; case "String": if (values.GetType().IsArray) nparray = np.array((string[])values, np_dt); else - nparray = NDArray.FromString(Convert.ToString(values)); + nparray = NDArray.FromString(Converts.ToString(values)); break; case "Boolean": if (values.GetType().IsArray) nparray = np.array((bool[])values, np_dt); else - nparray = Convert.ToBoolean(values); + nparray = Converts.ToBoolean(values); break; default: throw new NotImplementedException($"make_tensor_proto: Support for type {np_dt.Name} Not Implemented"); diff --git a/src/TensorFlowNET.Core/globals.regen b/src/TensorFlowNET.Core/globals.regen new file mode 100644 index 00000000..146155b3 --- /dev/null +++ b/src/TensorFlowNET.Core/globals.regen @@ -0,0 +1,38 @@ +%all_dtypes = ["NDArray","Complex","Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single","String"] +%all_dtypes_lowercase = ["NDArray","Complex","bool","byte","short","ushort","int","uint","long","ulong","char","double","float","string"] + +%supported_primitives = ["Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single","String"] +%supported_primitives_lowercase = ["bool","byte","short","ushort","int","uint","long","ulong","char","double","float","string"] + +%supported_numericals = ["Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single"] +%supported_numericals_lowercase = ["byte","short","ushort","int","uint","long","ulong","char","double","float"] +%supported_numericals_defaultvals = ["0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] +%supported_numericals_onevales = ["1","1","1","1","1u","1L","1UL",1,"1d","1f"] +%supported_numericals_TF_DataType = ["TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_UINT8","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"] + +//this is the type we use in summerizing/reducting: +%supported_numericals_accumulatingType = ["UInt32","Int32","UInt32","Int32","UInt32","Int64","UInt64","UInt32","Double","Single"] +%supported_numericals_accumulatingType_defaultvals = ["0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] + +%supported_numericals_signed = ["Int16","Int32","Int64","Double","Single"] +%supported_numericals_signed_lowercase = ["short","int","long","double","float"] +%supported_numericals_signed_defaultvals = ["0","0","0L","0d","0f"] +%supported_numericals_signed_onevales = ["1","1","1L","1d","1f"] + +%supported_numericals_unsigned = ["Byte","UInt16","UInt32","UInt64","Char"] +%supported_numericals_unsigned_lowercase = ["byte","ushort","uint","ulong","char"] +%supported_numericals_unsigned_defaultvals = ["0","0","0U","0UL","'\0'"] +%supported_numericals_unsigned_onevales = ["1","1","1U","1UL","'\1'"] + +%supported_dtypes = ["Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single"] +%supported_numericals_TF_DataType = ["TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_UINT8","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"] + +%supported_dtypes_lowercase = ["bool","byte","short","ushort","int","uint","long","ulong","char","double","float"] +%supported_dtypes_defaultvals = [false,"0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] +%supported_dtypes_onevales = [true,"1","1","1","1","1u","1L","1UL","'\1'","1d","1f"] +%supported_dtypes_dtype = ["bool","uint8","int16","uint16","int32","uint32","int64","uint64","uint8","float64","float32"] + +//this is the type we use in summerizing/reducting: +%supported_dtypes_accumulatingType = ["Int32","UInt32","Int32","UInt32","Int32","UInt32","Int64","UInt64","UInt32","Double","Single"] +%supported_dtypes_accumulatingType_defaultvals = [false, "0","0","0","0u","0L","0UL","'\0'","0d","0f"] + diff --git a/src/TensorFlowNET.Core/ops.GraphKeys.cs b/src/TensorFlowNET.Core/ops.GraphKeys.cs index 17b095a4..c5a06433 100644 --- a/src/TensorFlowNET.Core/ops.GraphKeys.cs +++ b/src/TensorFlowNET.Core/ops.GraphKeys.cs @@ -29,55 +29,111 @@ namespace Tensorflow /// public class GraphKeys { + #region const + + + /// + /// the subset of `Variable` objects that will be trained by an optimizer. + /// + public const string TRAINABLE_VARIABLES_ = "trainable_variables"; + + /// + /// Trainable resource-style variables. + /// + public const string TRAINABLE_RESOURCE_VARIABLES_ = "trainable_resource_variables"; + + /// + /// Key for streaming model ports. + /// + public const string _STREAMING_MODEL_PORTS_ = "streaming_model_ports"; + + /// + /// Key to collect losses + /// + public const string LOSSES_ = "losses"; + + /// + /// Key to collect Variable objects that are global (shared across machines). + /// Default collection for all variables, except local ones. + /// + public const string GLOBAL_VARIABLES_ = "variables"; + + public const string TRAIN_OP_ = "train_op"; + + public const string GLOBAL_STEP_ = "global_step"; + + public string[] _VARIABLE_COLLECTIONS_ = new string[] { "variables", "trainable_variables", "model_variables" }; + /// + /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. + /// + public const string SAVEABLE_OBJECTS_ = "saveable_objects"; + /// + /// Key to collect update_ops + /// + public const string UPDATE_OPS_ = "update_ops"; + + // Key to collect summaries. + public const string SUMMARIES_ = "summaries"; + + // Used to store v2 summary names. + public const string _SUMMARY_COLLECTION_ = "_SUMMARY_V2"; + + // Key for control flow context. + public const string COND_CONTEXT_ = "cond_context"; + public const string WHILE_CONTEXT_ = "while_context"; + + #endregion + + /// /// the subset of `Variable` objects that will be trained by an optimizer. /// - public string TRAINABLE_VARIABLES = "trainable_variables"; + public string TRAINABLE_VARIABLES => TRAINABLE_VARIABLES_; /// /// Trainable resource-style variables. /// - public string TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables"; + public string TRAINABLE_RESOURCE_VARIABLES => TRAINABLE_RESOURCE_VARIABLES_; /// /// Key for streaming model ports. /// - public string _STREAMING_MODEL_PORTS = "streaming_model_ports"; + public string _STREAMING_MODEL_PORTS => _STREAMING_MODEL_PORTS_; /// /// Key to collect losses /// - public string LOSSES = "losses"; + public string LOSSES => LOSSES_; /// /// Key to collect Variable objects that are global (shared across machines). /// Default collection for all variables, except local ones. /// - public string GLOBAL_VARIABLES = "variables"; + public string GLOBAL_VARIABLES => GLOBAL_VARIABLES_; - public string TRAIN_OP = "train_op"; + public string TRAIN_OP => TRAIN_OP_; - public string GLOBAL_STEP = "global_step"; + public string GLOBAL_STEP => GLOBAL_STEP_; - public string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables", "model_variables" }; + public string[] _VARIABLE_COLLECTIONS => _VARIABLE_COLLECTIONS_; /// /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. /// - public string SAVEABLE_OBJECTS = "saveable_objects"; + public string SAVEABLE_OBJECTS => SAVEABLE_OBJECTS_; /// /// Key to collect update_ops /// - public string UPDATE_OPS = "update_ops"; + public string UPDATE_OPS => UPDATE_OPS_; // Key to collect summaries. - public string SUMMARIES = "summaries"; + public string SUMMARIES => SUMMARIES_; // Used to store v2 summary names. - public string _SUMMARY_COLLECTION = "_SUMMARY_V2"; + public string _SUMMARY_COLLECTION => _SUMMARY_COLLECTION_; // Key for control flow context. - public string COND_CONTEXT = "cond_context"; - public string WHILE_CONTEXT = "while_context"; + public string COND_CONTEXT => COND_CONTEXT_; + public string WHILE_CONTEXT => WHILE_CONTEXT_; } } } diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj index 1bd3d530..55e9b27d 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj @@ -6,6 +6,14 @@ false + + bin\debug-gpu + + + + bin\release-gpu + + diff --git a/test/TensorFlowNET.UnitTest/ConstantTest.cs b/test/TensorFlowNET.UnitTest/ConstantTest.cs index c1d4c9e5..b532e558 100644 --- a/test/TensorFlowNET.UnitTest/ConstantTest.cs +++ b/test/TensorFlowNET.UnitTest/ConstantTest.cs @@ -98,9 +98,9 @@ namespace TensorFlowNET.UnitTest { var result = sess.run(tensor); - Assert.AreEqual(result[0].shape[0], 3); - Assert.AreEqual(result[0].shape[1], 2); - Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, result[0].Data())); + Assert.AreEqual(result.shape[0], 3); + Assert.AreEqual(result.shape[1], 2); + Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, result.Data())); } // big size @@ -109,13 +109,13 @@ namespace TensorFlowNET.UnitTest { var result = sess.run(tensor); - Assert.AreEqual(result[0].shape[0], 200); - Assert.AreEqual(result[0].shape[1], 100); + Assert.AreEqual(result.shape[0], 200); + Assert.AreEqual(result.shape[1], 100); - var data = result[0].Data(); + var data = result.Data(); Assert.AreEqual(0, data[0]); Assert.AreEqual(0, data[500]); - Assert.AreEqual(0, data[result[0].size - 1]); + Assert.AreEqual(0, data[result.size - 1]); } } @@ -127,9 +127,9 @@ namespace TensorFlowNET.UnitTest { var result = sess.run(ones); - Assert.AreEqual(result[0].shape[0], 3); - Assert.AreEqual(result[0].shape[1], 2); - Assert.IsTrue(new[] { 1, 1, 1, 1, 1, 1 }.SequenceEqual(result[0].Data())); + Assert.AreEqual(result.shape[0], 3); + Assert.AreEqual(result.shape[1], 2); + Assert.IsTrue(new[] { 1, 1, 1, 1, 1, 1 }.SequenceEqual(result.Data())); } } @@ -142,9 +142,9 @@ namespace TensorFlowNET.UnitTest { var result = sess.run(halfes); - Assert.AreEqual(result[0].shape[0], 3); - Assert.AreEqual(result[0].shape[1], 2); - Assert.IsTrue(new[] { .5, .5, .5, .5, .5, .5 }.SequenceEqual(result[0].Data())); + Assert.AreEqual(result.shape[0], 3); + Assert.AreEqual(result.shape[1], 2); + Assert.IsTrue(new[] { .5, .5, .5, .5, .5, .5 }.SequenceEqual(result.Data())); } } @@ -161,10 +161,10 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session()) { var result = sess.run(tensor); - var data = result[0].Data(); + var data = result.Data(); - Assert.AreEqual(result[0].shape[0], 2); - Assert.AreEqual(result[0].shape[1], 3); + Assert.AreEqual(result.shape[0], 2); + Assert.AreEqual(result.shape[1], 3); Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 1, 2, 1, 3 }, data)); } } @@ -177,7 +177,7 @@ namespace TensorFlowNET.UnitTest var c = a * b; var sess = tf.Session(); - double result = sess.run(c)[0]; + double result = sess.run(c); sess.close(); Assert.AreEqual(6.0, result); diff --git a/test/TensorFlowNET.UnitTest/GradientTest.cs b/test/TensorFlowNET.UnitTest/GradientTest.cs index b52bc1cf..c8e57ba4 100644 --- a/test/TensorFlowNET.UnitTest/GradientTest.cs +++ b/test/TensorFlowNET.UnitTest/GradientTest.cs @@ -41,7 +41,7 @@ namespace TensorFlowNET.UnitTest var grad = tf.gradients(y, x); Assert.AreEqual(grad[0].name, "gradients/AddN:0"); - float r = sess.run(grad[0])[0]; + float r = sess.run(grad[0]); Assert.AreEqual(r, 1.4f); } } @@ -57,7 +57,7 @@ namespace TensorFlowNET.UnitTest var grad = tf.gradients(y, x); Assert.AreEqual(grad[0].name, "gradients/AddN:0"); - float r = sess.run(grad[0])[0]; + float r = sess.run(grad[0]); Assert.AreEqual(r, 14.700001f); }); } @@ -94,7 +94,7 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session(graph)) { - var r = sess.run(slice)[0]; + var r = sess.run(slice); Assert.IsTrue(Enumerable.SequenceEqual(r.shape, new[] { 2, 1, 2 })); Assert.IsTrue(Enumerable.SequenceEqual(r[0].GetData(), new[] { 11, 13 })); diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index 4c6ae3d0..0caa5259 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -44,7 +44,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, 3.0f), new FeedItem(b, 2.0f)); - Assert.AreEqual((float)o[0], 5.0f); + Assert.AreEqual((float)o, 5.0f); } } @@ -58,7 +58,7 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session()) { var o = sess.run(c); - Assert.AreEqual((float)o[0], 9.0f); + Assert.AreEqual((float)o, 9.0f); } } @@ -72,7 +72,7 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session()) { var o = sess.run(b); - Assert.IsTrue(o[0].array_equal(check)); + Assert.IsTrue(o.array_equal(check)); } } @@ -86,7 +86,7 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session()) { var o = sess.run(b); - Assert.IsTrue(o[0].array_equal(check)); + Assert.IsTrue(o.array_equal(check)); } } @@ -100,7 +100,7 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session()) { var o = sess.run(b); - Assert.IsTrue(o[0].array_equal(check)); + Assert.IsTrue(o.array_equal(check)); } b = tf.cumsum(a, exclusive: true); @@ -109,7 +109,7 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session()) { var o = sess.run(b); - Assert.IsTrue(o[0].array_equal(check)); + Assert.IsTrue(o.array_equal(check)); } b = tf.cumsum(a, reverse: true); @@ -118,7 +118,7 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session()) { var o = sess.run(b); - Assert.IsTrue(o[0].array_equal(check)); + Assert.IsTrue(o.array_equal(check)); } b = tf.cumsum(a, exclusive:true, reverse: true); @@ -127,7 +127,7 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session()) { var o = sess.run(b); - Assert.IsTrue(o[0].array_equal(check)); + Assert.IsTrue(o.array_equal(check)); } } @@ -143,7 +143,7 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session()) { var o = sess.run(d); - Assert.IsTrue(o[0].array_equal(check)); + Assert.IsTrue(o.array_equal(check)); } d = tf.cast(tf.logical_not(b), tf.int32); @@ -152,7 +152,7 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session()) { var o = sess.run(d); - Assert.IsTrue(o[0].array_equal(check)); + Assert.IsTrue(o.array_equal(check)); } d = tf.cast(tf.logical_or(b, c), tf.int32); @@ -161,7 +161,7 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session()) { var o = sess.run(d); - Assert.IsTrue(o[0].array_equal(check)); + Assert.IsTrue(o.array_equal(check)); } d = tf.cast(tf.logical_xor(b, c), tf.int32); @@ -170,7 +170,7 @@ namespace TensorFlowNET.UnitTest using (var sess = tf.Session()) { var o = sess.run(d); - Assert.IsTrue(o[0].array_equal(check)); + Assert.IsTrue(o.array_equal(check)); } } @@ -197,7 +197,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator +(Tensor x, Tensor y)` @@ -207,7 +207,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator +(Tensor x, int y)` @@ -216,7 +216,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator +(int x, Tensor y)` @@ -225,7 +225,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } #endregion @@ -246,7 +246,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } // Testing `operator +(Tensor x, Tensor y) @@ -256,7 +256,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } // Testing `operator +(Tensor x, float y) @@ -265,7 +265,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } // Testing `operator +(float x, Tensor y) @@ -274,7 +274,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } #endregion @@ -295,7 +295,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } // Testing `operator +(Tensor x, Tensor y) @@ -305,7 +305,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } // Testing `operator +(Tensor x, double y) @@ -314,7 +314,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } // Testing `operator +(double x, Tensor y) @@ -323,7 +323,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } #endregion } @@ -352,7 +352,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator -(Tensor x, Tensor y) @@ -362,7 +362,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator -(Tensor x, int y) @@ -371,7 +371,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator -(int x, Tensor y) @@ -380,7 +380,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], Math.Abs(intResult)); + Assert.AreEqual((int)o, Math.Abs(intResult)); } // Testing `operator -(Tensor x) @@ -389,7 +389,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResultTwo); + Assert.AreEqual((int)o, intResultTwo); } #endregion @@ -411,7 +411,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } // Testing `operator -(Tensor x, Tensor y) @@ -421,7 +421,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } // Testing `operator -(Tensor x, float y) @@ -430,7 +430,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } // Testing `operator -(float x, Tensor y) @@ -439,7 +439,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], Math.Abs(floatResult)); + Assert.AreEqual((float)o, Math.Abs(floatResult)); } // Testing `operator -(Tensor x) @@ -448,7 +448,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResultTwo); + Assert.AreEqual((float)o, floatResultTwo); } #endregion @@ -470,7 +470,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } // Testing `operator -(Tensor x, Tensor y) @@ -480,7 +480,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } // Testing `operator -(Tensor x, double y) @@ -489,7 +489,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } // Testing `operator -(double x, Tensor y) @@ -498,7 +498,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], Math.Abs(doubleResult)); + Assert.AreEqual((double)o, Math.Abs(doubleResult)); } // Testing `operator -(Tensor x) @@ -507,7 +507,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResultTwo); + Assert.AreEqual((double)o, doubleResultTwo); } #endregion } @@ -593,7 +593,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator *(Tensor x, Tensor y) @@ -603,7 +603,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator *(Tensor x, int y) @@ -612,7 +612,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator *(int x, Tensor y) @@ -621,7 +621,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } #endregion @@ -642,7 +642,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } // Testing `operator *(Tensor x, Tensor y) @@ -652,7 +652,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } // Testing `operator *(Tensor x, float y) @@ -661,7 +661,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } // Testing `operator *(float x, Tensor y) @@ -670,7 +670,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } #endregion @@ -691,7 +691,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } // Testing `operator *(Tensor x, Tensor y) @@ -701,7 +701,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } // Testing `operator *(Tensor x, double y) @@ -710,7 +710,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } // Testing `operator *(double x, Tensor y) @@ -719,7 +719,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } #endregion } @@ -747,7 +747,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator /(Tensor x, Tensor y) @@ -757,7 +757,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator /(Tensor x, int y) @@ -766,7 +766,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator /(int x, Tensor y) @@ -775,7 +775,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } #endregion @@ -796,7 +796,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } // Testing `operator /(Tensor x, Tensor y) @@ -806,7 +806,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } // Testing `operator /(Tensor x, float y) @@ -815,7 +815,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } // Testing `operator /(float x, Tensor y) @@ -824,7 +824,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((float)o[0], floatResult); + Assert.AreEqual((float)o, floatResult); } #endregion @@ -845,7 +845,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } // Testing `operator /(Tensor x, Tensor y) @@ -855,7 +855,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } // Testing `operator /(Tensor x, double y) @@ -864,7 +864,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } // Testing `operator /(double x, Tensor y) @@ -873,7 +873,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((double)o[0], doubleResult); + Assert.AreEqual((double)o, doubleResult); } #endregion } @@ -901,7 +901,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator >(Tensor x, Tensor y) @@ -911,7 +911,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator >(Tensor x, int y) @@ -920,7 +920,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator >(int x, Tensor y) @@ -929,7 +929,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResultTwo); + Assert.AreEqual((int)o, intResultTwo); } #endregion @@ -950,7 +950,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResult); + Assert.AreEqual((int)o, floatResult); } // Testing `operator >(Tensor x, Tensor y) @@ -960,7 +960,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResult); + Assert.AreEqual((int)o, floatResult); } // Testing `operator >(Tensor x, float y) @@ -969,7 +969,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResult); + Assert.AreEqual((int)o, floatResult); } // Testing `operator >(float x, Tensor y) @@ -978,7 +978,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResultTwo); + Assert.AreEqual((int)o, floatResultTwo); } #endregion @@ -999,7 +999,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResult); + Assert.AreEqual((int)o, doubleResult); } // Testing `operator >(Tensor x, Tensor y) @@ -1009,7 +1009,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResult); + Assert.AreEqual((int)o, doubleResult); } // Testing `operator >(Tensor x, double y) @@ -1018,7 +1018,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResult); + Assert.AreEqual((int)o, doubleResult); } // Testing `operator >(double x, Tensor y) @@ -1027,7 +1027,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResultTwo); + Assert.AreEqual((int)o, doubleResultTwo); } #endregion } @@ -1055,7 +1055,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator <(Tensor x, Tensor y) @@ -1065,7 +1065,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator <(Tensor x, int y) @@ -1074,7 +1074,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator <(int x, Tensor y) @@ -1083,7 +1083,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResultTwo); + Assert.AreEqual((int)o, intResultTwo); } #endregion @@ -1104,7 +1104,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResult); + Assert.AreEqual((int)o, floatResult); } // Testing `operator <(Tensor x, Tensor y) @@ -1114,7 +1114,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResult); + Assert.AreEqual((int)o, floatResult); } // Testing `operator <(Tensor x, float y) @@ -1123,7 +1123,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResult); + Assert.AreEqual((int)o, floatResult); } // Testing `operator <(float x, Tensor y) @@ -1132,7 +1132,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResultTwo); + Assert.AreEqual((int)o, floatResultTwo); } #endregion @@ -1153,7 +1153,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResult); + Assert.AreEqual((int)o, doubleResult); } // Testing `operator <(Tensor x, Tensor y) @@ -1163,7 +1163,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResult); + Assert.AreEqual((int)o, doubleResult); } // Testing `operator <(Tensor x, double y) @@ -1172,7 +1172,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResult); + Assert.AreEqual((int)o, doubleResult); } // Testing `operator <(double x, Tensor y) @@ -1181,7 +1181,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResultTwo); + Assert.AreEqual((int)o, doubleResultTwo); } #endregion } @@ -1209,7 +1209,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator >=(Tensor x, Tensor y) @@ -1219,7 +1219,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator >=(Tensor x, int y) @@ -1228,7 +1228,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator >=(int x, Tensor y) @@ -1237,7 +1237,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResultTwo); + Assert.AreEqual((int)o, intResultTwo); } #endregion @@ -1258,7 +1258,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResult); + Assert.AreEqual((int)o, floatResult); } // Testing `operator >=(Tensor x, Tensor y) @@ -1268,7 +1268,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResult); + Assert.AreEqual((int)o, floatResult); } // Testing `operator >=(Tensor x, float y) @@ -1277,7 +1277,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResult); + Assert.AreEqual((int)o, floatResult); } // Testing `operator >=(float x, Tensor y) @@ -1286,7 +1286,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResultTwo); + Assert.AreEqual((int)o, floatResultTwo); } #endregion @@ -1307,7 +1307,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResult); + Assert.AreEqual((int)o, doubleResult); } // Testing `operator >=(Tensor x, Tensor y) @@ -1317,7 +1317,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResult); + Assert.AreEqual((int)o, doubleResult); } // Testing `operator >=(Tensor x, double y) @@ -1326,7 +1326,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResult); + Assert.AreEqual((int)o, doubleResult); } // Testing `operator >=(double x, Tensor y) @@ -1335,7 +1335,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResultTwo); + Assert.AreEqual((int)o, doubleResultTwo); } #endregion } @@ -1363,7 +1363,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator <=(Tensor x, Tensor y) @@ -1373,7 +1373,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator <=(Tensor x, int y) @@ -1382,7 +1382,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResult); + Assert.AreEqual((int)o, intResult); } // Testing `operator <=(int x, Tensor y) @@ -1391,7 +1391,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], intResultTwo); + Assert.AreEqual((int)o, intResultTwo); } #endregion @@ -1412,7 +1412,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResult); + Assert.AreEqual((int)o, floatResult); } // Testing `operator <=(Tensor x, Tensor y) @@ -1422,7 +1422,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResult); + Assert.AreEqual((int)o, floatResult); } // Testing `operator <=(Tensor x, float y) @@ -1431,7 +1431,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResult); + Assert.AreEqual((int)o, floatResult); } // Testing `operator <=(float x, Tensor y) @@ -1440,7 +1440,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], floatResultTwo); + Assert.AreEqual((int)o, floatResultTwo); } #endregion @@ -1461,7 +1461,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResult); + Assert.AreEqual((int)o, doubleResult); } // Testing `operator <=(Tensor x, Tensor y) @@ -1471,7 +1471,7 @@ namespace TensorFlowNET.UnitTest var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResult); + Assert.AreEqual((int)o, doubleResult); } // Testing `operator <=(Tensor x, double y) @@ -1480,7 +1480,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResult); + Assert.AreEqual((int)o, doubleResult); } // Testing `operator <=(double x, Tensor y) @@ -1489,7 +1489,7 @@ namespace TensorFlowNET.UnitTest { var o = sess.run(c, new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); - Assert.AreEqual((int)o[0], doubleResultTwo); + Assert.AreEqual((int)o, doubleResultTwo); } #endregion } diff --git a/test/TensorFlowNET.UnitTest/PlaceholderTest.cs b/test/TensorFlowNET.UnitTest/PlaceholderTest.cs index 14b16c23..5135bd25 100644 --- a/test/TensorFlowNET.UnitTest/PlaceholderTest.cs +++ b/test/TensorFlowNET.UnitTest/PlaceholderTest.cs @@ -17,7 +17,7 @@ namespace TensorFlowNET.UnitTest { var result = sess.run(y, new FeedItem(x, 2)); - Assert.AreEqual((int)result[0], 6); + Assert.AreEqual((int)result, 6); } } } diff --git a/test/TensorFlowNET.UnitTest/SessionTest.cs b/test/TensorFlowNET.UnitTest/SessionTest.cs index 9c8485ec..8fd4dc8a 100644 --- a/test/TensorFlowNET.UnitTest/SessionTest.cs +++ b/test/TensorFlowNET.UnitTest/SessionTest.cs @@ -45,7 +45,7 @@ namespace TensorFlowNET.UnitTest EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); EXPECT_EQ(0, outTensor.NDims); ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); - var output_contents = outTensor.Data(); + var output_contents = outTensor.ToArray(); EXPECT_EQ(3 + 2, output_contents[0]); // Add another operation to the graph. @@ -66,7 +66,7 @@ namespace TensorFlowNET.UnitTest EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); EXPECT_EQ(0, outTensor.NDims); // scalar ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); - output_contents = outTensor.Data(); + output_contents = outTensor.ToArray(); EXPECT_EQ(-(7 + 2), output_contents[0]); // Clean up diff --git a/test/TensorFlowNET.UnitTest/TensorTest.cs b/test/TensorFlowNET.UnitTest/TensorTest.cs index 07da9dca..11557f14 100644 --- a/test/TensorFlowNET.UnitTest/TensorTest.cs +++ b/test/TensorFlowNET.UnitTest/TensorTest.cs @@ -112,7 +112,7 @@ namespace TensorFlowNET.UnitTest var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); var tensor = new Tensor(nd); - var array = tensor.Data(); + var array = tensor.ToArray(); EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT); EXPECT_EQ(tensor.rank, nd.ndim); diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs index 4c5ddd7a..7673cac8 100644 --- a/test/TensorFlowNET.UnitTest/VariableTest.cs +++ b/test/TensorFlowNET.UnitTest/VariableTest.cs @@ -1,4 +1,5 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp; using Tensorflow; using static Tensorflow.Binding; @@ -16,7 +17,7 @@ namespace TensorFlowNET.UnitTest { session.run(x.initializer); var result = session.run(x); - Assert.AreEqual(10, (int)result[0]); + Assert.AreEqual(10, (int)result); } } @@ -81,7 +82,7 @@ namespace TensorFlowNET.UnitTest using (var session = tf.Session()) { session.run(model); - int result = session.run(y)[0]; + int result = session.run(y); Assert.AreEqual(result, 4); } } @@ -97,12 +98,12 @@ namespace TensorFlowNET.UnitTest var sess = tf.Session(graph); sess.run(init); - var result = sess.run(variable); - Assert.IsTrue((int)result[0] == 31); + NDArray result = sess.run(variable); + Assert.IsTrue((int)result == 31); var assign = variable.assign(12); result = sess.run(assign); - Assert.IsTrue((int)result[0] == 12); + Assert.IsTrue((int)result == 12); } [TestMethod] @@ -139,7 +140,7 @@ namespace TensorFlowNET.UnitTest for(int i = 0; i < 5; i++) { x = x + 1; - result = session.run(x)[0]; + result = session.run(x); print(result); } } diff --git a/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs b/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs index 1fd7d3aa..3a5515d9 100644 --- a/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs +++ b/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs @@ -31,7 +31,7 @@ namespace TensorFlowNET.UnitTest.nn_test var y_np = this._ZeroFraction(x_np); var x_tf = constant_op.constant(x_np); - x_tf.SetShape(x_shape); + x_tf.set_shape(x_shape); var y_tf = nn_impl.zero_fraction(x_tf); var y_tf_np = self.evaluate(y_tf);