diff --git a/src/TensorFlowNET.Core/APIs/c_api.cs b/src/TensorFlowNET.Core/APIs/c_api.cs
index 1656edd0..adf0b86f 100644
--- a/src/TensorFlowNET.Core/APIs/c_api.cs
+++ b/src/TensorFlowNET.Core/APIs/c_api.cs
@@ -59,6 +59,6 @@ namespace Tensorflow
}
[DllImport(TensorFlowLibName)]
- public static unsafe extern IntPtr TF_Version();
+ public static extern IntPtr TF_Version();
}
}
diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs
index da2cdf6e..bfbfa4ec 100644
--- a/src/TensorFlowNET.Core/Binding.Util.cs
+++ b/src/TensorFlowNET.Core/Binding.Util.cs
@@ -308,15 +308,14 @@ namespace Tensorflow
public static IEnumerable TupleToEnumerable(object tuple)
{
Type t = tuple.GetType();
- if(t.IsGenericType && (t.FullName.StartsWith("System.Tuple") || t.FullName.StartsWith("System.ValueTuple")))
+ if (t.IsGenericType && (t.FullName.StartsWith("System.Tuple") || t.FullName.StartsWith("System.ValueTuple")))
{
var flds = t.GetFields();
- for(int i = 0; i < flds.Length;i++)
+ for (int i = 0; i < flds.Length; i++)
{
yield return flds[i].GetValue(tuple);
}
- }
- else
+ } else
{
throw new System.Exception("Expected Tuple.");
}
@@ -329,12 +328,9 @@ namespace Tensorflow
public static bool isinstance(object Item1, object tuple)
{
- var tup = TupleToEnumerable(tuple);
- foreach(var t in tup)
- {
- if(isinstance(Item1, (Type)t))
+ foreach (var t in TupleToEnumerable(tuple))
+ if (isinstance(Item1, (Type) t))
return true;
- }
return false;
}
}
diff --git a/src/TensorFlowNET.Core/Buffers/Buffer.cs b/src/TensorFlowNET.Core/Buffers/Buffer.cs
index dbe576b8..396fb311 100644
--- a/src/TensorFlowNET.Core/Buffers/Buffer.cs
+++ b/src/TensorFlowNET.Core/Buffers/Buffer.cs
@@ -66,7 +66,7 @@ namespace Tensorflow
return buffer.Data;
}
- protected override void DisposeUnManagedState(IntPtr handle)
+ protected override void DisposeUnmanagedResources(IntPtr handle)
=> c_api.TF_DeleteBuffer(handle);
}
}
diff --git a/src/TensorFlowNET.Core/DisposableObject.cs b/src/TensorFlowNET.Core/DisposableObject.cs
index 7e416e6d..688ac92c 100644
--- a/src/TensorFlowNET.Core/DisposableObject.cs
+++ b/src/TensorFlowNET.Core/DisposableObject.cs
@@ -29,18 +29,10 @@ namespace Tensorflow
protected DisposableObject() { }
- public DisposableObject(IntPtr handle)
- {
- _handle = handle;
- }
-
- protected virtual void DisposeManagedState()
- {
- }
+ protected DisposableObject(IntPtr handle)
+ => _handle = handle;
- protected abstract void DisposeUnManagedState(IntPtr handle);
-
- protected virtual void Dispose(bool disposing)
+ private void internal_dispose(bool disposing)
{
if (disposing)
{
@@ -48,30 +40,43 @@ namespace Tensorflow
if (_handle != IntPtr.Zero)
{
// dispose managed state (managed objects).
- DisposeManagedState();
+ DisposeManagedResources();
// set large fields to null.
- DisposeUnManagedState(_handle);
+ DisposeUnmanagedResources(_handle);
_handle = IntPtr.Zero;
}
}
}
+ ///
+ /// Dispose any managed resources.
+ ///
+ /// Equivalent to what you would perform inside
+ protected virtual void DisposeManagedResources()
+ {
+ }
+
+ ///
+ /// Dispose any unmanaged resources related to given .
+ ///
+ protected abstract void DisposeUnmanagedResources(IntPtr handle);
+
// override a finalizer only if Dispose(bool disposing) above has code to free unmanaged resources.
~DisposableObject()
{
// Do not change this code. Put cleanup code in Dispose(bool disposing) above.
- Dispose(false);
+ internal_dispose(false);
}
// This code added to correctly implement the disposable pattern.
public void Dispose()
{
// Do not change this code. Put cleanup code in Dispose(bool disposing) above.
- Dispose(true);
+ internal_dispose(true);
// uncomment the following line if the finalizer is overridden above.
GC.SuppressFinalize(this);
}
}
-}
+}
\ No newline at end of file
diff --git a/src/TensorFlowNET.Core/Eager/ContextOptions.cs b/src/TensorFlowNET.Core/Eager/ContextOptions.cs
index 4bffddf6..4bdf04b3 100644
--- a/src/TensorFlowNET.Core/Eager/ContextOptions.cs
+++ b/src/TensorFlowNET.Core/Eager/ContextOptions.cs
@@ -1,8 +1,9 @@
using System;
+using System.IO;
namespace Tensorflow.Eager
{
- public class ContextOptions : IDisposable
+ public class ContextOptions : IDisposable //TODO! Eli: Shouldn't this inherieting DisposableObject?
{
private IntPtr _handle;
diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs
index 77926dca..07dc117e 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.cs
@@ -23,57 +23,58 @@ using static Tensorflow.Binding;
namespace Tensorflow
{
+ /*
+ A TensorFlow computation, represented as a dataflow graph.
+
+ A `Graph` contains a set of
+ `tf.Operation` objects,
+ which represent units of computation; and
+ `tf.Tensor` objects, which represent
+ the units of data that flow between operations.
+
+ A default `Graph` is always registered, and accessible by calling
+ `tf.get_default_graph`.
+ To add an operation to the default graph, simply call one of the functions
+ that defines a new `Operation`:
+
+ ```python
+ c = tf.constant(4.0)
+ assert c.graph is tf.get_default_graph()
+ ```
+
+ Another typical usage involves the
+ `tf.Graph.as_default`
+ context manager, which overrides the current default graph for the
+ lifetime of the context:
+
+ ```python
+ g = tf.Graph()
+ with g.as_default():
+ # Define operations and tensors in `g`.
+ c = tf.constant(30.0)
+ assert c.graph is g
+ ```
+
+ Important note: This class *is not* thread-safe for graph construction. All
+ operations should be created from a single thread, or external
+ synchronization must be provided. Unless otherwise specified, all methods
+ are not thread-safe.
+
+ A `Graph` instance supports an arbitrary number of "collections"
+ that are identified by name. For convenience when building a large
+ graph, collections can store groups of related objects: for
+ example, the `tf.Variable` uses a collection (named
+ `tf.GraphKeys.GLOBAL_VARIABLES`) for
+ all variables that are created during the construction of a graph. The caller
+ may define additional collections by specifying a new name.
+ */
+
///
- /// TensorFlow uses a dataflow graph to represent your computation in terms of the dependencies between individual operations.
- /// This leads to a low-level programming model in which you first define the dataflow graph,
- /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices.
- /// https://www.tensorflow.org/guide/graphs
+ /// TensorFlow uses a dataflow graph to represent your computation in terms of the dependencies between individual operations.
+ /// This leads to a low-level programming model in which you first define the dataflow graph,
+ /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices.
///
- /*
- A TensorFlow computation, represented as a dataflow graph.
-
- A `Graph` contains a set of
- `tf.Operation` objects,
- which represent units of computation; and
- `tf.Tensor` objects, which represent
- the units of data that flow between operations.
-
- A default `Graph` is always registered, and accessible by calling
- `tf.get_default_graph`.
- To add an operation to the default graph, simply call one of the functions
- that defines a new `Operation`:
-
- ```python
- c = tf.constant(4.0)
- assert c.graph is tf.get_default_graph()
- ```
-
- Another typical usage involves the
- `tf.Graph.as_default`
- context manager, which overrides the current default graph for the
- lifetime of the context:
-
- ```python
- g = tf.Graph()
- with g.as_default():
- # Define operations and tensors in `g`.
- c = tf.constant(30.0)
- assert c.graph is g
- ```
-
- Important note: This class *is not* thread-safe for graph construction. All
- operations should be created from a single thread, or external
- synchronization must be provided. Unless otherwise specified, all methods
- are not thread-safe.
-
- A `Graph` instance supports an arbitrary number of "collections"
- that are identified by name. For convenience when building a large
- graph, collections can store groups of related objects: for
- example, the `tf.Variable` uses a collection (named
- `tf.GraphKeys.GLOBAL_VARIABLES`) for
- all variables that are created during the construction of a graph. The caller
- may define additional collections by specifying a new name.
- */
+ /// https://www.tensorflow.org/guide/graphs
https://www.tensorflow.org/api_docs/python/tf/Graph
public partial class Graph : DisposableObject, IEnumerable
{
private Dictionary _nodes_by_id;
@@ -439,12 +440,12 @@ namespace Tensorflow
_unfetchable_ops.Add(op);
}
- protected override void DisposeManagedState()
+ protected override void DisposeManagedResources()
{
ops.default_graph_stack.remove(this);
}
- protected override void DisposeUnManagedState(IntPtr handle)
+ protected override void DisposeUnmanagedResources(IntPtr handle)
{
c_api.TF_DeleteGraph(handle);
}
diff --git a/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs b/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs
index 97720206..bdcaf60c 100644
--- a/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs
+++ b/src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs
@@ -37,7 +37,7 @@ namespace Tensorflow
c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index);
}
- protected override void DisposeUnManagedState(IntPtr handle)
+ protected override void DisposeUnmanagedResources(IntPtr handle)
=> c_api.TF_DeleteImportGraphDefOptions(handle);
public static implicit operator IntPtr(ImportGraphDefOptions opts) => opts._handle;
diff --git a/src/TensorFlowNET.Core/IO/gfile.cs b/src/TensorFlowNET.Core/IO/gfile.cs
index 930dd652..a7303bf6 100644
--- a/src/TensorFlowNET.Core/IO/gfile.cs
+++ b/src/TensorFlowNET.Core/IO/gfile.cs
@@ -16,6 +16,7 @@
using System.Collections.Generic;
using System.IO;
+using System.Linq;
namespace Tensorflow.IO
{
@@ -28,6 +29,9 @@ namespace Tensorflow.IO
/// Traverse in order if True, post order if False.
public IEnumerable<(string, string[], string[])> Walk(string top, bool in_order = true)
{
+ if (!Directory.Exists(top))
+ return Enumerable.Empty<(string, string[], string[])>();
+
return walk_v2(top, in_order);
}
diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
index 2c05e36a..2a76c52c 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
@@ -141,7 +141,7 @@ namespace Tensorflow.Operations
data, frame_name, is_constant, parallel_iterations, name: name);
if (use_input_shape)
- result.SetShape(data.TensorShape);
+ result.set_shape(data.TensorShape);
return result;
}
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
index 3198942b..1b68d1cd 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
@@ -233,7 +233,7 @@ namespace Tensorflow.Operations
dims.AddRange(x_static_shape.dims.Skip(2));
var shape = new TensorShape(dims.ToArray());
- x_t.SetShape(shape);
+ x_t.set_shape(shape);
return x_t;
}
diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs
index d3213250..92fe2e3c 100644
--- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs
@@ -351,7 +351,7 @@ namespace Tensorflow
var input_shape = tensor_util.to_shape(input_tensor.shape);
if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined())
{
- var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_datatype());
+ var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_dtype());
return constant_op.constant(nd, name: name);
}
}
diff --git a/src/TensorFlowNET.Core/Operations/nn_ops.cs b/src/TensorFlowNET.Core/Operations/nn_ops.cs
index cbf55861..63e0fca1 100644
--- a/src/TensorFlowNET.Core/Operations/nn_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/nn_ops.cs
@@ -98,7 +98,7 @@ namespace Tensorflow
// float to be selected, hence we use a >= comparison.
var keep_mask = random_tensor >= rate;
var ret = x * scale * math_ops.cast(keep_mask, x.dtype);
- ret.SetShape(x.TensorShape);
+ ret.set_shape(x.TensorShape);
return ret;
});
}
diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs
index efe2afd4..4c5f2be3 100644
--- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs
+++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs
@@ -49,7 +49,7 @@ namespace Tensorflow
// dispose newOpts
if (opts == null)
- c_api.TF_DeleteSessionOptions(newOpts);
+ newOpts.Dispose();
status.Check(true);
}
@@ -64,6 +64,11 @@ namespace Tensorflow
return _run(fetche, feed_dict)[0];
}
+ public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict)
+ {
+ return _run(fetche, feed_dict)[0];
+ }
+
public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
{
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict);
@@ -273,7 +278,7 @@ namespace Tensorflow
{
var tensor = new Tensor(output);
NDArray nd = null;
- Type type = tensor.dtype.as_numpy_datatype();
+ Type type = tensor.dtype.as_numpy_dtype();
var ndims = tensor.shape;
var offset = c_api.TF_TensorData(output);
@@ -285,7 +290,7 @@ namespace Tensorflow
nd = NDArray.Scalar(*(bool*)offset);
break;
case TF_DataType.TF_STRING:
- var bytes = tensor.Data();
+ var bytes = tensor.BufferToArray();
// wired, don't know why we have to start from offset 9.
// length in the begin
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
@@ -324,7 +329,7 @@ namespace Tensorflow
nd = np.array(bools).reshape(ndims);
break;
case TF_DataType.TF_STRING:
- var bytes = tensor.Data();
+ var bytes = tensor.BufferToArray();
// wired, don't know why we have to start from offset 9.
// length in the begin
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
@@ -396,7 +401,7 @@ namespace Tensorflow
Dispose();
}
- protected override void DisposeUnManagedState(IntPtr handle)
+ protected override void DisposeUnmanagedResources(IntPtr handle)
{
using (var status = new Status())
{
diff --git a/src/TensorFlowNET.Core/Sessions/SessionOptions.cs b/src/TensorFlowNET.Core/Sessions/SessionOptions.cs
index 8e0a0a74..ed99b7fe 100644
--- a/src/TensorFlowNET.Core/Sessions/SessionOptions.cs
+++ b/src/TensorFlowNET.Core/Sessions/SessionOptions.cs
@@ -32,7 +32,7 @@ namespace Tensorflow
_handle = handle;
}
- protected override void DisposeUnManagedState(IntPtr handle)
+ protected override void DisposeUnmanagedResources(IntPtr handle)
=> c_api.TF_DeleteSessionOptions(handle);
public void SetConfig(ConfigProto config)
diff --git a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
index a46decb1..e1a77d90 100644
--- a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
+++ b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
@@ -17,6 +17,7 @@
using NumSharp;
using System;
using System.Collections.Generic;
+using NumSharp.Backends;
namespace Tensorflow
{
@@ -71,18 +72,18 @@ namespace Tensorflow
{
if(tensor_values.Length > 0)
{
- switch (tensor_values[0].dtype.Name)
+ switch (tensor_values[0].typecode)
{
- case "Int32":
+ case NPTypeCode.Int32:
full_values.Add(float.NaN);
break;
- case "Single":
+ case NPTypeCode.Single:
full_values.Add(float.NaN);
break;
- case "String":
+ case NPTypeCode.String:
full_values.Add(float.NaN);
break;
- case "Char":
+ case NPTypeCode.Char:
full_values.Add(float.NaN);
break;
default:
@@ -100,21 +101,21 @@ namespace Tensorflow
j += 1;
if (value.ndim == 0)
{
- switch (value.dtype.Name)
+ switch (value.typecode)
{
- case "Int16":
+ case NPTypeCode.Int16:
full_values.Add(value.GetValue(0));
break;
- case "Int32":
+ case NPTypeCode.Int32:
full_values.Add(value.GetValue(0));
break;
- case "Int64":
+ case NPTypeCode.Int64:
full_values.Add(value.GetValue(0));
break;
- case "Single":
+ case NPTypeCode.Single:
full_values.Add(value.GetValue(0));
break;
- case "Double":
+ case NPTypeCode.Double:
full_values.Add(value.GetValue(0));
break;
/*case "String":
diff --git a/src/TensorFlowNET.Core/Status/Status.cs b/src/TensorFlowNET.Core/Status/Status.cs
index 7eb2d7e3..2bdd806a 100644
--- a/src/TensorFlowNET.Core/Status/Status.cs
+++ b/src/TensorFlowNET.Core/Status/Status.cs
@@ -50,7 +50,7 @@ namespace Tensorflow
///
public void Check(bool throwException = false)
{
- if(Code != TF_Code.TF_OK)
+ if (Code != TF_Code.TF_OK)
{
Console.WriteLine(Message);
if (throwException)
@@ -65,7 +65,7 @@ namespace Tensorflow
return status._handle;
}
- protected override void DisposeUnManagedState(IntPtr handle)
+ protected override void DisposeUnmanagedResources(IntPtr handle)
=> c_api.TF_DeleteStatus(handle);
}
-}
+}
\ No newline at end of file
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
index 63fda866..73f116ec 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
@@ -16,11 +16,13 @@
using NumSharp;
using System;
+using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
+using NumSharp.Backends;
using NumSharp.Backends.Unmanaged;
using static Tensorflow.c_api;
@@ -462,7 +464,7 @@ namespace Tensorflow
*v = value;
_handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(Complex)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(Complex), deallocator: _hGlobalDeallocator, ref _deallocatorArgs);
IsMemoryOwner=true;
- }
+ }
#endif
///
@@ -477,7 +479,7 @@ namespace Tensorflow
IntPtr tensor = c_api.TF_TensorData(handle);
Marshal.WriteInt64(tensor, 0);
- fixed (byte* src = &buffer[0])
+ fixed (byte* src = buffer)
c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status);
_handle = handle;
status.Check(true);
@@ -486,35 +488,55 @@ namespace Tensorflow
public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null)
{
// todo: handle nd of type "String" here too
- if (tensorDType == TF_DataType.TF_STRING && nd.dtype.Name == "Byte")
+ if (tensorDType == TF_DataType.TF_STRING && nd.typecode == NPTypeCode.Byte)
{
- var buffer = nd.ToArray();
- var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length);
- var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8));
-
- IntPtr tensor = c_api.TF_TensorData(handle);
- Marshal.WriteInt64(tensor, 0);
+ if (nd.Unsafe.Storage.Shape.IsContiguous)
+ {
+ var bytesLength = (UIntPtr)nd.size;
+ var size = c_api.TF_StringEncodedSize(bytesLength);
+ var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8));
+
+ IntPtr tensor = c_api.TF_TensorData(handle);
+ Marshal.WriteInt64(tensor, 0);
+
+ var status = new Status();
+ c_api.TF_StringEncode((byte*) nd.Unsafe.Address, bytesLength, (sbyte*) (tensor + sizeof(Int64)), size, status);
+
+ status.Check(true);
+ _handle = handle;
+ IsMemoryOwner = false;
+ }
+ else
+ {
+ var buffer = nd.ToArray();
+ var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length);
+ var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8));
+
+ IntPtr tensor = c_api.TF_TensorData(handle);
+ Marshal.WriteInt64(tensor, 0);
+
+ var status = new Status();
+ fixed (byte* src = buffer)
+ c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(Int64)), size, status);
+
+ status.Check(true);
+ _handle = handle;
+ IsMemoryOwner = false;
+ }
- var status = new Status();
- fixed (byte* src = &buffer[0])
- c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status);
-
- status.Check(true);
- _handle=handle;
- IsMemoryOwner = false;
return;
}
+
_handle = CreateTensorFromNDArray(nd, tensorDType);
IsMemoryOwner = true;
}
private unsafe IntPtr CreateTensorFromNDArray(NDArray nd, TF_DataType? given_dtype)
{
- if (nd.dtype.Name == "String")
- throw new NotImplementedException("Support for NDArray of type string not implemented yet");
+ if (nd.dtype.Name == "String")
+ throw new NotImplementedException("Support for NDArray of type string not implemented yet");
IArraySlice arraySlice;
- var shape = nd.Unsafe.Storage.Shape;
- if (shape.IsSliced || shape.IsBroadcasted)
+ if (nd.Unsafe.Storage.Shape.IsContiguous == false)
{
// the memory is NOT contiguous, so we have to copy the view into a contiguous memory block.
arraySlice = nd.CloneData();
@@ -527,51 +549,52 @@ namespace Tensorflow
this.Tag = arraySlice; // keep a reference to the memory block to make sure it is not disposed while TF is using it
var ptr = new IntPtr(arraySlice.Address);
int num_bytes = (nd.size * nd.dtypesize);
- var dtype = given_dtype ?? ToTFDataType(nd.dtype);
+ var dtype = given_dtype ?? nd.dtype.as_dtype();
var handle = TF_NewTensor(dtype, dims: nd.shape.Select(i=>(long)i).ToArray(), num_dims: nd.ndim, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs);
IsMemoryOwner = false;
return handle;
- }
-
- public unsafe Tensor(byte[][] buffer, long[] shape)
- {
- int size = 0;
- foreach (var b in buffer)
- {
- size += (int)TF_StringEncodedSize((UIntPtr)b.Length);
- }
- int totalSize = size + buffer.Length * 8;
- ulong offset = 0;
- IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr)totalSize);
-
- // Clear offset table
- IntPtr pOffset = TF_TensorData(handle);
- IntPtr dst = pOffset + buffer.Length * 8;
- IntPtr dstLimit = pOffset + totalSize;
- for (int i = 0; i < buffer.Length; i++)
- {
- Marshal.WriteInt64(pOffset, (long)offset);
- using (var status = new Status())
- {
- fixed (byte* src = &buffer[i][0])
- {
- var written = TF_StringEncode(src, (UIntPtr)buffer[i].Length, (sbyte*)dst, (UIntPtr)(dstLimit.ToInt64() - dst.ToInt64()), status);
- status.Check(true);
- pOffset += 8;
- dst += (int)written;
- offset += written;
- }
- }
- }
-
- _handle = handle;
+
+ }
+
+ public unsafe Tensor(byte[][] buffer, long[] shape)
+ {
+ int size = 0;
+ foreach (var b in buffer)
+ {
+ size += (int)TF_StringEncodedSize((UIntPtr)b.Length);
+ }
+ int totalSize = size + buffer.Length * 8;
+ ulong offset = 0;
+ IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr)totalSize);
+
+ // Clear offset table
+ IntPtr pOffset = TF_TensorData(handle);
+ IntPtr dst = pOffset + buffer.Length * 8;
+ IntPtr dstLimit = pOffset + totalSize;
+ for (int i = 0; i < buffer.Length; i++)
+ {
+ Marshal.WriteInt64(pOffset, (long)offset);
+ using (var status = new Status())
+ {
+ fixed (byte* src = &buffer[i][0])
+ {
+ var written = TF_StringEncode(src, (UIntPtr)buffer[i].Length, (sbyte*)dst, (UIntPtr)(dstLimit.ToInt64() - dst.ToInt64()), status);
+ status.Check(true);
+ pOffset += 8;
+ dst += (int)written;
+ offset += written;
+ }
+ }
+ }
+
+ _handle = handle;
}
public Tensor(Operation op, int value_index, TF_DataType dtype)
{
_op = op;
_value_index = value_index;
- _dtype = dtype;
+ _override_dtype = dtype;
_id = ops.uid();
}
@@ -589,11 +612,11 @@ namespace Tensorflow
/// specified dimensions.
///
[MethodImpl(MethodImplOptions.AggressiveInlining)]
+ [SuppressMessage("ReSharper", "LocalVariableHidesMember")]
protected unsafe IntPtr CreateTensorWithoutCopying(TF_DataType dt, long[] shape, Array data, int element_size)
{
- if (dt == TF_DataType.TF_STRING && data is byte[])
+ if (dt == TF_DataType.TF_STRING && data is byte[] buffer)
{
- var buffer = (byte[])data;
var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length);
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8));
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs
index 6db60b4a..6d7f20f1 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs
@@ -1,4 +1,5 @@
using System;
+using System.Runtime.CompilerServices;
namespace Tensorflow
{
@@ -6,86 +7,142 @@ namespace Tensorflow
{
public static explicit operator bool(Tensor tensor)
{
- EnsureScalar(tensor);
- return tensor.Data()[0];
+ unsafe
+ {
+ EnsureScalar(tensor);
+ EnsureDType(tensor, TF_DataType.TF_BOOL);
+ return *(bool*) tensor.buffer;
+ }
}
public static explicit operator sbyte(Tensor tensor)
{
- EnsureScalar(tensor);
- return tensor.Data()[0];
+ unsafe
+ {
+ EnsureScalar(tensor);
+ EnsureDType(tensor, TF_DataType.TF_INT8);
+ return *(sbyte*) tensor.buffer;
+ }
}
public static explicit operator byte(Tensor tensor)
{
- EnsureScalar(tensor);
- return tensor.Data()[0];
+ unsafe
+ {
+ EnsureScalar(tensor);
+ EnsureDType(tensor, TF_DataType.TF_UINT8);
+ return *(byte*) tensor.buffer;
+ }
}
public static explicit operator ushort(Tensor tensor)
{
- EnsureScalar(tensor);
- return tensor.Data()[0];
+ unsafe
+ {
+ EnsureScalar(tensor);
+ EnsureDType(tensor, TF_DataType.TF_UINT16);
+ return *(ushort*) tensor.buffer;
+ }
}
public static explicit operator short(Tensor tensor)
{
- EnsureScalar(tensor);
- return tensor.Data()[0];
+ unsafe
+ {
+ EnsureScalar(tensor);
+ EnsureDType(tensor, TF_DataType.TF_INT16);
+ return *(short*) tensor.buffer;
+ }
}
public static explicit operator int(Tensor tensor)
{
- EnsureScalar(tensor);
- return tensor.Data()[0];
+ unsafe
+ {
+ EnsureScalar(tensor);
+ EnsureDType(tensor, TF_DataType.TF_INT32);
+ return *(int*) tensor.buffer;
+ }
}
public static explicit operator uint(Tensor tensor)
{
- EnsureScalar(tensor);
- return tensor.Data()[0];
+ unsafe
+ {
+ EnsureScalar(tensor);
+ EnsureDType(tensor, TF_DataType.TF_UINT32);
+ return *(uint*) tensor.buffer;
+ }
}
public static explicit operator long(Tensor tensor)
{
- EnsureScalar(tensor);
- return tensor.Data()[0];
+ unsafe
+ {
+ EnsureScalar(tensor);
+ EnsureDType(tensor, TF_DataType.TF_INT64);
+ return *(long*) tensor.buffer;
+ }
}
public static explicit operator ulong(Tensor tensor)
{
- EnsureScalar(tensor);
- return tensor.Data()[0];
+ unsafe
+ {
+ EnsureScalar(tensor);
+ EnsureDType(tensor, TF_DataType.TF_UINT64);
+ return *(ulong*) tensor.buffer;
+ }
}
public static explicit operator float(Tensor tensor)
{
- EnsureScalar(tensor);
- return tensor.Data()[0];
+ unsafe
+ {
+ EnsureScalar(tensor);
+ EnsureDType(tensor, TF_DataType.TF_FLOAT);
+ return *(float*) tensor.buffer;
+ }
}
public static explicit operator double(Tensor tensor)
{
- EnsureScalar(tensor);
- return tensor.Data()[0];
+ unsafe
+ {
+ EnsureScalar(tensor);
+ EnsureDType(tensor, TF_DataType.TF_DOUBLE);
+ return *(double*) tensor.buffer;
+ }
+ }
+
+ public static explicit operator string(Tensor tensor)
+ {
+ unsafe
+ {
+ EnsureScalar(tensor);
+ EnsureDType(tensor, TF_DataType.TF_STRING);
+ return new string((char*) tensor.buffer, 0, (int) tensor.size);
+ }
+ }
+
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ private static void EnsureDType(Tensor tensor, TF_DataType @is)
+ {
+ if (tensor.dtype != @is)
+ throw new InvalidCastException($"Unable to cast scalar tensor {tensor.dtype} to {@is}");
}
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void EnsureScalar(Tensor tensor)
{
if (tensor == null)
- {
throw new ArgumentNullException(nameof(tensor));
- }
if (tensor.TensorShape.ndim != 0)
- {
throw new ArgumentException("Tensor must have 0 dimensions in order to convert to scalar");
- }
if (tensor.TensorShape.size != 1)
- {
throw new ArgumentException("Tensor must have size 1 in order to convert to scalar");
- }
}
}
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
index eb912eb9..4b15864f 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
@@ -69,11 +69,12 @@ namespace Tensorflow
TF_DataType.TF_QINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QINT32,
TF_DataType.TF_UINT8, TF_DataType.TF_UINT16, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64
};
+
public static Tensor operator /(double x, Tensor y) => BinaryOpWrapper("truediv", x, y);
public static Tensor operator /(float x, Tensor y) => BinaryOpWrapper("truediv", x, y);
public static Tensor operator /(int x, Tensor y) => BinaryOpWrapper("floordiv", x, y);
public static Tensor operator /(Tensor x, Tensor y) =>
- _intTfDataTypes.Contains(x._dtype)
+ _intTfDataTypes.Contains(x.dtype)
? BinaryOpWrapper("floordiv", x, y)
: BinaryOpWrapper("truediv", x, y);
public static Tensor operator /(Tensor x, int y) => BinaryOpWrapper("floordiv", x, y);
@@ -122,8 +123,7 @@ namespace Tensorflow
if (y is Tensor tr)
dtype = tr.dtype.as_base_dtype();
- var namescope = ops.name_scope(null, name, new { x, y });
- return tf_with(namescope, scope =>
+ return tf_with(ops.name_scope(null, name, new { x, y }), scope =>
{
Tensor result = null;
var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x");
@@ -155,7 +155,6 @@ namespace Tensorflow
return result;
});
-
}
}
}
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs
index d52b9422..8ac6c73e 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs
@@ -17,9 +17,16 @@
using NumSharp;
using System;
using System.Collections.Generic;
+using System.Diagnostics.CodeAnalysis;
+using System.Globalization;
using System.Linq;
+using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
+using System.Threading.Tasks;
+using NumSharp.Backends;
+using NumSharp.Backends.Unmanaged;
+using NumSharp.Utilities;
using Tensorflow.Framework;
using static Tensorflow.Binding;
@@ -29,42 +36,68 @@ namespace Tensorflow
/// A tensor is a generalization of vectors and matrices to potentially higher dimensions.
/// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes.
///
+ [SuppressMessage("ReSharper", "ConvertToAutoProperty")]
public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike
{
- private int _id;
- private Operation _op;
+ private readonly int _id;
+ private readonly Operation _op;
+ private readonly int _value_index;
+ private TF_Output? _tf_output;
+ private readonly TF_DataType _override_dtype;
public int Id => _id;
+
+ ///
+ /// The Graph that contains this tensor.
+ ///
public Graph graph => op?.graph;
+
+ ///
+ /// The Operation that produces this tensor as an output.
+ ///
public Operation op => _op;
+
public Tensor[] outputs => op.outputs;
///
- /// The string name of this tensor.
+ /// The string name of this tensor.
///
public string name => $"{(op == null ? "" : $"{op.name}:{_value_index}")}";
- private int _value_index;
+ ///
+ /// The index of this tensor in the outputs of its Operation.
+ ///
public int value_index => _value_index;
- private TF_DataType _dtype = TF_DataType.DtInvalid;
- public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle);
+ ///
+ /// The DType of elements in this tensor.
+ ///
+ public TF_DataType dtype => _handle == IntPtr.Zero ? _override_dtype : c_api.TF_TensorType(_handle);
public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle);
-
public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype);
public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize;
-
public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle);
public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out);
+ public int NDims => rank;
- private TF_Output? _tf_output;
+ ///
+ /// The name of the device on which this tensor will be produced, or null.
+ ///
+ public string Device => op.Device;
+
+ public int[] dims => shape;
///
- /// used for keep other pointer when do implicit operating
+ /// Used for keep other pointer when do implicit operating
///
public object Tag { get; set; }
+
+ ///
+ /// Returns the shape of a tensor.
+ ///
+ /// https://www.tensorflow.org/api_docs/python/tf/shape
public int[] shape
{
get
@@ -76,14 +109,13 @@ namespace Tensorflow
var status = new Status();
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status);
status.Check();
- }
- else
+ } else
{
for (int i = 0; i < rank; i++)
dims[i] = c_api.TF_Dim(_handle, i);
}
- return dims.Select(x => Convert.ToInt32(x)).ToArray();
+ return dims.Select(x => ((IConvertible) x).ToInt32(CultureInfo.InvariantCulture)).ToArray();
}
set
@@ -93,38 +125,52 @@ namespace Tensorflow
if (value == null)
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status);
else
- c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(x => Convert.ToInt64(x)).ToArray(), value.Length, status);
+ c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status);
}
}
public int[] _shape_tuple()
{
- if (shape == null) return null;
- return shape.Select(x => (int)x).ToArray();
+ return (int[]) shape.Clone();
}
public TensorShape TensorShape => tensor_util.to_shape(shape);
- public void SetShape(TensorShape shape)
+ ///
+ /// Updates the shape of this tensor.
+ ///
+ public void set_shape(TensorShape shape)
{
- this.shape = shape.dims;
+ this.shape = (int[]) shape.dims.Clone();
}
+ ///
+ /// Updates the shape of this tensor.
+ ///
+ [Obsolete("Please use set_shape(TensorShape shape) instead.", false)]
+ public void SetShape(TensorShape shape)
+ {
+ this.shape = (int[]) shape.dims.Clone();
+ }
+
+ ///
+ /// Updates the shape of this tensor.
+ ///
public void set_shape(Tensor shape)
{
+ // ReSharper disable once MergeConditionalExpression
this.shape = shape is null ? null : shape.shape;
}
- public int[] dims => shape;
-
///
- /// number of dimensions
- /// 0 Scalar (magnitude only)
- /// 1 Vector (magnitude and direction)
- /// 2 Matrix (table of numbers)
- /// 3 3-Tensor (cube of numbers)
+ /// number of dimensions
+ /// 0 Scalar (magnitude only)
+ /// 1 Vector (magnitude and direction)
+ /// 2 Matrix (table of numbers)
+ /// 3 3-Tensor (cube of numbers)
/// n n-Tensor (you get the idea)
///
+ /// https://www.tensorflow.org/api_docs/python/tf/rank
public int rank
{
get
@@ -137,17 +183,15 @@ namespace Tensorflow
status.Check();
return ndim;
}
- else
- {
- return c_api.TF_NumDims(_handle);
- }
+
+ return c_api.TF_NumDims(_handle);
}
}
- public int NDims => rank;
-
- public string Device => op.Device;
-
+ ///
+ /// Returns a list of Operations that consume this tensor.
+ ///
+ ///
public Operation[] consumers()
{
var output = _as_tf_output();
@@ -157,37 +201,191 @@ namespace Tensorflow
public TF_Output _as_tf_output()
{
- if(!_tf_output.HasValue)
+ if (!_tf_output.HasValue)
_tf_output = new TF_Output(op, value_index);
return _tf_output.Value;
}
- public T[] Data()
+ [Obsolete("Please use ToArray() instead.", false)]
+ public T[] Data() where T : unmanaged
+ {
+ return ToArray();
+ }
+
+ ///
+ ///
+ ///
+ ///
+ ///
+ /// When is string
+ public T[] ToArray() where T : unmanaged
{
- // Column major order
- // https://en.wikipedia.org/wiki/File:Row_and_column_major_order.svg
- // matrix:[[1, 2, 3], [4, 5, 6]]
- // index: 0 2 4 1 3 5
- // result: 1 4 2 5 3 6
- var data = new T[size];
-
- for (ulong i = 0; i < size; i++)
+ //when T is string
+ if (typeof(T) == typeof(string))
{
- data[i] = Marshal.PtrToStructure(buffer + (int)(i * itemsize));
+ if (dtype != TF_DataType.TF_STRING)
+ throw new ArgumentException($"Given <{typeof(T).Name}> can't be converted to string.");
+
+ return (T[]) (object) StringData();
}
- return data;
+ //Are the types matching?
+ if (typeof(T).as_dtype() == dtype)
+ {
+ if (NDims == 0 && size == 1) //is it a scalar?
+ {
+ unsafe
+ {
+ return new T[] {*(T*) buffer};
+ }
+ }
+
+ //types match, no need to perform cast
+ var ret = new T[size];
+ unsafe
+ {
+ var len = (long) size;
+ fixed (T* dstRet = ret)
+ {
+ T* dst = dstRet; //local stack copy
+ if (typeof(T).IsPrimitive)
+ {
+ var src = (T*) buffer;
+ len *= ((long) itemsize);
+ System.Buffer.MemoryCopy(src, dst, len, len);
+ } else
+ {
+ var itemsize = (long) this.itemsize;
+ var buffer = this.buffer.ToInt64();
+ Parallel.For(0L, len, i => dst[i] = Marshal.PtrToStructure(new IntPtr(buffer + i * itemsize)));
+ }
+ }
+ }
+
+ return ret;
+ } else
+ {
+
+ //types do not match, need to perform cast
+ if (NDims == 0 && size == 1) //is it a scalar?
+ {
+ unsafe
+ {
+#if _REGEN
+ #region Compute
+ switch (dtype.as_numpy_dtype().GetTypeCode())
+ {
+ %foreach supported_dtypes,supported_dtypes_lowercase%
+ case NPTypeCode.#1: return new T[] {Converts.ChangeType(*(#2*) buffer, NPTypeCode.#1)};
+ %
+ case NPTypeCode.String: return new T[] {Converts.ChangeType((string)this, NPTypeCode.String)};
+ default:
+ throw new NotSupportedException();
+ }
+ #endregion
+#else
+ #region Compute
+ switch (dtype.as_numpy_dtype()?.GetTypeCode())
+ {
+ case NPTypeCode.Boolean: return new T[] {Converts.ChangeType(*(bool*) buffer, NPTypeCode.Boolean)};
+ case NPTypeCode.Byte: return new T[] {Converts.ChangeType(*(byte*) buffer, NPTypeCode.Byte)};
+ case NPTypeCode.Int16: return new T[] {Converts.ChangeType(*(short*) buffer, NPTypeCode.Int16)};
+ case NPTypeCode.UInt16: return new T[] {Converts.ChangeType(*(ushort*) buffer, NPTypeCode.UInt16)};
+ case NPTypeCode.Int32: return new T[] {Converts.ChangeType(*(int*) buffer, NPTypeCode.Int32)};
+ case NPTypeCode.UInt32: return new T[] {Converts.ChangeType(*(uint*) buffer, NPTypeCode.UInt32)};
+ case NPTypeCode.Int64: return new T[] {Converts.ChangeType(*(long*) buffer, NPTypeCode.Int64)};
+ case NPTypeCode.UInt64: return new T[] {Converts.ChangeType(*(ulong*) buffer, NPTypeCode.UInt64)};
+ case NPTypeCode.Char: return new T[] {Converts.ChangeType(*(char*) buffer, NPTypeCode.Char)};
+ case NPTypeCode.Double: return new T[] {Converts.ChangeType(*(double*) buffer, NPTypeCode.Double)};
+ case NPTypeCode.Single: return new T[] {Converts.ChangeType(*(float*) buffer, NPTypeCode.Single)};
+ case NPTypeCode.String: return new T[] {Converts.ChangeType((string)this, NPTypeCode.String)};
+ default:
+ throw new NotSupportedException();
+ }
+ #endregion
+#endif
+ }
+ }
+
+ var ret = new T[size];
+ unsafe
+ {
+ var len = (long) size;
+ fixed (T* dstRet = ret)
+ {
+ T* dst = dstRet; //local stack copy
+
+#if _REGEN
+ #region Compute
+ switch (dtype.as_numpy_dtype().GetTypeCode())
+ {
+ %foreach supported_dtypes,supported_dtypes_lowercase%
+ case NPTypeCode.#1: new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break;
+ %
+ default:
+ throw new NotSupportedException();
+ }
+ #endregion
+#else
+ #region Compute
+ switch (dtype.as_numpy_dtype().GetTypeCode())
+ {
+ case NPTypeCode.Boolean: new UnmanagedMemoryBlock((bool*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break;
+ case NPTypeCode.Byte: new UnmanagedMemoryBlock((byte*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break;
+ case NPTypeCode.Int16: new UnmanagedMemoryBlock((short*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break;
+ case NPTypeCode.UInt16: new UnmanagedMemoryBlock((ushort*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break;
+ case NPTypeCode.Int32: new UnmanagedMemoryBlock((int*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break;
+ case NPTypeCode.UInt32: new UnmanagedMemoryBlock((uint*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break;
+ case NPTypeCode.Int64: new UnmanagedMemoryBlock((long*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break;
+ case NPTypeCode.UInt64: new UnmanagedMemoryBlock((ulong*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break;
+ case NPTypeCode.Char: new UnmanagedMemoryBlock((char*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break;
+ case NPTypeCode.Double: new UnmanagedMemoryBlock((double*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break;
+ case NPTypeCode.Single: new UnmanagedMemoryBlock((float*) buffer, len).CastTo(new UnmanagedMemoryBlock(dst, len), null, null); break;
+ case NPTypeCode.String: throw new NotSupportedException("Unable to convert from string to other dtypes"); //TODO! this should call Converts.To
+ default:
+ throw new NotSupportedException();
+ }
+ #endregion
+#endif
+
+ }
+ }
+
+ return ret;
+ }
}
+ ///
+ /// Copies the memory of current buffer onto newly allocated array.
+ ///
+ ///
+ [Obsolete("Please use set_shape(TensorShape shape) instead.", false)]
public byte[] Data()
{
- var data = new byte[bytesize];
- Marshal.Copy(buffer, data, 0, (int)bytesize);
- return data;
+ return BufferToArray();
+ }
+
+ ///
+ /// Copies the memory of current buffer onto newly allocated array.
+ ///
+ ///
+ public byte[] BufferToArray()
+ {
+ unsafe
+ {
+ // ReSharper disable once LocalVariableHidesMember
+ var bytesize = (long) this.bytesize;
+ var data = new byte[bytesize];
+ fixed (byte* dst = data)
+ System.Buffer.MemoryCopy(buffer.ToPointer(), dst, bytesize, bytesize);
+
+ return data;
+ }
}
- public unsafe string[] StringData()
+ /// Used internally in ToArray<T>
+ private unsafe string[] StringData()
{
//
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes.
@@ -199,19 +397,19 @@ namespace Tensorflow
var buffer = new byte[size][];
var src = c_api.TF_TensorData(_handle);
- var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize);
- src += (int)(size * 8);
+ var srcLen = (IntPtr) (src.ToInt64() + (long) bytesize);
+ src += (int) (size * 8);
for (int i = 0; i < buffer.Length; i++)
{
using (var status = new Status())
{
IntPtr dst = IntPtr.Zero;
UIntPtr dstLen = UIntPtr.Zero;
- var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, status);
+ var read = c_api.TF_StringDecode((byte*) src, (UIntPtr) (srcLen.ToInt64() - src.ToInt64()), (byte**) &dst, &dstLen, status);
status.Check(true);
- buffer[i] = new byte[(int)dstLen];
+ buffer[i] = new byte[(int) dstLen];
Marshal.Copy(dst, buffer[i], 0, buffer[i].Length);
- src += (int)read;
+ src += (int) read;
}
}
@@ -229,51 +427,29 @@ namespace Tensorflow
}
///
- /// Evaluates this tensor in a `Session`.
+ /// Evaluates this tensor in a `Session`.
///
/// A dictionary that maps `Tensor` objects to feed values.
- /// The `Session` to be used to evaluate this tensor.
- ///
+ /// A array corresponding to the value of this tensor.
public NDArray eval(params FeedItem[] feed_dict)
{
return ops._eval_using_default_session(this, feed_dict, graph);
}
+ ///
+ /// Evaluates this tensor in a `Session`.
+ ///
+ /// A dictionary that maps `Tensor` objects to feed values.
+ /// The `Session` to be used to evaluate this tensor.
+ /// A array corresponding to the value of this tensor.
public NDArray eval(Session session, FeedItem[] feed_dict = null)
{
return ops._eval_using_default_session(this, feed_dict, graph, session);
}
- public TF_DataType ToTFDataType(Type type)
- {
- switch (type.Name)
- {
- case "Char":
- return TF_DataType.TF_UINT8;
- case "Int16":
- return TF_DataType.TF_INT16;
- case "Int32":
- return TF_DataType.TF_INT32;
- case "Int64":
- return TF_DataType.TF_INT64;
- case "Single":
- return TF_DataType.TF_FLOAT;
- case "Double":
- return TF_DataType.TF_DOUBLE;
- case "Byte":
- return TF_DataType.TF_UINT8;
- case "String":
- return TF_DataType.TF_STRING;
- case "Boolean":
- return TF_DataType.TF_BOOL;
- default:
- throw new NotImplementedException("ToTFDataType error");
- }
- }
-
public Tensor slice(Slice slice)
{
- var slice_spec = new int[] { slice.Start.Value };
+ var slice_spec = new int[] {slice.Start.Value};
var begin = new List();
var end = new List();
var strides = new List();
@@ -289,26 +465,26 @@ namespace Tensorflow
if (slice.Stop.HasValue)
{
end.Add(slice.Stop.Value);
- }
- else
+ } else
{
end.Add(0);
end_mask |= (1 << index);
}
+
strides.Add(slice.Step);
index += 1;
}
- return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope =>
+ return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope =>
{
string name = scope;
if (begin != null)
{
var (packed_begin, packed_end, packed_strides) =
(array_ops.stack(begin.ToArray()),
- array_ops.stack(end.ToArray()),
- array_ops.stack(strides.ToArray()));
+ array_ops.stack(end.ToArray()),
+ array_ops.stack(strides.ToArray()));
return gen_array_ops.strided_slice(
this,
@@ -320,7 +496,6 @@ namespace Tensorflow
shrink_axis_mask: shrink_axis_mask,
new_axis_mask: new_axis_mask,
ellipsis_mask: ellipsis_mask,
-
name: name);
}
@@ -330,7 +505,7 @@ namespace Tensorflow
public Tensor slice(int start)
{
- var slice_spec = new int[] { start };
+ var slice_spec = new int[] {start};
var begin = new List();
var end = new List();
var strides = new List();
@@ -349,15 +524,15 @@ namespace Tensorflow
index += 1;
}
- return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope =>
+ return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope =>
{
string name = scope;
if (begin != null)
{
var (packed_begin, packed_end, packed_strides) =
(array_ops.stack(begin.ToArray()),
- array_ops.stack(end.ToArray()),
- array_ops.stack(strides.ToArray()));
+ array_ops.stack(end.ToArray()),
+ array_ops.stack(strides.ToArray()));
return gen_array_ops.strided_slice(
this,
@@ -369,7 +544,6 @@ namespace Tensorflow
shrink_axis_mask: shrink_axis_mask,
new_axis_mask: new_axis_mask,
ellipsis_mask: ellipsis_mask,
-
name: name);
}
@@ -392,15 +566,12 @@ namespace Tensorflow
return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}";
}
- protected override void DisposeManagedState()
- {
- }
-
- protected override void DisposeUnManagedState(IntPtr handle)
+ protected override void DisposeUnmanagedResources(IntPtr handle)
{
- if(handle != IntPtr.Zero)
+ if (handle != IntPtr.Zero)
{
c_api.TF_DeleteTensor(handle);
+ _handle = IntPtr.Zero;
}
}
@@ -417,4 +588,4 @@ namespace Tensorflow
public int tensor_int_val { get; set; }
}
-}
+}
\ No newline at end of file
diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs
index 13258f79..cf62ce04 100644
--- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs
+++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs
@@ -1,35 +1,84 @@
using NumSharp;
using System;
+using System.Diagnostics.CodeAnalysis;
using System.Linq;
+using System.Runtime.CompilerServices;
namespace Tensorflow
{
///
- /// Represents the shape of a `Tensor`.
+ /// Represents the shape of a `Tensor`.
///
+ /// https://www.tensorflow.org/api_docs/python/tf/TensorShape
public class TensorShape
{
- private Shape shape;
+ private readonly Shape shape;
+
+ ///
+ /// Returns a list of Dimensions, or None if the shape is unspecified.
+ ///
public int[] dims => shape.Dimensions;
+
+ ///
+ /// Returns the rank of this shape.
+ ///
public int ndim => shape.NDim;
+
+ ///
+ /// Returns the rank of this shape.
+ ///
+ public int rank => shape.NDim;
+
+ ///
+ /// Returns the size this shape represents.
+ ///
public int size => shape.Size;
public TensorShape(TensorShapeProto proto)
{
if (proto.UnknownRank) return;
+ switch (proto.Dim.Count)
+ {
+ case 0: shape = new Shape(new int[0]); break;
+ case 1: shape = Shape.Vector((int) proto.Dim[0].Size); break;
+ case 2: shape = Shape.Matrix((int) proto.Dim[0].Size, (int) proto.Dim[1].Size); break;
+ default:
+ var protodims = proto.Dim;
+ var len = protodims.Count;
+ var dims = new int[len];
+ for (int i = 0; i < len; i++)
+ dims[i] = (int) protodims[i].Size;
+
- shape.reshape(proto.Dim.Select(x => (int)x.Size).ToArray());
+ shape = new Shape(dims); break;
+ }
}
public TensorShape(params int[] dims)
{
- shape = new Shape(dims);
+ switch (dims.Length)
+ {
+ case 0: shape = new Shape(new int[0]); break;
+ case 1: shape = Shape.Vector((int) dims[0]); break;
+ case 2: shape = Shape.Matrix(dims[0], dims[1]); break;
+ default: shape = new Shape(dims); break;
+ }
}
+ ///
+ ///
+ ///
+ ///
+ ///
+ /// When is not an Index.
+ [SuppressMessage("ReSharper", "PossibleInvalidOperationException")]
public TensorShape this[Slice slice]
{
get
{
+ if (slice.Start.HasValue == false || slice.Length.HasValue == false)
+ throw new ArgumentException("Slice must has Start and Length.");
+
return new TensorShape(dims.Skip(slice.Start.Value)
.Take(slice.Length.Value)
.ToArray());
@@ -37,7 +86,7 @@ namespace Tensorflow
}
///
- /// Returns True iff `self` is fully defined in every dimension.
+ /// Returns True iff `self` is fully defined in every dimension.
///
///
public bool is_fully_defined()
@@ -50,6 +99,7 @@ namespace Tensorflow
throw new NotImplementedException("TensorShape is_compatible_with");
}
+ [SuppressMessage("ReSharper", "ParameterHidesMember")]
public TensorShape with_rank_at_least(int rank)
{
if (rank != ndim)
@@ -59,35 +109,63 @@ namespace Tensorflow
}
///
- /// Returns the concatenation of the dimension in `self` and `other`.
+ /// Returns the concatenation of the dimension in `self` and `other`.
///
///
///
- public TensorShape concatenate(int[] other_)
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ public TensorShape concatenate(int[] other)
{
- var other = new TensorShape(other_);
+ return concatenate(new TensorShape(other));
+ }
- if (ndim < 0 || other.ndim < 0)
+ ///
+ /// Returns the concatenation of the dimension in `self` and `other`.
+ ///
+ ///
+ ///
+ public TensorShape concatenate(TensorShape other)
+ {
+ var otherShape = other;
+
+ if (ndim < 0 || otherShape.ndim < 0)
return new TensorShape();
else
{
- var concatenate_dims = new int[ndim + other.ndim];
+ var concatenate_dims = new int[ndim + otherShape.ndim];
for (int i = 0; i < ndim; i++)
concatenate_dims[i] = dims[i];
- for (int i = 0; i < other.ndim; i++)
- concatenate_dims[ndim + i] = other.dims[i];
+ for (int i = 0; i < otherShape.ndim; i++)
+ concatenate_dims[ndim + i] = otherShape.dims[i];
return new TensorShape(concatenate_dims);
}
}
- public static implicit operator TensorShape(Shape shape) => new TensorShape(shape.Dimensions);
- public static implicit operator Shape(TensorShape shape) => new Shape(shape.dims);
+ public static implicit operator TensorShape(Shape shape) => new TensorShape((int[]) shape.Dimensions.Clone());
+ public static implicit operator Shape(TensorShape shape) => new Shape((int[]) shape.dims.Clone());
+
+ public static implicit operator int[](TensorShape shape) => (int[])shape.dims.Clone(); //we clone to avoid any changes
public static implicit operator TensorShape(int[] dims) => new TensorShape(dims);
- public static implicit operator int[](TensorShape shape) => shape.dims;
+
+ public static explicit operator int(TensorShape shape) => shape.size;
+ public static explicit operator TensorShape(int dim) => new TensorShape(dim);
+
+ public static explicit operator (int, int)(TensorShape shape) => shape.dims.Length == 2 ? (shape.dims[0], shape.dims[1]) : (0, 0);
public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2);
+
+ public static explicit operator (int, int, int)(TensorShape shape) => shape.dims.Length == 3 ? (shape.dims[0], shape.dims[1], shape.dims[2]) : (0, 0, 0);
public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3);
+
+ public static explicit operator (int, int, int, int)(TensorShape shape) => shape.dims.Length == 4 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3]) : (0, 0, 0, 0);
public static implicit operator TensorShape((int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4);
+
+ public static explicit operator (int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 5 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4]) : (0, 0, 0, 0, 0);
+ public static implicit operator TensorShape((int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5);
+
+ public static explicit operator (int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 6 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5]) : (0, 0, 0, 0, 0, 0);
+ public static implicit operator TensorShape((int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6);
+
}
}
diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs
index 807dc6f5..37f1ca61 100644
--- a/src/TensorFlowNET.Core/Tensors/dtypes.cs
+++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs
@@ -15,6 +15,8 @@
******************************************************************************/
using System;
+using System.Numerics;
+using NumSharp.Backends;
namespace Tensorflow
{
@@ -23,35 +25,100 @@ namespace Tensorflow
public static TF_DataType int8 = TF_DataType.TF_INT8;
public static TF_DataType int32 = TF_DataType.TF_INT32;
public static TF_DataType int64 = TF_DataType.TF_INT64;
+ public static TF_DataType uint8 = TF_DataType.TF_UINT8;
+ public static TF_DataType uint32 = TF_DataType.TF_UINT32;
+ public static TF_DataType uint64 = TF_DataType.TF_UINT64;
public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32?
public static TF_DataType float16 = TF_DataType.TF_HALF;
public static TF_DataType float64 = TF_DataType.TF_DOUBLE;
- public static Type as_numpy_datatype(this TF_DataType type)
+ ///
+ ///
+ ///
+ ///
+ /// equivalent to , if none exists, returns null.
+ public static Type as_numpy_dtype(this TF_DataType type)
{
switch (type)
{
case TF_DataType.TF_BOOL:
return typeof(bool);
+ case TF_DataType.TF_UINT8:
+ return typeof(byte);
case TF_DataType.TF_INT64:
return typeof(long);
+ case TF_DataType.TF_UINT64:
+ return typeof(ulong);
case TF_DataType.TF_INT32:
return typeof(int);
+ case TF_DataType.TF_UINT32:
+ return typeof(uint);
case TF_DataType.TF_INT16:
return typeof(short);
+ case TF_DataType.TF_UINT16:
+ return typeof(ushort);
case TF_DataType.TF_FLOAT:
return typeof(float);
case TF_DataType.TF_DOUBLE:
return typeof(double);
case TF_DataType.TF_STRING:
return typeof(string);
+ case TF_DataType.TF_COMPLEX128:
+ case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX
+ return typeof(Complex);
default:
return null;
}
}
- // "sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"
- public static TF_DataType as_dtype(Type type, TF_DataType? dtype = null)
+ ///
+ ///
+ ///
+ ///
+ ///
+ /// When has no equivalent
+ public static NPTypeCode as_numpy_typecode(this TF_DataType type)
+ {
+ switch (type)
+ {
+ case TF_DataType.TF_BOOL:
+ return NPTypeCode.Boolean;
+ case TF_DataType.TF_UINT8:
+ return NPTypeCode.Byte;
+ case TF_DataType.TF_INT64:
+ return NPTypeCode.Int64;
+ case TF_DataType.TF_INT32:
+ return NPTypeCode.Int32;
+ case TF_DataType.TF_INT16:
+ return NPTypeCode.Int16;
+ case TF_DataType.TF_UINT64:
+ return NPTypeCode.UInt64;
+ case TF_DataType.TF_UINT32:
+ return NPTypeCode.UInt32;
+ case TF_DataType.TF_UINT16:
+ return NPTypeCode.UInt16;
+ case TF_DataType.TF_FLOAT:
+ return NPTypeCode.Single;
+ case TF_DataType.TF_DOUBLE:
+ return NPTypeCode.Double;
+ case TF_DataType.TF_STRING:
+ return NPTypeCode.String;
+ case TF_DataType.TF_COMPLEX128:
+ case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX
+ return NPTypeCode.Complex;
+ default:
+ throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode.");
+ }
+ }
+
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ /// When has no equivalent
+ public static TF_DataType as_dtype(this Type type, TF_DataType? dtype = null)
{
switch (type.Name)
{
@@ -98,7 +165,7 @@ namespace Tensorflow
dtype = TF_DataType.TF_BOOL;
break;
default:
- throw new Exception("as_dtype Not Implemented");
+ throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode.");
}
return dtype.Value;
@@ -106,16 +173,7 @@ namespace Tensorflow
public static DataType as_datatype_enum(this TF_DataType type)
{
- DataType dtype = DataType.DtInvalid;
-
- switch (type)
- {
- default:
- Enum.TryParse(((int)type).ToString(), out dtype);
- break;
- }
-
- return dtype;
+ return Enum.TryParse(((int) type).ToString(), out DataType dtype) ? dtype : DataType.DtInvalid;
}
public static TF_DataType as_base_dtype(this TF_DataType type)
@@ -132,7 +190,7 @@ namespace Tensorflow
public static Type as_numpy_dtype(this DataType type)
{
- return type.as_tf_dtype().as_numpy_datatype();
+ return type.as_tf_dtype().as_numpy_dtype();
}
public static DataType as_base_dtype(this DataType type)
@@ -144,16 +202,7 @@ namespace Tensorflow
public static TF_DataType as_tf_dtype(this DataType type)
{
- TF_DataType dtype = TF_DataType.DtInvalid;
-
- switch (type)
- {
- default:
- Enum.TryParse(((int)type).ToString(), out dtype);
- break;
- }
-
- return dtype;
+ return Enum.TryParse(((int) type).ToString(), out TF_DataType dtype) ? dtype : TF_DataType.DtInvalid;
}
public static TF_DataType as_ref(this TF_DataType type)
diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
index ded105c7..43848da6 100644
--- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs
+++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
@@ -17,6 +17,7 @@
using NumSharp;
using System;
using System.Linq;
+using NumSharp.Utilities;
namespace Tensorflow
{
@@ -109,7 +110,7 @@ namespace Tensorflow
// We first convert value to a numpy array or scalar.
NDArray nparray = null;
- var np_dt = dtype.as_numpy_datatype();
+ var np_dt = dtype.as_numpy_dtype();
if (values is NDArray nd)
{
@@ -188,37 +189,37 @@ namespace Tensorflow
if (values.GetType().IsArray)
nparray = np.array((int[])values, np_dt);
else
- nparray = Convert.ToInt32(values);
+ nparray = Converts.ToInt32(values);
break;
case "Int64":
if (values.GetType().IsArray)
nparray = np.array((int[])values, np_dt);
else
- nparray = Convert.ToInt64(values);
+ nparray = Converts.ToInt64(values);
break;
case "Single":
if (values.GetType().IsArray)
nparray = np.array((float[])values, np_dt);
else
- nparray = Convert.ToSingle(values);
+ nparray = Converts.ToSingle(values);
break;
case "Double":
if (values.GetType().IsArray)
nparray = np.array((double[])values, np_dt);
else
- nparray = Convert.ToDouble(values);
+ nparray = Converts.ToDouble(values);
break;
case "String":
if (values.GetType().IsArray)
nparray = np.array((string[])values, np_dt);
else
- nparray = NDArray.FromString(Convert.ToString(values));
+ nparray = NDArray.FromString(Converts.ToString(values));
break;
case "Boolean":
if (values.GetType().IsArray)
nparray = np.array((bool[])values, np_dt);
else
- nparray = Convert.ToBoolean(values);
+ nparray = Converts.ToBoolean(values);
break;
default:
throw new NotImplementedException($"make_tensor_proto: Support for type {np_dt.Name} Not Implemented");
diff --git a/src/TensorFlowNET.Core/globals.regen b/src/TensorFlowNET.Core/globals.regen
new file mode 100644
index 00000000..146155b3
--- /dev/null
+++ b/src/TensorFlowNET.Core/globals.regen
@@ -0,0 +1,38 @@
+%all_dtypes = ["NDArray","Complex","Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single","String"]
+%all_dtypes_lowercase = ["NDArray","Complex","bool","byte","short","ushort","int","uint","long","ulong","char","double","float","string"]
+
+%supported_primitives = ["Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single","String"]
+%supported_primitives_lowercase = ["bool","byte","short","ushort","int","uint","long","ulong","char","double","float","string"]
+
+%supported_numericals = ["Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single"]
+%supported_numericals_lowercase = ["byte","short","ushort","int","uint","long","ulong","char","double","float"]
+%supported_numericals_defaultvals = ["0","0","0","0","0u","0L","0UL","'\0'","0d","0f"]
+%supported_numericals_onevales = ["1","1","1","1","1u","1L","1UL",1,"1d","1f"]
+%supported_numericals_TF_DataType = ["TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_UINT8","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"]
+
+//this is the type we use in summerizing/reducting:
+%supported_numericals_accumulatingType = ["UInt32","Int32","UInt32","Int32","UInt32","Int64","UInt64","UInt32","Double","Single"]
+%supported_numericals_accumulatingType_defaultvals = ["0","0","0","0","0u","0L","0UL","'\0'","0d","0f"]
+
+%supported_numericals_signed = ["Int16","Int32","Int64","Double","Single"]
+%supported_numericals_signed_lowercase = ["short","int","long","double","float"]
+%supported_numericals_signed_defaultvals = ["0","0","0L","0d","0f"]
+%supported_numericals_signed_onevales = ["1","1","1L","1d","1f"]
+
+%supported_numericals_unsigned = ["Byte","UInt16","UInt32","UInt64","Char"]
+%supported_numericals_unsigned_lowercase = ["byte","ushort","uint","ulong","char"]
+%supported_numericals_unsigned_defaultvals = ["0","0","0U","0UL","'\0'"]
+%supported_numericals_unsigned_onevales = ["1","1","1U","1UL","'\1'"]
+
+%supported_dtypes = ["Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single"]
+%supported_numericals_TF_DataType = ["TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_UINT8","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"]
+
+%supported_dtypes_lowercase = ["bool","byte","short","ushort","int","uint","long","ulong","char","double","float"]
+%supported_dtypes_defaultvals = [false,"0","0","0","0","0u","0L","0UL","'\0'","0d","0f"]
+%supported_dtypes_onevales = [true,"1","1","1","1","1u","1L","1UL","'\1'","1d","1f"]
+%supported_dtypes_dtype = ["bool","uint8","int16","uint16","int32","uint32","int64","uint64","uint8","float64","float32"]
+
+//this is the type we use in summerizing/reducting:
+%supported_dtypes_accumulatingType = ["Int32","UInt32","Int32","UInt32","Int32","UInt32","Int64","UInt64","UInt32","Double","Single"]
+%supported_dtypes_accumulatingType_defaultvals = [false, "0","0","0","0u","0L","0UL","'\0'","0d","0f"]
+
diff --git a/src/TensorFlowNET.Core/ops.GraphKeys.cs b/src/TensorFlowNET.Core/ops.GraphKeys.cs
index 17b095a4..c5a06433 100644
--- a/src/TensorFlowNET.Core/ops.GraphKeys.cs
+++ b/src/TensorFlowNET.Core/ops.GraphKeys.cs
@@ -29,55 +29,111 @@ namespace Tensorflow
///
public class GraphKeys
{
+ #region const
+
+
+ ///
+ /// the subset of `Variable` objects that will be trained by an optimizer.
+ ///
+ public const string TRAINABLE_VARIABLES_ = "trainable_variables";
+
+ ///
+ /// Trainable resource-style variables.
+ ///
+ public const string TRAINABLE_RESOURCE_VARIABLES_ = "trainable_resource_variables";
+
+ ///
+ /// Key for streaming model ports.
+ ///
+ public const string _STREAMING_MODEL_PORTS_ = "streaming_model_ports";
+
+ ///
+ /// Key to collect losses
+ ///
+ public const string LOSSES_ = "losses";
+
+ ///
+ /// Key to collect Variable objects that are global (shared across machines).
+ /// Default collection for all variables, except local ones.
+ ///
+ public const string GLOBAL_VARIABLES_ = "variables";
+
+ public const string TRAIN_OP_ = "train_op";
+
+ public const string GLOBAL_STEP_ = "global_step";
+
+ public string[] _VARIABLE_COLLECTIONS_ = new string[] { "variables", "trainable_variables", "model_variables" };
+ ///
+ /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing.
+ ///
+ public const string SAVEABLE_OBJECTS_ = "saveable_objects";
+ ///
+ /// Key to collect update_ops
+ ///
+ public const string UPDATE_OPS_ = "update_ops";
+
+ // Key to collect summaries.
+ public const string SUMMARIES_ = "summaries";
+
+ // Used to store v2 summary names.
+ public const string _SUMMARY_COLLECTION_ = "_SUMMARY_V2";
+
+ // Key for control flow context.
+ public const string COND_CONTEXT_ = "cond_context";
+ public const string WHILE_CONTEXT_ = "while_context";
+
+ #endregion
+
+
///
/// the subset of `Variable` objects that will be trained by an optimizer.
///
- public string TRAINABLE_VARIABLES = "trainable_variables";
+ public string TRAINABLE_VARIABLES => TRAINABLE_VARIABLES_;
///
/// Trainable resource-style variables.
///
- public string TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables";
+ public string TRAINABLE_RESOURCE_VARIABLES => TRAINABLE_RESOURCE_VARIABLES_;
///
/// Key for streaming model ports.
///
- public string _STREAMING_MODEL_PORTS = "streaming_model_ports";
+ public string _STREAMING_MODEL_PORTS => _STREAMING_MODEL_PORTS_;
///
/// Key to collect losses
///
- public string LOSSES = "losses";
+ public string LOSSES => LOSSES_;
///
/// Key to collect Variable objects that are global (shared across machines).
/// Default collection for all variables, except local ones.
///
- public string GLOBAL_VARIABLES = "variables";
+ public string GLOBAL_VARIABLES => GLOBAL_VARIABLES_;
- public string TRAIN_OP = "train_op";
+ public string TRAIN_OP => TRAIN_OP_;
- public string GLOBAL_STEP = "global_step";
+ public string GLOBAL_STEP => GLOBAL_STEP_;
- public string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables", "model_variables" };
+ public string[] _VARIABLE_COLLECTIONS => _VARIABLE_COLLECTIONS_;
///
/// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing.
///
- public string SAVEABLE_OBJECTS = "saveable_objects";
+ public string SAVEABLE_OBJECTS => SAVEABLE_OBJECTS_;
///
/// Key to collect update_ops
///
- public string UPDATE_OPS = "update_ops";
+ public string UPDATE_OPS => UPDATE_OPS_;
// Key to collect summaries.
- public string SUMMARIES = "summaries";
+ public string SUMMARIES => SUMMARIES_;
// Used to store v2 summary names.
- public string _SUMMARY_COLLECTION = "_SUMMARY_V2";
+ public string _SUMMARY_COLLECTION => _SUMMARY_COLLECTION_;
// Key for control flow context.
- public string COND_CONTEXT = "cond_context";
- public string WHILE_CONTEXT = "while_context";
+ public string COND_CONTEXT => COND_CONTEXT_;
+ public string WHILE_CONTEXT => WHILE_CONTEXT_;
}
}
}
diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj
index 1bd3d530..55e9b27d 100644
--- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj
+++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj
@@ -6,6 +6,14 @@
false
+
+ bin\debug-gpu
+
+
+
+ bin\release-gpu
+
+
diff --git a/test/TensorFlowNET.UnitTest/ConstantTest.cs b/test/TensorFlowNET.UnitTest/ConstantTest.cs
index c1d4c9e5..b532e558 100644
--- a/test/TensorFlowNET.UnitTest/ConstantTest.cs
+++ b/test/TensorFlowNET.UnitTest/ConstantTest.cs
@@ -98,9 +98,9 @@ namespace TensorFlowNET.UnitTest
{
var result = sess.run(tensor);
- Assert.AreEqual(result[0].shape[0], 3);
- Assert.AreEqual(result[0].shape[1], 2);
- Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, result[0].Data()));
+ Assert.AreEqual(result.shape[0], 3);
+ Assert.AreEqual(result.shape[1], 2);
+ Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, result.Data()));
}
// big size
@@ -109,13 +109,13 @@ namespace TensorFlowNET.UnitTest
{
var result = sess.run(tensor);
- Assert.AreEqual(result[0].shape[0], 200);
- Assert.AreEqual(result[0].shape[1], 100);
+ Assert.AreEqual(result.shape[0], 200);
+ Assert.AreEqual(result.shape[1], 100);
- var data = result[0].Data();
+ var data = result.Data();
Assert.AreEqual(0, data[0]);
Assert.AreEqual(0, data[500]);
- Assert.AreEqual(0, data[result[0].size - 1]);
+ Assert.AreEqual(0, data[result.size - 1]);
}
}
@@ -127,9 +127,9 @@ namespace TensorFlowNET.UnitTest
{
var result = sess.run(ones);
- Assert.AreEqual(result[0].shape[0], 3);
- Assert.AreEqual(result[0].shape[1], 2);
- Assert.IsTrue(new[] { 1, 1, 1, 1, 1, 1 }.SequenceEqual(result[0].Data()));
+ Assert.AreEqual(result.shape[0], 3);
+ Assert.AreEqual(result.shape[1], 2);
+ Assert.IsTrue(new[] { 1, 1, 1, 1, 1, 1 }.SequenceEqual(result.Data()));
}
}
@@ -142,9 +142,9 @@ namespace TensorFlowNET.UnitTest
{
var result = sess.run(halfes);
- Assert.AreEqual(result[0].shape[0], 3);
- Assert.AreEqual(result[0].shape[1], 2);
- Assert.IsTrue(new[] { .5, .5, .5, .5, .5, .5 }.SequenceEqual(result[0].Data()));
+ Assert.AreEqual(result.shape[0], 3);
+ Assert.AreEqual(result.shape[1], 2);
+ Assert.IsTrue(new[] { .5, .5, .5, .5, .5, .5 }.SequenceEqual(result.Data()));
}
}
@@ -161,10 +161,10 @@ namespace TensorFlowNET.UnitTest
using (var sess = tf.Session())
{
var result = sess.run(tensor);
- var data = result[0].Data();
+ var data = result.Data();
- Assert.AreEqual(result[0].shape[0], 2);
- Assert.AreEqual(result[0].shape[1], 3);
+ Assert.AreEqual(result.shape[0], 2);
+ Assert.AreEqual(result.shape[1], 3);
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 1, 2, 1, 3 }, data));
}
}
@@ -177,7 +177,7 @@ namespace TensorFlowNET.UnitTest
var c = a * b;
var sess = tf.Session();
- double result = sess.run(c)[0];
+ double result = sess.run(c);
sess.close();
Assert.AreEqual(6.0, result);
diff --git a/test/TensorFlowNET.UnitTest/GradientTest.cs b/test/TensorFlowNET.UnitTest/GradientTest.cs
index b52bc1cf..c8e57ba4 100644
--- a/test/TensorFlowNET.UnitTest/GradientTest.cs
+++ b/test/TensorFlowNET.UnitTest/GradientTest.cs
@@ -41,7 +41,7 @@ namespace TensorFlowNET.UnitTest
var grad = tf.gradients(y, x);
Assert.AreEqual(grad[0].name, "gradients/AddN:0");
- float r = sess.run(grad[0])[0];
+ float r = sess.run(grad[0]);
Assert.AreEqual(r, 1.4f);
}
}
@@ -57,7 +57,7 @@ namespace TensorFlowNET.UnitTest
var grad = tf.gradients(y, x);
Assert.AreEqual(grad[0].name, "gradients/AddN:0");
- float r = sess.run(grad[0])[0];
+ float r = sess.run(grad[0]);
Assert.AreEqual(r, 14.700001f);
});
}
@@ -94,7 +94,7 @@ namespace TensorFlowNET.UnitTest
using (var sess = tf.Session(graph))
{
- var r = sess.run(slice)[0];
+ var r = sess.run(slice);
Assert.IsTrue(Enumerable.SequenceEqual(r.shape, new[] { 2, 1, 2 }));
Assert.IsTrue(Enumerable.SequenceEqual(r[0].GetData(), new[] { 11, 13 }));
diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs
index 4c6ae3d0..0caa5259 100644
--- a/test/TensorFlowNET.UnitTest/OperationsTest.cs
+++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs
@@ -44,7 +44,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, 3.0f),
new FeedItem(b, 2.0f));
- Assert.AreEqual((float)o[0], 5.0f);
+ Assert.AreEqual((float)o, 5.0f);
}
}
@@ -58,7 +58,7 @@ namespace TensorFlowNET.UnitTest
using (var sess = tf.Session())
{
var o = sess.run(c);
- Assert.AreEqual((float)o[0], 9.0f);
+ Assert.AreEqual((float)o, 9.0f);
}
}
@@ -72,7 +72,7 @@ namespace TensorFlowNET.UnitTest
using (var sess = tf.Session())
{
var o = sess.run(b);
- Assert.IsTrue(o[0].array_equal(check));
+ Assert.IsTrue(o.array_equal(check));
}
}
@@ -86,7 +86,7 @@ namespace TensorFlowNET.UnitTest
using (var sess = tf.Session())
{
var o = sess.run(b);
- Assert.IsTrue(o[0].array_equal(check));
+ Assert.IsTrue(o.array_equal(check));
}
}
@@ -100,7 +100,7 @@ namespace TensorFlowNET.UnitTest
using (var sess = tf.Session())
{
var o = sess.run(b);
- Assert.IsTrue(o[0].array_equal(check));
+ Assert.IsTrue(o.array_equal(check));
}
b = tf.cumsum(a, exclusive: true);
@@ -109,7 +109,7 @@ namespace TensorFlowNET.UnitTest
using (var sess = tf.Session())
{
var o = sess.run(b);
- Assert.IsTrue(o[0].array_equal(check));
+ Assert.IsTrue(o.array_equal(check));
}
b = tf.cumsum(a, reverse: true);
@@ -118,7 +118,7 @@ namespace TensorFlowNET.UnitTest
using (var sess = tf.Session())
{
var o = sess.run(b);
- Assert.IsTrue(o[0].array_equal(check));
+ Assert.IsTrue(o.array_equal(check));
}
b = tf.cumsum(a, exclusive:true, reverse: true);
@@ -127,7 +127,7 @@ namespace TensorFlowNET.UnitTest
using (var sess = tf.Session())
{
var o = sess.run(b);
- Assert.IsTrue(o[0].array_equal(check));
+ Assert.IsTrue(o.array_equal(check));
}
}
@@ -143,7 +143,7 @@ namespace TensorFlowNET.UnitTest
using (var sess = tf.Session())
{
var o = sess.run(d);
- Assert.IsTrue(o[0].array_equal(check));
+ Assert.IsTrue(o.array_equal(check));
}
d = tf.cast(tf.logical_not(b), tf.int32);
@@ -152,7 +152,7 @@ namespace TensorFlowNET.UnitTest
using (var sess = tf.Session())
{
var o = sess.run(d);
- Assert.IsTrue(o[0].array_equal(check));
+ Assert.IsTrue(o.array_equal(check));
}
d = tf.cast(tf.logical_or(b, c), tf.int32);
@@ -161,7 +161,7 @@ namespace TensorFlowNET.UnitTest
using (var sess = tf.Session())
{
var o = sess.run(d);
- Assert.IsTrue(o[0].array_equal(check));
+ Assert.IsTrue(o.array_equal(check));
}
d = tf.cast(tf.logical_xor(b, c), tf.int32);
@@ -170,7 +170,7 @@ namespace TensorFlowNET.UnitTest
using (var sess = tf.Session())
{
var o = sess.run(d);
- Assert.IsTrue(o[0].array_equal(check));
+ Assert.IsTrue(o.array_equal(check));
}
}
@@ -197,7 +197,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResult);
+ Assert.AreEqual((int)o, intResult);
}
// Testing `operator +(Tensor x, Tensor y)`
@@ -207,7 +207,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResult);
+ Assert.AreEqual((int)o, intResult);
}
// Testing `operator +(Tensor x, int y)`
@@ -216,7 +216,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResult);
+ Assert.AreEqual((int)o, intResult);
}
// Testing `operator +(int x, Tensor y)`
@@ -225,7 +225,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResult);
+ Assert.AreEqual((int)o, intResult);
}
#endregion
@@ -246,7 +246,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((float)o[0], floatResult);
+ Assert.AreEqual((float)o, floatResult);
}
// Testing `operator +(Tensor x, Tensor y)
@@ -256,7 +256,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((float)o[0], floatResult);
+ Assert.AreEqual((float)o, floatResult);
}
// Testing `operator +(Tensor x, float y)
@@ -265,7 +265,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((float)o[0], floatResult);
+ Assert.AreEqual((float)o, floatResult);
}
// Testing `operator +(float x, Tensor y)
@@ -274,7 +274,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((float)o[0], floatResult);
+ Assert.AreEqual((float)o, floatResult);
}
#endregion
@@ -295,7 +295,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((double)o[0], doubleResult);
+ Assert.AreEqual((double)o, doubleResult);
}
// Testing `operator +(Tensor x, Tensor y)
@@ -305,7 +305,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((double)o[0], doubleResult);
+ Assert.AreEqual((double)o, doubleResult);
}
// Testing `operator +(Tensor x, double y)
@@ -314,7 +314,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((double)o[0], doubleResult);
+ Assert.AreEqual((double)o, doubleResult);
}
// Testing `operator +(double x, Tensor y)
@@ -323,7 +323,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((double)o[0], doubleResult);
+ Assert.AreEqual((double)o, doubleResult);
}
#endregion
}
@@ -352,7 +352,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResult);
+ Assert.AreEqual((int)o, intResult);
}
// Testing `operator -(Tensor x, Tensor y)
@@ -362,7 +362,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResult);
+ Assert.AreEqual((int)o, intResult);
}
// Testing `operator -(Tensor x, int y)
@@ -371,7 +371,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResult);
+ Assert.AreEqual((int)o, intResult);
}
// Testing `operator -(int x, Tensor y)
@@ -380,7 +380,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], Math.Abs(intResult));
+ Assert.AreEqual((int)o, Math.Abs(intResult));
}
// Testing `operator -(Tensor x)
@@ -389,7 +389,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResultTwo);
+ Assert.AreEqual((int)o, intResultTwo);
}
#endregion
@@ -411,7 +411,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((float)o[0], floatResult);
+ Assert.AreEqual((float)o, floatResult);
}
// Testing `operator -(Tensor x, Tensor y)
@@ -421,7 +421,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((float)o[0], floatResult);
+ Assert.AreEqual((float)o, floatResult);
}
// Testing `operator -(Tensor x, float y)
@@ -430,7 +430,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((float)o[0], floatResult);
+ Assert.AreEqual((float)o, floatResult);
}
// Testing `operator -(float x, Tensor y)
@@ -439,7 +439,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((float)o[0], Math.Abs(floatResult));
+ Assert.AreEqual((float)o, Math.Abs(floatResult));
}
// Testing `operator -(Tensor x)
@@ -448,7 +448,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((float)o[0], floatResultTwo);
+ Assert.AreEqual((float)o, floatResultTwo);
}
#endregion
@@ -470,7 +470,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((double)o[0], doubleResult);
+ Assert.AreEqual((double)o, doubleResult);
}
// Testing `operator -(Tensor x, Tensor y)
@@ -480,7 +480,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((double)o[0], doubleResult);
+ Assert.AreEqual((double)o, doubleResult);
}
// Testing `operator -(Tensor x, double y)
@@ -489,7 +489,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((double)o[0], doubleResult);
+ Assert.AreEqual((double)o, doubleResult);
}
// Testing `operator -(double x, Tensor y)
@@ -498,7 +498,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((double)o[0], Math.Abs(doubleResult));
+ Assert.AreEqual((double)o, Math.Abs(doubleResult));
}
// Testing `operator -(Tensor x)
@@ -507,7 +507,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((double)o[0], doubleResultTwo);
+ Assert.AreEqual((double)o, doubleResultTwo);
}
#endregion
}
@@ -593,7 +593,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResult);
+ Assert.AreEqual((int)o, intResult);
}
// Testing `operator *(Tensor x, Tensor y)
@@ -603,7 +603,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResult);
+ Assert.AreEqual((int)o, intResult);
}
// Testing `operator *(Tensor x, int y)
@@ -612,7 +612,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResult);
+ Assert.AreEqual((int)o, intResult);
}
// Testing `operator *(int x, Tensor y)
@@ -621,7 +621,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResult);
+ Assert.AreEqual((int)o, intResult);
}
#endregion
@@ -642,7 +642,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((float)o[0], floatResult);
+ Assert.AreEqual((float)o, floatResult);
}
// Testing `operator *(Tensor x, Tensor y)
@@ -652,7 +652,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((float)o[0], floatResult);
+ Assert.AreEqual((float)o, floatResult);
}
// Testing `operator *(Tensor x, float y)
@@ -661,7 +661,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((float)o[0], floatResult);
+ Assert.AreEqual((float)o, floatResult);
}
// Testing `operator *(float x, Tensor y)
@@ -670,7 +670,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((float)o[0], floatResult);
+ Assert.AreEqual((float)o, floatResult);
}
#endregion
@@ -691,7 +691,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((double)o[0], doubleResult);
+ Assert.AreEqual((double)o, doubleResult);
}
// Testing `operator *(Tensor x, Tensor y)
@@ -701,7 +701,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((double)o[0], doubleResult);
+ Assert.AreEqual((double)o, doubleResult);
}
// Testing `operator *(Tensor x, double y)
@@ -710,7 +710,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((double)o[0], doubleResult);
+ Assert.AreEqual((double)o, doubleResult);
}
// Testing `operator *(double x, Tensor y)
@@ -719,7 +719,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((double)o[0], doubleResult);
+ Assert.AreEqual((double)o, doubleResult);
}
#endregion
}
@@ -747,7 +747,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResult);
+ Assert.AreEqual((int)o, intResult);
}
// Testing `operator /(Tensor x, Tensor y)
@@ -757,7 +757,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResult);
+ Assert.AreEqual((int)o, intResult);
}
// Testing `operator /(Tensor x, int y)
@@ -766,7 +766,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResult);
+ Assert.AreEqual((int)o, intResult);
}
// Testing `operator /(int x, Tensor y)
@@ -775,7 +775,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResult);
+ Assert.AreEqual((int)o, intResult);
}
#endregion
@@ -796,7 +796,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((float)o[0], floatResult);
+ Assert.AreEqual((float)o, floatResult);
}
// Testing `operator /(Tensor x, Tensor y)
@@ -806,7 +806,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((float)o[0], floatResult);
+ Assert.AreEqual((float)o, floatResult);
}
// Testing `operator /(Tensor x, float y)
@@ -815,7 +815,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((float)o[0], floatResult);
+ Assert.AreEqual((float)o, floatResult);
}
// Testing `operator /(float x, Tensor y)
@@ -824,7 +824,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((float)o[0], floatResult);
+ Assert.AreEqual((float)o, floatResult);
}
#endregion
@@ -845,7 +845,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((double)o[0], doubleResult);
+ Assert.AreEqual((double)o, doubleResult);
}
// Testing `operator /(Tensor x, Tensor y)
@@ -855,7 +855,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((double)o[0], doubleResult);
+ Assert.AreEqual((double)o, doubleResult);
}
// Testing `operator /(Tensor x, double y)
@@ -864,7 +864,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((double)o[0], doubleResult);
+ Assert.AreEqual((double)o, doubleResult);
}
// Testing `operator /(double x, Tensor y)
@@ -873,7 +873,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((double)o[0], doubleResult);
+ Assert.AreEqual((double)o, doubleResult);
}
#endregion
}
@@ -901,7 +901,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResult);
+ Assert.AreEqual((int)o, intResult);
}
// Testing `operator >(Tensor x, Tensor y)
@@ -911,7 +911,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResult);
+ Assert.AreEqual((int)o, intResult);
}
// Testing `operator >(Tensor x, int y)
@@ -920,7 +920,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResult);
+ Assert.AreEqual((int)o, intResult);
}
// Testing `operator >(int x, Tensor y)
@@ -929,7 +929,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResultTwo);
+ Assert.AreEqual((int)o, intResultTwo);
}
#endregion
@@ -950,7 +950,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], floatResult);
+ Assert.AreEqual((int)o, floatResult);
}
// Testing `operator >(Tensor x, Tensor y)
@@ -960,7 +960,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], floatResult);
+ Assert.AreEqual((int)o, floatResult);
}
// Testing `operator >(Tensor x, float y)
@@ -969,7 +969,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], floatResult);
+ Assert.AreEqual((int)o, floatResult);
}
// Testing `operator >(float x, Tensor y)
@@ -978,7 +978,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], floatResultTwo);
+ Assert.AreEqual((int)o, floatResultTwo);
}
#endregion
@@ -999,7 +999,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], doubleResult);
+ Assert.AreEqual((int)o, doubleResult);
}
// Testing `operator >(Tensor x, Tensor y)
@@ -1009,7 +1009,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], doubleResult);
+ Assert.AreEqual((int)o, doubleResult);
}
// Testing `operator >(Tensor x, double y)
@@ -1018,7 +1018,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], doubleResult);
+ Assert.AreEqual((int)o, doubleResult);
}
// Testing `operator >(double x, Tensor y)
@@ -1027,7 +1027,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], doubleResultTwo);
+ Assert.AreEqual((int)o, doubleResultTwo);
}
#endregion
}
@@ -1055,7 +1055,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResult);
+ Assert.AreEqual((int)o, intResult);
}
// Testing `operator <(Tensor x, Tensor y)
@@ -1065,7 +1065,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResult);
+ Assert.AreEqual((int)o, intResult);
}
// Testing `operator <(Tensor x, int y)
@@ -1074,7 +1074,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResult);
+ Assert.AreEqual((int)o, intResult);
}
// Testing `operator <(int x, Tensor y)
@@ -1083,7 +1083,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResultTwo);
+ Assert.AreEqual((int)o, intResultTwo);
}
#endregion
@@ -1104,7 +1104,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], floatResult);
+ Assert.AreEqual((int)o, floatResult);
}
// Testing `operator <(Tensor x, Tensor y)
@@ -1114,7 +1114,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], floatResult);
+ Assert.AreEqual((int)o, floatResult);
}
// Testing `operator <(Tensor x, float y)
@@ -1123,7 +1123,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], floatResult);
+ Assert.AreEqual((int)o, floatResult);
}
// Testing `operator <(float x, Tensor y)
@@ -1132,7 +1132,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], floatResultTwo);
+ Assert.AreEqual((int)o, floatResultTwo);
}
#endregion
@@ -1153,7 +1153,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], doubleResult);
+ Assert.AreEqual((int)o, doubleResult);
}
// Testing `operator <(Tensor x, Tensor y)
@@ -1163,7 +1163,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], doubleResult);
+ Assert.AreEqual((int)o, doubleResult);
}
// Testing `operator <(Tensor x, double y)
@@ -1172,7 +1172,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], doubleResult);
+ Assert.AreEqual((int)o, doubleResult);
}
// Testing `operator <(double x, Tensor y)
@@ -1181,7 +1181,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], doubleResultTwo);
+ Assert.AreEqual((int)o, doubleResultTwo);
}
#endregion
}
@@ -1209,7 +1209,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResult);
+ Assert.AreEqual((int)o, intResult);
}
// Testing `operator >=(Tensor x, Tensor y)
@@ -1219,7 +1219,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResult);
+ Assert.AreEqual((int)o, intResult);
}
// Testing `operator >=(Tensor x, int y)
@@ -1228,7 +1228,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResult);
+ Assert.AreEqual((int)o, intResult);
}
// Testing `operator >=(int x, Tensor y)
@@ -1237,7 +1237,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResultTwo);
+ Assert.AreEqual((int)o, intResultTwo);
}
#endregion
@@ -1258,7 +1258,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], floatResult);
+ Assert.AreEqual((int)o, floatResult);
}
// Testing `operator >=(Tensor x, Tensor y)
@@ -1268,7 +1268,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], floatResult);
+ Assert.AreEqual((int)o, floatResult);
}
// Testing `operator >=(Tensor x, float y)
@@ -1277,7 +1277,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], floatResult);
+ Assert.AreEqual((int)o, floatResult);
}
// Testing `operator >=(float x, Tensor y)
@@ -1286,7 +1286,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], floatResultTwo);
+ Assert.AreEqual((int)o, floatResultTwo);
}
#endregion
@@ -1307,7 +1307,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], doubleResult);
+ Assert.AreEqual((int)o, doubleResult);
}
// Testing `operator >=(Tensor x, Tensor y)
@@ -1317,7 +1317,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], doubleResult);
+ Assert.AreEqual((int)o, doubleResult);
}
// Testing `operator >=(Tensor x, double y)
@@ -1326,7 +1326,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], doubleResult);
+ Assert.AreEqual((int)o, doubleResult);
}
// Testing `operator >=(double x, Tensor y)
@@ -1335,7 +1335,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], doubleResultTwo);
+ Assert.AreEqual((int)o, doubleResultTwo);
}
#endregion
}
@@ -1363,7 +1363,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResult);
+ Assert.AreEqual((int)o, intResult);
}
// Testing `operator <=(Tensor x, Tensor y)
@@ -1373,7 +1373,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResult);
+ Assert.AreEqual((int)o, intResult);
}
// Testing `operator <=(Tensor x, int y)
@@ -1382,7 +1382,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResult);
+ Assert.AreEqual((int)o, intResult);
}
// Testing `operator <=(int x, Tensor y)
@@ -1391,7 +1391,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], intResultTwo);
+ Assert.AreEqual((int)o, intResultTwo);
}
#endregion
@@ -1412,7 +1412,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], floatResult);
+ Assert.AreEqual((int)o, floatResult);
}
// Testing `operator <=(Tensor x, Tensor y)
@@ -1422,7 +1422,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], floatResult);
+ Assert.AreEqual((int)o, floatResult);
}
// Testing `operator <=(Tensor x, float y)
@@ -1431,7 +1431,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], floatResult);
+ Assert.AreEqual((int)o, floatResult);
}
// Testing `operator <=(float x, Tensor y)
@@ -1440,7 +1440,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], floatResultTwo);
+ Assert.AreEqual((int)o, floatResultTwo);
}
#endregion
@@ -1461,7 +1461,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], doubleResult);
+ Assert.AreEqual((int)o, doubleResult);
}
// Testing `operator <=(Tensor x, Tensor y)
@@ -1471,7 +1471,7 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], doubleResult);
+ Assert.AreEqual((int)o, doubleResult);
}
// Testing `operator <=(Tensor x, double y)
@@ -1480,7 +1480,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], doubleResult);
+ Assert.AreEqual((int)o, doubleResult);
}
// Testing `operator <=(double x, Tensor y)
@@ -1489,7 +1489,7 @@ namespace TensorFlowNET.UnitTest
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
- Assert.AreEqual((int)o[0], doubleResultTwo);
+ Assert.AreEqual((int)o, doubleResultTwo);
}
#endregion
}
diff --git a/test/TensorFlowNET.UnitTest/PlaceholderTest.cs b/test/TensorFlowNET.UnitTest/PlaceholderTest.cs
index 14b16c23..5135bd25 100644
--- a/test/TensorFlowNET.UnitTest/PlaceholderTest.cs
+++ b/test/TensorFlowNET.UnitTest/PlaceholderTest.cs
@@ -17,7 +17,7 @@ namespace TensorFlowNET.UnitTest
{
var result = sess.run(y,
new FeedItem(x, 2));
- Assert.AreEqual((int)result[0], 6);
+ Assert.AreEqual((int)result, 6);
}
}
}
diff --git a/test/TensorFlowNET.UnitTest/SessionTest.cs b/test/TensorFlowNET.UnitTest/SessionTest.cs
index 9c8485ec..8fd4dc8a 100644
--- a/test/TensorFlowNET.UnitTest/SessionTest.cs
+++ b/test/TensorFlowNET.UnitTest/SessionTest.cs
@@ -45,7 +45,7 @@ namespace TensorFlowNET.UnitTest
EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
EXPECT_EQ(0, outTensor.NDims);
ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
- var output_contents = outTensor.Data();
+ var output_contents = outTensor.ToArray();
EXPECT_EQ(3 + 2, output_contents[0]);
// Add another operation to the graph.
@@ -66,7 +66,7 @@ namespace TensorFlowNET.UnitTest
EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
EXPECT_EQ(0, outTensor.NDims); // scalar
ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
- output_contents = outTensor.Data();
+ output_contents = outTensor.ToArray();
EXPECT_EQ(-(7 + 2), output_contents[0]);
// Clean up
diff --git a/test/TensorFlowNET.UnitTest/TensorTest.cs b/test/TensorFlowNET.UnitTest/TensorTest.cs
index 07da9dca..11557f14 100644
--- a/test/TensorFlowNET.UnitTest/TensorTest.cs
+++ b/test/TensorFlowNET.UnitTest/TensorTest.cs
@@ -112,7 +112,7 @@ namespace TensorFlowNET.UnitTest
var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3);
var tensor = new Tensor(nd);
- var array = tensor.Data();
+ var array = tensor.ToArray();
EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT);
EXPECT_EQ(tensor.rank, nd.ndim);
diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs
index 4c5ddd7a..7673cac8 100644
--- a/test/TensorFlowNET.UnitTest/VariableTest.cs
+++ b/test/TensorFlowNET.UnitTest/VariableTest.cs
@@ -1,4 +1,5 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
+using NumSharp;
using Tensorflow;
using static Tensorflow.Binding;
@@ -16,7 +17,7 @@ namespace TensorFlowNET.UnitTest
{
session.run(x.initializer);
var result = session.run(x);
- Assert.AreEqual(10, (int)result[0]);
+ Assert.AreEqual(10, (int)result);
}
}
@@ -81,7 +82,7 @@ namespace TensorFlowNET.UnitTest
using (var session = tf.Session())
{
session.run(model);
- int result = session.run(y)[0];
+ int result = session.run(y);
Assert.AreEqual(result, 4);
}
}
@@ -97,12 +98,12 @@ namespace TensorFlowNET.UnitTest
var sess = tf.Session(graph);
sess.run(init);
- var result = sess.run(variable);
- Assert.IsTrue((int)result[0] == 31);
+ NDArray result = sess.run(variable);
+ Assert.IsTrue((int)result == 31);
var assign = variable.assign(12);
result = sess.run(assign);
- Assert.IsTrue((int)result[0] == 12);
+ Assert.IsTrue((int)result == 12);
}
[TestMethod]
@@ -139,7 +140,7 @@ namespace TensorFlowNET.UnitTest
for(int i = 0; i < 5; i++)
{
x = x + 1;
- result = session.run(x)[0];
+ result = session.run(x);
print(result);
}
}
diff --git a/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs b/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs
index 1fd7d3aa..3a5515d9 100644
--- a/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs
+++ b/test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs
@@ -31,7 +31,7 @@ namespace TensorFlowNET.UnitTest.nn_test
var y_np = this._ZeroFraction(x_np);
var x_tf = constant_op.constant(x_np);
- x_tf.SetShape(x_shape);
+ x_tf.set_shape(x_shape);
var y_tf = nn_impl.zero_fraction(x_tf);
var y_tf_np = self.evaluate(y_tf);