Browse Source

Performance optimization, refactoring and revamping. (#362)

* Refactored DisposableObject

* Added different build directory for TensorflowNET.Examples.GPU

* _FetchHandler: Switched to NPTypeCode

* gfile.cs, Walk(...): Handle case when directory top doesn't exist.

* Tensor.Creation: Perf-opted when creating tensor from NDArray of string

* Graph.cs: refactor and added docs

* Tensor.Creation.cs: perf-ops

* Tensor.Explicit.cs: perf-ops

* Copied globals.regen from NumSharp

- Added supported_numericals_TF_DataType

* Tensor perf-ops and cleanup, Revamped dtypes.cs, some renames.

- Cleanup and docs to all Tensor.cs files
- Changed all uses of System.Convert to NumSharp.Utilities.Converts
- Added all missing types in dtypes.cs
- Renamed tensor.Data<T> to tensor.ToArray<T>, added obsolete message
- Renamed tensor.Data() to tensor.BufferToArray(), added obsolete message
- Made GraphKeys to use const string instead allocating strings at every use of GraphKeys.

* Tensor: Added guards for explicit casts.

* Tensor: Added explicit cast to string

* Tensor.ToArray<T>(): Added support for cases when tensor is scalar.

* Tensor.BufferToArray(): Fixed to use long instead of int.

* TensorShape: Revamped and documented.

* BaseSession: Added Session.run(ITensorOrOperation fetche, params FeedItem[] feed_dict)

* Tensor: renamed _dtype to _override_dtype

- Fixed all locations _dtype is used incorrectly.

* Fixed unit tests

* Tensor.Operations: Reverted commit

* DisposableObject: sorted internal_dispose to properly handle Dispose() calls

* Tensor.DisposeUnmanagedResources: Nullify _handle after delete.

* TensorShape.this[...]: fixed guard check.

* DisposableObject #362
tags/v0.12
Eli Belash Haiping 6 years ago
parent
commit
6c8c2e5ec9
34 changed files with 994 additions and 500 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/APIs/c_api.cs
  2. +5
    -9
      src/TensorFlowNET.Core/Binding.Util.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Buffers/Buffer.cs
  4. +21
    -16
      src/TensorFlowNET.Core/DisposableObject.cs
  5. +2
    -1
      src/TensorFlowNET.Core/Eager/ContextOptions.cs
  6. +52
    -51
      src/TensorFlowNET.Core/Graphs/Graph.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs
  8. +4
    -0
      src/TensorFlowNET.Core/IO/gfile.cs
  9. +1
    -1
      src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
  10. +1
    -1
      src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
  11. +1
    -1
      src/TensorFlowNET.Core/Operations/array_ops.py.cs
  12. +1
    -1
      src/TensorFlowNET.Core/Operations/nn_ops.cs
  13. +10
    -5
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  14. +1
    -1
      src/TensorFlowNET.Core/Sessions/SessionOptions.cs
  15. +12
    -11
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  16. +3
    -3
      src/TensorFlowNET.Core/Status/Status.cs
  17. +81
    -58
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  18. +85
    -28
      src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs
  19. +3
    -4
      src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
  20. +274
    -103
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  21. +93
    -15
      src/TensorFlowNET.Core/Tensors/TensorShape.cs
  22. +74
    -25
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  23. +8
    -7
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  24. +38
    -0
      src/TensorFlowNET.Core/globals.regen
  25. +70
    -14
      src/TensorFlowNET.Core/ops.GraphKeys.cs
  26. +8
    -0
      test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj
  27. +17
    -17
      test/TensorFlowNET.UnitTest/ConstantTest.cs
  28. +3
    -3
      test/TensorFlowNET.UnitTest/GradientTest.cs
  29. +111
    -111
      test/TensorFlowNET.UnitTest/OperationsTest.cs
  30. +1
    -1
      test/TensorFlowNET.UnitTest/PlaceholderTest.cs
  31. +2
    -2
      test/TensorFlowNET.UnitTest/SessionTest.cs
  32. +1
    -1
      test/TensorFlowNET.UnitTest/TensorTest.cs
  33. +7
    -6
      test/TensorFlowNET.UnitTest/VariableTest.cs
  34. +1
    -1
      test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs

+ 1
- 1
src/TensorFlowNET.Core/APIs/c_api.cs View File

@@ -59,6 +59,6 @@ namespace Tensorflow
}

[DllImport(TensorFlowLibName)]
public static unsafe extern IntPtr TF_Version();
public static extern IntPtr TF_Version();
}
}

+ 5
- 9
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -308,15 +308,14 @@ namespace Tensorflow
public static IEnumerable TupleToEnumerable(object tuple)
{
Type t = tuple.GetType();
if(t.IsGenericType && (t.FullName.StartsWith("System.Tuple") || t.FullName.StartsWith("System.ValueTuple")))
if (t.IsGenericType && (t.FullName.StartsWith("System.Tuple") || t.FullName.StartsWith("System.ValueTuple")))
{
var flds = t.GetFields();
for(int i = 0; i < flds.Length;i++)
for (int i = 0; i < flds.Length; i++)
{
yield return flds[i].GetValue(tuple);
}
}
else
} else
{
throw new System.Exception("Expected Tuple.");
}
@@ -329,12 +328,9 @@ namespace Tensorflow

public static bool isinstance(object Item1, object tuple)
{
var tup = TupleToEnumerable(tuple);
foreach(var t in tup)
{
if(isinstance(Item1, (Type)t))
foreach (var t in TupleToEnumerable(tuple))
if (isinstance(Item1, (Type) t))
return true;
}
return false;
}
}


+ 1
- 1
src/TensorFlowNET.Core/Buffers/Buffer.cs View File

@@ -66,7 +66,7 @@ namespace Tensorflow
return buffer.Data;
}

protected override void DisposeUnManagedState(IntPtr handle)
protected override void DisposeUnmanagedResources(IntPtr handle)
=> c_api.TF_DeleteBuffer(handle);
}
}

+ 21
- 16
src/TensorFlowNET.Core/DisposableObject.cs View File

@@ -29,18 +29,10 @@ namespace Tensorflow

protected DisposableObject() { }

public DisposableObject(IntPtr handle)
{
_handle = handle;
}

protected virtual void DisposeManagedState()
{
}
protected DisposableObject(IntPtr handle)
=> _handle = handle;

protected abstract void DisposeUnManagedState(IntPtr handle);

protected virtual void Dispose(bool disposing)
private void internal_dispose(bool disposing)
{
if (disposing)
{
@@ -48,30 +40,43 @@ namespace Tensorflow
if (_handle != IntPtr.Zero)
{
// dispose managed state (managed objects).
DisposeManagedState();
DisposeManagedResources();

// set large fields to null.
DisposeUnManagedState(_handle);
DisposeUnmanagedResources(_handle);

_handle = IntPtr.Zero;
}
}
}

/// <summary>
/// Dispose any managed resources.
/// </summary>
/// <remarks>Equivalent to what you would perform inside <see cref="Dispose()"/></remarks>
protected virtual void DisposeManagedResources()
{
}

/// <summary>
/// Dispose any unmanaged resources related to given <paramref name="handle"/>.
/// </summary>
protected abstract void DisposeUnmanagedResources(IntPtr handle);

// override a finalizer only if Dispose(bool disposing) above has code to free unmanaged resources.
~DisposableObject()
{
// Do not change this code. Put cleanup code in Dispose(bool disposing) above.
Dispose(false);
internal_dispose(false);
}

// This code added to correctly implement the disposable pattern.
public void Dispose()
{
// Do not change this code. Put cleanup code in Dispose(bool disposing) above.
Dispose(true);
internal_dispose(true);
// uncomment the following line if the finalizer is overridden above.
GC.SuppressFinalize(this);
}
}
}
}

+ 2
- 1
src/TensorFlowNET.Core/Eager/ContextOptions.cs View File

@@ -1,8 +1,9 @@
using System;
using System.IO;

namespace Tensorflow.Eager
{
public class ContextOptions : IDisposable
public class ContextOptions : IDisposable //TODO! Eli: Shouldn't this inherieting DisposableObject?
{
private IntPtr _handle;



+ 52
- 51
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -23,57 +23,58 @@ using static Tensorflow.Binding;

namespace Tensorflow
{
/*
A TensorFlow computation, represented as a dataflow graph.

A `Graph` contains a set of
`tf.Operation` objects,
which represent units of computation; and
`tf.Tensor` objects, which represent
the units of data that flow between operations.

A default `Graph` is always registered, and accessible by calling
`tf.get_default_graph`.
To add an operation to the default graph, simply call one of the functions
that defines a new `Operation`:

```python
c = tf.constant(4.0)
assert c.graph is tf.get_default_graph()
```

Another typical usage involves the
`tf.Graph.as_default`
context manager, which overrides the current default graph for the
lifetime of the context:

```python
g = tf.Graph()
with g.as_default():
# Define operations and tensors in `g`.
c = tf.constant(30.0)
assert c.graph is g
```

Important note: This class *is not* thread-safe for graph construction. All
operations should be created from a single thread, or external
synchronization must be provided. Unless otherwise specified, all methods
are not thread-safe.

A `Graph` instance supports an arbitrary number of "collections"
that are identified by name. For convenience when building a large
graph, collections can store groups of related objects: for
example, the `tf.Variable` uses a collection (named
`tf.GraphKeys.GLOBAL_VARIABLES`) for
all variables that are created during the construction of a graph. The caller
may define additional collections by specifying a new name.
*/

/// <summary>
/// TensorFlow uses a dataflow graph to represent your computation in terms of the dependencies between individual operations.
/// This leads to a low-level programming model in which you first define the dataflow graph,
/// then create a TensorFlow session to run parts of the graph across a set of local and remote devices.
/// https://www.tensorflow.org/guide/graphs
/// TensorFlow uses a dataflow graph to represent your computation in terms of the dependencies between individual operations.
/// This leads to a low-level programming model in which you first define the dataflow graph,
/// then create a TensorFlow session to run parts of the graph across a set of local and remote devices.
/// </summary>
/*
A TensorFlow computation, represented as a dataflow graph.
A `Graph` contains a set of
`tf.Operation` objects,
which represent units of computation; and
`tf.Tensor` objects, which represent
the units of data that flow between operations.
A default `Graph` is always registered, and accessible by calling
`tf.get_default_graph`.
To add an operation to the default graph, simply call one of the functions
that defines a new `Operation`:
```python
c = tf.constant(4.0)
assert c.graph is tf.get_default_graph()
```
Another typical usage involves the
`tf.Graph.as_default`
context manager, which overrides the current default graph for the
lifetime of the context:
```python
g = tf.Graph()
with g.as_default():
# Define operations and tensors in `g`.
c = tf.constant(30.0)
assert c.graph is g
```
Important note: This class *is not* thread-safe for graph construction. All
operations should be created from a single thread, or external
synchronization must be provided. Unless otherwise specified, all methods
are not thread-safe.
A `Graph` instance supports an arbitrary number of "collections"
that are identified by name. For convenience when building a large
graph, collections can store groups of related objects: for
example, the `tf.Variable` uses a collection (named
`tf.GraphKeys.GLOBAL_VARIABLES`) for
all variables that are created during the construction of a graph. The caller
may define additional collections by specifying a new name.
*/
/// <remarks>https://www.tensorflow.org/guide/graphs <br></br>https://www.tensorflow.org/api_docs/python/tf/Graph</remarks>
public partial class Graph : DisposableObject, IEnumerable<Operation>
{
private Dictionary<int, ITensorOrOperation> _nodes_by_id;
@@ -439,12 +440,12 @@ namespace Tensorflow
_unfetchable_ops.Add(op);
}
protected override void DisposeManagedState()
protected override void DisposeManagedResources()
{
ops.default_graph_stack.remove(this);
}
protected override void DisposeUnManagedState(IntPtr handle)
protected override void DisposeUnmanagedResources(IntPtr handle)
{
c_api.TF_DeleteGraph(handle);
}


+ 1
- 1
src/TensorFlowNET.Core/Graphs/ImportGraphDefOptions.cs View File

@@ -37,7 +37,7 @@ namespace Tensorflow
c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index);
}

protected override void DisposeUnManagedState(IntPtr handle)
protected override void DisposeUnmanagedResources(IntPtr handle)
=> c_api.TF_DeleteImportGraphDefOptions(handle);

public static implicit operator IntPtr(ImportGraphDefOptions opts) => opts._handle;


+ 4
- 0
src/TensorFlowNET.Core/IO/gfile.cs View File

@@ -16,6 +16,7 @@

using System.Collections.Generic;
using System.IO;
using System.Linq;

namespace Tensorflow.IO
{
@@ -28,6 +29,9 @@ namespace Tensorflow.IO
/// <param name="in_order">Traverse in order if True, post order if False.</param>
public IEnumerable<(string, string[], string[])> Walk(string top, bool in_order = true)
{
if (!Directory.Exists(top))
return Enumerable.Empty<(string, string[], string[])>();

return walk_v2(top, in_order);
}



+ 1
- 1
src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs View File

@@ -141,7 +141,7 @@ namespace Tensorflow.Operations
data, frame_name, is_constant, parallel_iterations, name: name);

if (use_input_shape)
result.SetShape(data.TensorShape);
result.set_shape(data.TensorShape);

return result;
}


+ 1
- 1
src/TensorFlowNET.Core/Operations/NnOps/rnn.cs View File

@@ -233,7 +233,7 @@ namespace Tensorflow.Operations
dims.AddRange(x_static_shape.dims.Skip(2));
var shape = new TensorShape(dims.ToArray());

x_t.SetShape(shape);
x_t.set_shape(shape);

return x_t;
}


+ 1
- 1
src/TensorFlowNET.Core/Operations/array_ops.py.cs View File

@@ -351,7 +351,7 @@ namespace Tensorflow
var input_shape = tensor_util.to_shape(input_tensor.shape);
if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined())
{
var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_datatype());
var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_dtype());
return constant_op.constant(nd, name: name);
}
}


+ 1
- 1
src/TensorFlowNET.Core/Operations/nn_ops.cs View File

@@ -98,7 +98,7 @@ namespace Tensorflow
// float to be selected, hence we use a >= comparison.
var keep_mask = random_tensor >= rate;
var ret = x * scale * math_ops.cast(keep_mask, x.dtype);
ret.SetShape(x.TensorShape);
ret.set_shape(x.TensorShape);
return ret;
});
}


+ 10
- 5
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -49,7 +49,7 @@ namespace Tensorflow
// dispose newOpts
if (opts == null)
c_api.TF_DeleteSessionOptions(newOpts);
newOpts.Dispose();
status.Check(true);
}
@@ -64,6 +64,11 @@ namespace Tensorflow
return _run(fetche, feed_dict)[0];
}
public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict)
{
return _run(fetche, feed_dict)[0];
}
public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
{
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict);
@@ -273,7 +278,7 @@ namespace Tensorflow
{
var tensor = new Tensor(output);
NDArray nd = null;
Type type = tensor.dtype.as_numpy_datatype();
Type type = tensor.dtype.as_numpy_dtype();
var ndims = tensor.shape;
var offset = c_api.TF_TensorData(output);
@@ -285,7 +290,7 @@ namespace Tensorflow
nd = NDArray.Scalar(*(bool*)offset);
break;
case TF_DataType.TF_STRING:
var bytes = tensor.Data();
var bytes = tensor.BufferToArray();
// wired, don't know why we have to start from offset 9.
// length in the begin
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
@@ -324,7 +329,7 @@ namespace Tensorflow
nd = np.array(bools).reshape(ndims);
break;
case TF_DataType.TF_STRING:
var bytes = tensor.Data();
var bytes = tensor.BufferToArray();
// wired, don't know why we have to start from offset 9.
// length in the begin
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
@@ -396,7 +401,7 @@ namespace Tensorflow
Dispose();
}
protected override void DisposeUnManagedState(IntPtr handle)
protected override void DisposeUnmanagedResources(IntPtr handle)
{
using (var status = new Status())
{


+ 1
- 1
src/TensorFlowNET.Core/Sessions/SessionOptions.cs View File

@@ -32,7 +32,7 @@ namespace Tensorflow
_handle = handle;
}

protected override void DisposeUnManagedState(IntPtr handle)
protected override void DisposeUnmanagedResources(IntPtr handle)
=> c_api.TF_DeleteSessionOptions(handle);

public void SetConfig(ConfigProto config)


+ 12
- 11
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -17,6 +17,7 @@
using NumSharp;
using System;
using System.Collections.Generic;
using NumSharp.Backends;

namespace Tensorflow
{
@@ -71,18 +72,18 @@ namespace Tensorflow
{
if(tensor_values.Length > 0)
{
switch (tensor_values[0].dtype.Name)
switch (tensor_values[0].typecode)
{
case "Int32":
case NPTypeCode.Int32:
full_values.Add(float.NaN);
break;
case "Single":
case NPTypeCode.Single:
full_values.Add(float.NaN);
break;
case "String":
case NPTypeCode.String:
full_values.Add(float.NaN);
break;
case "Char":
case NPTypeCode.Char:
full_values.Add(float.NaN);
break;
default:
@@ -100,21 +101,21 @@ namespace Tensorflow
j += 1;
if (value.ndim == 0)
{
switch (value.dtype.Name)
switch (value.typecode)
{
case "Int16":
case NPTypeCode.Int16:
full_values.Add(value.GetValue<short>(0));
break;
case "Int32":
case NPTypeCode.Int32:
full_values.Add(value.GetValue<int>(0));
break;
case "Int64":
case NPTypeCode.Int64:
full_values.Add(value.GetValue<long>(0));
break;
case "Single":
case NPTypeCode.Single:
full_values.Add(value.GetValue<float>(0));
break;
case "Double":
case NPTypeCode.Double:
full_values.Add(value.GetValue<double>(0));
break;
/*case "String":


+ 3
- 3
src/TensorFlowNET.Core/Status/Status.cs View File

@@ -50,7 +50,7 @@ namespace Tensorflow
/// </summary>
public void Check(bool throwException = false)
{
if(Code != TF_Code.TF_OK)
if (Code != TF_Code.TF_OK)
{
Console.WriteLine(Message);
if (throwException)
@@ -65,7 +65,7 @@ namespace Tensorflow
return status._handle;
}

protected override void DisposeUnManagedState(IntPtr handle)
protected override void DisposeUnmanagedResources(IntPtr handle)
=> c_api.TF_DeleteStatus(handle);
}
}
}

+ 81
- 58
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -16,11 +16,13 @@

using NumSharp;
using System;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
using NumSharp.Backends;
using NumSharp.Backends.Unmanaged;
using static Tensorflow.c_api;

@@ -462,7 +464,7 @@ namespace Tensorflow
*v = value;
_handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(Complex)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(Complex), deallocator: _hGlobalDeallocator, ref _deallocatorArgs);
IsMemoryOwner=true;
}
}
#endif

/// <summary>
@@ -477,7 +479,7 @@ namespace Tensorflow

IntPtr tensor = c_api.TF_TensorData(handle);
Marshal.WriteInt64(tensor, 0);
fixed (byte* src = &buffer[0])
fixed (byte* src = buffer)
c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status);
_handle = handle;
status.Check(true);
@@ -486,35 +488,55 @@ namespace Tensorflow
public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null)
{
// todo: handle nd of type "String" here too
if (tensorDType == TF_DataType.TF_STRING && nd.dtype.Name == "Byte")
if (tensorDType == TF_DataType.TF_STRING && nd.typecode == NPTypeCode.Byte)
{
var buffer = nd.ToArray<byte>();
var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length);
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8));

IntPtr tensor = c_api.TF_TensorData(handle);
Marshal.WriteInt64(tensor, 0);
if (nd.Unsafe.Storage.Shape.IsContiguous)
{
var bytesLength = (UIntPtr)nd.size;
var size = c_api.TF_StringEncodedSize(bytesLength);
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8));

IntPtr tensor = c_api.TF_TensorData(handle);
Marshal.WriteInt64(tensor, 0);

var status = new Status();
c_api.TF_StringEncode((byte*) nd.Unsafe.Address, bytesLength, (sbyte*) (tensor + sizeof(Int64)), size, status);

status.Check(true);
_handle = handle;
IsMemoryOwner = false;
}
else
{
var buffer = nd.ToArray<byte>();
var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length);
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8));

IntPtr tensor = c_api.TF_TensorData(handle);
Marshal.WriteInt64(tensor, 0);

var status = new Status();
fixed (byte* src = buffer)
c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(Int64)), size, status);

status.Check(true);
_handle = handle;
IsMemoryOwner = false;
}

var status = new Status();
fixed (byte* src = &buffer[0])
c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status);

status.Check(true);
_handle=handle;
IsMemoryOwner = false;
return;
}

_handle = CreateTensorFromNDArray(nd, tensorDType);
IsMemoryOwner = true;
}

private unsafe IntPtr CreateTensorFromNDArray(NDArray nd, TF_DataType? given_dtype)
{
if (nd.dtype.Name == "String")
throw new NotImplementedException("Support for NDArray of type string not implemented yet");
if (nd.dtype.Name == "String")
throw new NotImplementedException("Support for NDArray of type string not implemented yet");
IArraySlice arraySlice;
var shape = nd.Unsafe.Storage.Shape;
if (shape.IsSliced || shape.IsBroadcasted)
if (nd.Unsafe.Storage.Shape.IsContiguous == false)
{
// the memory is NOT contiguous, so we have to copy the view into a contiguous memory block.
arraySlice = nd.CloneData();
@@ -527,51 +549,52 @@ namespace Tensorflow
this.Tag = arraySlice; // keep a reference to the memory block to make sure it is not disposed while TF is using it
var ptr = new IntPtr(arraySlice.Address);
int num_bytes = (nd.size * nd.dtypesize);
var dtype = given_dtype ?? ToTFDataType(nd.dtype);
var dtype = given_dtype ?? nd.dtype.as_dtype();
var handle = TF_NewTensor(dtype, dims: nd.shape.Select(i=>(long)i).ToArray(), num_dims: nd.ndim, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs);
IsMemoryOwner = false;
return handle;
}
public unsafe Tensor(byte[][] buffer, long[] shape)
{
int size = 0;
foreach (var b in buffer)
{
size += (int)TF_StringEncodedSize((UIntPtr)b.Length);
}
int totalSize = size + buffer.Length * 8;
ulong offset = 0;
IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr)totalSize);
// Clear offset table
IntPtr pOffset = TF_TensorData(handle);
IntPtr dst = pOffset + buffer.Length * 8;
IntPtr dstLimit = pOffset + totalSize;
for (int i = 0; i < buffer.Length; i++)
{
Marshal.WriteInt64(pOffset, (long)offset);
using (var status = new Status())
{
fixed (byte* src = &buffer[i][0])
{
var written = TF_StringEncode(src, (UIntPtr)buffer[i].Length, (sbyte*)dst, (UIntPtr)(dstLimit.ToInt64() - dst.ToInt64()), status);
status.Check(true);
pOffset += 8;
dst += (int)written;
offset += written;
}
}
}
_handle = handle;

}

public unsafe Tensor(byte[][] buffer, long[] shape)
{
int size = 0;
foreach (var b in buffer)
{
size += (int)TF_StringEncodedSize((UIntPtr)b.Length);
}
int totalSize = size + buffer.Length * 8;
ulong offset = 0;
IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr)totalSize);

// Clear offset table
IntPtr pOffset = TF_TensorData(handle);
IntPtr dst = pOffset + buffer.Length * 8;
IntPtr dstLimit = pOffset + totalSize;
for (int i = 0; i < buffer.Length; i++)
{
Marshal.WriteInt64(pOffset, (long)offset);
using (var status = new Status())
{
fixed (byte* src = &buffer[i][0])
{
var written = TF_StringEncode(src, (UIntPtr)buffer[i].Length, (sbyte*)dst, (UIntPtr)(dstLimit.ToInt64() - dst.ToInt64()), status);
status.Check(true);
pOffset += 8;
dst += (int)written;
offset += written;
}
}
}

_handle = handle;
}

public Tensor(Operation op, int value_index, TF_DataType dtype)
{
_op = op;
_value_index = value_index;
_dtype = dtype;
_override_dtype = dtype;
_id = ops.uid();
}

@@ -589,11 +612,11 @@ namespace Tensorflow
/// specified dimensions.
/// </remarks>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
[SuppressMessage("ReSharper", "LocalVariableHidesMember")]
protected unsafe IntPtr CreateTensorWithoutCopying(TF_DataType dt, long[] shape, Array data, int element_size)
{
if (dt == TF_DataType.TF_STRING && data is byte[])
if (dt == TF_DataType.TF_STRING && data is byte[] buffer)
{
var buffer = (byte[])data;
var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length);
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8));



+ 85
- 28
src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs View File

@@ -1,4 +1,5 @@
using System;
using System.Runtime.CompilerServices;

namespace Tensorflow
{
@@ -6,86 +7,142 @@ namespace Tensorflow
{
public static explicit operator bool(Tensor tensor)
{
EnsureScalar(tensor);
return tensor.Data<bool>()[0];
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_BOOL);
return *(bool*) tensor.buffer;
}
}

public static explicit operator sbyte(Tensor tensor)
{
EnsureScalar(tensor);
return tensor.Data<sbyte>()[0];
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_INT8);
return *(sbyte*) tensor.buffer;
}
}

public static explicit operator byte(Tensor tensor)
{
EnsureScalar(tensor);
return tensor.Data<byte>()[0];
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_UINT8);
return *(byte*) tensor.buffer;
}
}

public static explicit operator ushort(Tensor tensor)
{
EnsureScalar(tensor);
return tensor.Data<ushort>()[0];
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_UINT16);
return *(ushort*) tensor.buffer;
}
}

public static explicit operator short(Tensor tensor)
{
EnsureScalar(tensor);
return tensor.Data<short>()[0];
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_INT16);
return *(short*) tensor.buffer;
}
}

public static explicit operator int(Tensor tensor)
{
EnsureScalar(tensor);
return tensor.Data<int>()[0];
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_INT32);
return *(int*) tensor.buffer;
}
}

public static explicit operator uint(Tensor tensor)
{
EnsureScalar(tensor);
return tensor.Data<uint>()[0];
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_UINT32);
return *(uint*) tensor.buffer;
}
}

public static explicit operator long(Tensor tensor)
{
EnsureScalar(tensor);
return tensor.Data<long>()[0];
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_INT64);
return *(long*) tensor.buffer;
}
}

public static explicit operator ulong(Tensor tensor)
{
EnsureScalar(tensor);
return tensor.Data<ulong>()[0];
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_UINT64);
return *(ulong*) tensor.buffer;
}
}

public static explicit operator float(Tensor tensor)
{
EnsureScalar(tensor);
return tensor.Data<float>()[0];
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_FLOAT);
return *(float*) tensor.buffer;
}
}

public static explicit operator double(Tensor tensor)
{
EnsureScalar(tensor);
return tensor.Data<double>()[0];
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_DOUBLE);
return *(double*) tensor.buffer;
}
}

public static explicit operator string(Tensor tensor)
{
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_STRING);
return new string((char*) tensor.buffer, 0, (int) tensor.size);
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void EnsureDType(Tensor tensor, TF_DataType @is)
{
if (tensor.dtype != @is)
throw new InvalidCastException($"Unable to cast scalar tensor {tensor.dtype} to {@is}");
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void EnsureScalar(Tensor tensor)
{
if (tensor == null)
{
throw new ArgumentNullException(nameof(tensor));
}

if (tensor.TensorShape.ndim != 0)
{
throw new ArgumentException("Tensor must have 0 dimensions in order to convert to scalar");
}

if (tensor.TensorShape.size != 1)
{
throw new ArgumentException("Tensor must have size 1 in order to convert to scalar");
}
}

}


+ 3
- 4
src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs View File

@@ -69,11 +69,12 @@ namespace Tensorflow
TF_DataType.TF_QINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QINT32,
TF_DataType.TF_UINT8, TF_DataType.TF_UINT16, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64
};

public static Tensor operator /(double x, Tensor y) => BinaryOpWrapper("truediv", x, y);
public static Tensor operator /(float x, Tensor y) => BinaryOpWrapper("truediv", x, y);
public static Tensor operator /(int x, Tensor y) => BinaryOpWrapper("floordiv", x, y);
public static Tensor operator /(Tensor x, Tensor y) =>
_intTfDataTypes.Contains(x._dtype)
_intTfDataTypes.Contains(x.dtype)
? BinaryOpWrapper("floordiv", x, y)
: BinaryOpWrapper("truediv", x, y);
public static Tensor operator /(Tensor x, int y) => BinaryOpWrapper("floordiv", x, y);
@@ -122,8 +123,7 @@ namespace Tensorflow
if (y is Tensor tr)
dtype = tr.dtype.as_base_dtype();

var namescope = ops.name_scope(null, name, new { x, y });
return tf_with(namescope, scope =>
return tf_with(ops.name_scope(null, name, new { x, y }), scope =>
{
Tensor result = null;
var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x");
@@ -155,7 +155,6 @@ namespace Tensorflow

return result;
});

}
}
}

+ 274
- 103
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -17,9 +17,16 @@
using NumSharp;
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading.Tasks;
using NumSharp.Backends;
using NumSharp.Backends.Unmanaged;
using NumSharp.Utilities;
using Tensorflow.Framework;
using static Tensorflow.Binding;

@@ -29,42 +36,68 @@ namespace Tensorflow
/// A tensor is a generalization of vectors and matrices to potentially higher dimensions.
/// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes.
/// </summary>
[SuppressMessage("ReSharper", "ConvertToAutoProperty")]
public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike
{
private int _id;
private Operation _op;
private readonly int _id;
private readonly Operation _op;
private readonly int _value_index;
private TF_Output? _tf_output;
private readonly TF_DataType _override_dtype;

public int Id => _id;

/// <summary>
/// The Graph that contains this tensor.
/// </summary>
public Graph graph => op?.graph;

/// <summary>
/// The Operation that produces this tensor as an output.
/// </summary>
public Operation op => _op;

public Tensor[] outputs => op.outputs;

/// <summary>
/// The string name of this tensor.
/// The string name of this tensor.
/// </summary>
public string name => $"{(op == null ? "<unnamed Operation>" : $"{op.name}:{_value_index}")}";

private int _value_index;
/// <summary>
/// The index of this tensor in the outputs of its Operation.
/// </summary>
public int value_index => _value_index;

private TF_DataType _dtype = TF_DataType.DtInvalid;
public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle);
/// <summary>
/// The DType of elements in this tensor.
/// </summary>
public TF_DataType dtype => _handle == IntPtr.Zero ? _override_dtype : c_api.TF_TensorType(_handle);

public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle);

public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype);
public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize;

public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle);
public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out);
public int NDims => rank;

private TF_Output? _tf_output;
/// <summary>
/// The name of the device on which this tensor will be produced, or null.
/// </summary>
public string Device => op.Device;

public int[] dims => shape;

/// <summary>
/// used for keep other pointer when do implicit operating
/// Used for keep other pointer when do implicit operating
/// </summary>
public object Tag { get; set; }


/// <summary>
/// Returns the shape of a tensor.
/// </summary>
/// <remarks>https://www.tensorflow.org/api_docs/python/tf/shape</remarks>
public int[] shape
{
get
@@ -76,14 +109,13 @@ namespace Tensorflow
var status = new Status();
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status);
status.Check();
}
else
} else
{
for (int i = 0; i < rank; i++)
dims[i] = c_api.TF_Dim(_handle, i);
}

return dims.Select(x => Convert.ToInt32(x)).ToArray();
return dims.Select(x => ((IConvertible) x).ToInt32(CultureInfo.InvariantCulture)).ToArray();
}

set
@@ -93,38 +125,52 @@ namespace Tensorflow
if (value == null)
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status);
else
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(x => Convert.ToInt64(x)).ToArray(), value.Length, status);
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status);
}
}

public int[] _shape_tuple()
{
if (shape == null) return null;
return shape.Select(x => (int)x).ToArray();
return (int[]) shape.Clone();
}

public TensorShape TensorShape => tensor_util.to_shape(shape);

public void SetShape(TensorShape shape)
/// <summary>
/// Updates the shape of this tensor.
/// </summary>
public void set_shape(TensorShape shape)
{
this.shape = shape.dims;
this.shape = (int[]) shape.dims.Clone();
}

/// <summary>
/// Updates the shape of this tensor.
/// </summary>
[Obsolete("Please use set_shape(TensorShape shape) instead.", false)]
public void SetShape(TensorShape shape)
{
this.shape = (int[]) shape.dims.Clone();
}

/// <summary>
/// Updates the shape of this tensor.
/// </summary>
public void set_shape(Tensor shape)
{
// ReSharper disable once MergeConditionalExpression
this.shape = shape is null ? null : shape.shape;
}

public int[] dims => shape;

/// <summary>
/// number of dimensions
/// 0 Scalar (magnitude only)
/// 1 Vector (magnitude and direction)
/// 2 Matrix (table of numbers)
/// 3 3-Tensor (cube of numbers)
/// number of dimensions <br></br>
/// 0 Scalar (magnitude only) <br></br>
/// 1 Vector (magnitude and direction) <br></br>
/// 2 Matrix (table of numbers) <br></br>
/// 3 3-Tensor (cube of numbers) <br></br>
/// n n-Tensor (you get the idea)
/// </summary>
/// <remarks>https://www.tensorflow.org/api_docs/python/tf/rank</remarks>
public int rank
{
get
@@ -137,17 +183,15 @@ namespace Tensorflow
status.Check();
return ndim;
}
else
{
return c_api.TF_NumDims(_handle);
}

return c_api.TF_NumDims(_handle);
}
}

public int NDims => rank;
public string Device => op.Device;
/// <summary>
/// Returns a list of Operations that consume this tensor.
/// </summary>
/// <returns></returns>
public Operation[] consumers()
{
var output = _as_tf_output();
@@ -157,37 +201,191 @@ namespace Tensorflow

public TF_Output _as_tf_output()
{
if(!_tf_output.HasValue)
if (!_tf_output.HasValue)
_tf_output = new TF_Output(op, value_index);

return _tf_output.Value;
}

public T[] Data<T>()
[Obsolete("Please use ToArray<T>() instead.", false)]
public T[] Data<T>() where T : unmanaged
{
return ToArray<T>();
}

/// <summary>
///
/// </summary>
/// <typeparam name="T"></typeparam>
/// <returns></returns>
/// <exception cref="ArgumentException">When <typeparam name="T"> is string </typeparam></exception>
public T[] ToArray<T>() where T : unmanaged
{
// Column major order
// https://en.wikipedia.org/wiki/File:Row_and_column_major_order.svg
// matrix:[[1, 2, 3], [4, 5, 6]]
// index: 0 2 4 1 3 5
// result: 1 4 2 5 3 6
var data = new T[size];

for (ulong i = 0; i < size; i++)
//when T is string
if (typeof(T) == typeof(string))
{
data[i] = Marshal.PtrToStructure<T>(buffer + (int)(i * itemsize));
if (dtype != TF_DataType.TF_STRING)
throw new ArgumentException($"Given <{typeof(T).Name}> can't be converted to string.");

return (T[]) (object) StringData();
}

return data;
//Are the types matching?
if (typeof(T).as_dtype() == dtype)
{
if (NDims == 0 && size == 1) //is it a scalar?
{
unsafe
{
return new T[] {*(T*) buffer};
}
}

//types match, no need to perform cast
var ret = new T[size];
unsafe
{
var len = (long) size;
fixed (T* dstRet = ret)
{
T* dst = dstRet; //local stack copy
if (typeof(T).IsPrimitive)
{
var src = (T*) buffer;
len *= ((long) itemsize);
System.Buffer.MemoryCopy(src, dst, len, len);
} else
{
var itemsize = (long) this.itemsize;
var buffer = this.buffer.ToInt64();
Parallel.For(0L, len, i => dst[i] = Marshal.PtrToStructure<T>(new IntPtr(buffer + i * itemsize)));
}
}
}

return ret;
} else
{
//types do not match, need to perform cast
if (NDims == 0 && size == 1) //is it a scalar?
{
unsafe
{
#if _REGEN
#region Compute
switch (dtype.as_numpy_dtype().GetTypeCode())
{
%foreach supported_dtypes,supported_dtypes_lowercase%
case NPTypeCode.#1: return new T[] {Converts.ChangeType<T>(*(#2*) buffer, NPTypeCode.#1)};
%
case NPTypeCode.String: return new T[] {Converts.ChangeType<T>((string)this, NPTypeCode.String)};
default:
throw new NotSupportedException();
}
#endregion
#else
#region Compute
switch (dtype.as_numpy_dtype()?.GetTypeCode())
{
case NPTypeCode.Boolean: return new T[] {Converts.ChangeType<T>(*(bool*) buffer, NPTypeCode.Boolean)};
case NPTypeCode.Byte: return new T[] {Converts.ChangeType<T>(*(byte*) buffer, NPTypeCode.Byte)};
case NPTypeCode.Int16: return new T[] {Converts.ChangeType<T>(*(short*) buffer, NPTypeCode.Int16)};
case NPTypeCode.UInt16: return new T[] {Converts.ChangeType<T>(*(ushort*) buffer, NPTypeCode.UInt16)};
case NPTypeCode.Int32: return new T[] {Converts.ChangeType<T>(*(int*) buffer, NPTypeCode.Int32)};
case NPTypeCode.UInt32: return new T[] {Converts.ChangeType<T>(*(uint*) buffer, NPTypeCode.UInt32)};
case NPTypeCode.Int64: return new T[] {Converts.ChangeType<T>(*(long*) buffer, NPTypeCode.Int64)};
case NPTypeCode.UInt64: return new T[] {Converts.ChangeType<T>(*(ulong*) buffer, NPTypeCode.UInt64)};
case NPTypeCode.Char: return new T[] {Converts.ChangeType<T>(*(char*) buffer, NPTypeCode.Char)};
case NPTypeCode.Double: return new T[] {Converts.ChangeType<T>(*(double*) buffer, NPTypeCode.Double)};
case NPTypeCode.Single: return new T[] {Converts.ChangeType<T>(*(float*) buffer, NPTypeCode.Single)};
case NPTypeCode.String: return new T[] {Converts.ChangeType<T>((string)this, NPTypeCode.String)};
default:
throw new NotSupportedException();
}
#endregion
#endif
}
}

var ret = new T[size];
unsafe
{
var len = (long) size;
fixed (T* dstRet = ret)
{
T* dst = dstRet; //local stack copy

#if _REGEN
#region Compute
switch (dtype.as_numpy_dtype().GetTypeCode())
{
%foreach supported_dtypes,supported_dtypes_lowercase%
case NPTypeCode.#1: new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
%
default:
throw new NotSupportedException();
}
#endregion
#else
#region Compute
switch (dtype.as_numpy_dtype().GetTypeCode())
{
case NPTypeCode.Boolean: new UnmanagedMemoryBlock<bool>((bool*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Byte: new UnmanagedMemoryBlock<byte>((byte*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Int16: new UnmanagedMemoryBlock<short>((short*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.UInt16: new UnmanagedMemoryBlock<ushort>((ushort*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Int32: new UnmanagedMemoryBlock<int>((int*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.UInt32: new UnmanagedMemoryBlock<uint>((uint*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Int64: new UnmanagedMemoryBlock<long>((long*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.UInt64: new UnmanagedMemoryBlock<ulong>((ulong*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Char: new UnmanagedMemoryBlock<char>((char*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Double: new UnmanagedMemoryBlock<double>((double*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.Single: new UnmanagedMemoryBlock<float>((float*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break;
case NPTypeCode.String: throw new NotSupportedException("Unable to convert from string to other dtypes"); //TODO! this should call Converts.To<T>
default:
throw new NotSupportedException();
}
#endregion
#endif
}
}
return ret;
}
}

/// <summary>
/// Copies the memory of current buffer onto newly allocated array.
/// </summary>
/// <returns></returns>
[Obsolete("Please use set_shape(TensorShape shape) instead.", false)]
public byte[] Data()
{
var data = new byte[bytesize];
Marshal.Copy(buffer, data, 0, (int)bytesize);
return data;
return BufferToArray();
}

/// <summary>
/// Copies the memory of current buffer onto newly allocated array.
/// </summary>
/// <returns></returns>
public byte[] BufferToArray()
{
unsafe
{
// ReSharper disable once LocalVariableHidesMember
var bytesize = (long) this.bytesize;
var data = new byte[bytesize];
fixed (byte* dst = data)
System.Buffer.MemoryCopy(buffer.ToPointer(), dst, bytesize, bytesize);

return data;
}
}

public unsafe string[] StringData()
/// Used internally in ToArray&lt;T&gt;
private unsafe string[] StringData()
{
//
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes.
@@ -199,19 +397,19 @@ namespace Tensorflow

var buffer = new byte[size][];
var src = c_api.TF_TensorData(_handle);
var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize);
src += (int)(size * 8);
var srcLen = (IntPtr) (src.ToInt64() + (long) bytesize);
src += (int) (size * 8);
for (int i = 0; i < buffer.Length; i++)
{
using (var status = new Status())
{
IntPtr dst = IntPtr.Zero;
UIntPtr dstLen = UIntPtr.Zero;
var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, status);
var read = c_api.TF_StringDecode((byte*) src, (UIntPtr) (srcLen.ToInt64() - src.ToInt64()), (byte**) &dst, &dstLen, status);
status.Check(true);
buffer[i] = new byte[(int)dstLen];
buffer[i] = new byte[(int) dstLen];
Marshal.Copy(dst, buffer[i], 0, buffer[i].Length);
src += (int)read;
src += (int) read;
}
}

@@ -229,51 +427,29 @@ namespace Tensorflow
}

/// <summary>
/// Evaluates this tensor in a `Session`.
/// Evaluates this tensor in a `Session`.
/// </summary>
/// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param>
/// <param name="session">The `Session` to be used to evaluate this tensor.</param>
/// <returns></returns>
/// <returns>A <see cref="NumSharp"/> array corresponding to the value of this tensor.</returns>
public NDArray eval(params FeedItem[] feed_dict)
{
return ops._eval_using_default_session(this, feed_dict, graph);
}

/// <summary>
/// Evaluates this tensor in a `Session`.
/// </summary>
/// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param>
/// <param name="session">The `Session` to be used to evaluate this tensor.</param>
/// <returns>A <see cref="NumSharp"/> array corresponding to the value of this tensor.</returns>
public NDArray eval(Session session, FeedItem[] feed_dict = null)
{
return ops._eval_using_default_session(this, feed_dict, graph, session);
}

public TF_DataType ToTFDataType(Type type)
{
switch (type.Name)
{
case "Char":
return TF_DataType.TF_UINT8;
case "Int16":
return TF_DataType.TF_INT16;
case "Int32":
return TF_DataType.TF_INT32;
case "Int64":
return TF_DataType.TF_INT64;
case "Single":
return TF_DataType.TF_FLOAT;
case "Double":
return TF_DataType.TF_DOUBLE;
case "Byte":
return TF_DataType.TF_UINT8;
case "String":
return TF_DataType.TF_STRING;
case "Boolean":
return TF_DataType.TF_BOOL;
default:
throw new NotImplementedException("ToTFDataType error");
}
}

public Tensor slice(Slice slice)
{
var slice_spec = new int[] { slice.Start.Value };
var slice_spec = new int[] {slice.Start.Value};
var begin = new List<int>();
var end = new List<int>();
var strides = new List<int>();
@@ -289,26 +465,26 @@ namespace Tensorflow
if (slice.Stop.HasValue)
{
end.Add(slice.Stop.Value);
}
else
} else
{
end.Add(0);
end_mask |= (1 << index);
}

strides.Add(slice.Step);

index += 1;
}

return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope =>
return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope =>
{
string name = scope;
if (begin != null)
{
var (packed_begin, packed_end, packed_strides) =
(array_ops.stack(begin.ToArray()),
array_ops.stack(end.ToArray()),
array_ops.stack(strides.ToArray()));
array_ops.stack(end.ToArray()),
array_ops.stack(strides.ToArray()));

return gen_array_ops.strided_slice(
this,
@@ -320,7 +496,6 @@ namespace Tensorflow
shrink_axis_mask: shrink_axis_mask,
new_axis_mask: new_axis_mask,
ellipsis_mask: ellipsis_mask,

name: name);
}

@@ -330,7 +505,7 @@ namespace Tensorflow

public Tensor slice(int start)
{
var slice_spec = new int[] { start };
var slice_spec = new int[] {start};
var begin = new List<int>();
var end = new List<int>();
var strides = new List<int>();
@@ -349,15 +524,15 @@ namespace Tensorflow
index += 1;
}

return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope =>
return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope =>
{
string name = scope;
if (begin != null)
{
var (packed_begin, packed_end, packed_strides) =
(array_ops.stack(begin.ToArray()),
array_ops.stack(end.ToArray()),
array_ops.stack(strides.ToArray()));
array_ops.stack(end.ToArray()),
array_ops.stack(strides.ToArray()));

return gen_array_ops.strided_slice(
this,
@@ -369,7 +544,6 @@ namespace Tensorflow
shrink_axis_mask: shrink_axis_mask,
new_axis_mask: new_axis_mask,
ellipsis_mask: ellipsis_mask,

name: name);
}

@@ -392,15 +566,12 @@ namespace Tensorflow
return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}";
}

protected override void DisposeManagedState()
{
}

protected override void DisposeUnManagedState(IntPtr handle)
protected override void DisposeUnmanagedResources(IntPtr handle)
{
if(handle != IntPtr.Zero)
if (handle != IntPtr.Zero)
{
c_api.TF_DeleteTensor(handle);
_handle = IntPtr.Zero;
}
}

@@ -417,4 +588,4 @@ namespace Tensorflow

public int tensor_int_val { get; set; }
}
}
}

+ 93
- 15
src/TensorFlowNET.Core/Tensors/TensorShape.cs View File

@@ -1,35 +1,84 @@
using NumSharp;
using System;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.CompilerServices;

namespace Tensorflow
{
/// <summary>
/// Represents the shape of a `Tensor`.
/// Represents the shape of a `Tensor`.
/// </summary>
/// <remarks>https://www.tensorflow.org/api_docs/python/tf/TensorShape</remarks>
public class TensorShape
{
private Shape shape;
private readonly Shape shape;

/// <summary>
/// Returns a list of Dimensions, or None if the shape is unspecified.
/// </summary>
public int[] dims => shape.Dimensions;

/// <summary>
/// Returns the rank of this shape.
/// </summary>
public int ndim => shape.NDim;

/// <summary>
/// Returns the rank of this shape.
/// </summary>
public int rank => shape.NDim;

/// <summary>
/// Returns the size this shape represents.
/// </summary>
public int size => shape.Size;

public TensorShape(TensorShapeProto proto)
{
if (proto.UnknownRank) return;
switch (proto.Dim.Count)
{
case 0: shape = new Shape(new int[0]); break;
case 1: shape = Shape.Vector((int) proto.Dim[0].Size); break;
case 2: shape = Shape.Matrix((int) proto.Dim[0].Size, (int) proto.Dim[1].Size); break;
default:
var protodims = proto.Dim;
var len = protodims.Count;
var dims = new int[len];
for (int i = 0; i < len; i++)
dims[i] = (int) protodims[i].Size;


shape.reshape(proto.Dim.Select(x => (int)x.Size).ToArray());
shape = new Shape(dims); break;
}
}

public TensorShape(params int[] dims)
{
shape = new Shape(dims);
switch (dims.Length)
{
case 0: shape = new Shape(new int[0]); break;
case 1: shape = Shape.Vector((int) dims[0]); break;
case 2: shape = Shape.Matrix(dims[0], dims[1]); break;
default: shape = new Shape(dims); break;
}
}

/// <summary>
///
/// </summary>
/// <param name="slice"></param>
/// <returns></returns>
/// <exception cref="ArgumentException">When <see cref="Slice"/> is not an Index.</exception>
[SuppressMessage("ReSharper", "PossibleInvalidOperationException")]
public TensorShape this[Slice slice]
{
get
{
if (slice.Start.HasValue == false || slice.Length.HasValue == false)
throw new ArgumentException("Slice must has Start and Length.");

return new TensorShape(dims.Skip(slice.Start.Value)
.Take(slice.Length.Value)
.ToArray());
@@ -37,7 +86,7 @@ namespace Tensorflow
}

/// <summary>
/// Returns True iff `self` is fully defined in every dimension.
/// Returns True iff `self` is fully defined in every dimension.
/// </summary>
/// <returns></returns>
public bool is_fully_defined()
@@ -50,6 +99,7 @@ namespace Tensorflow
throw new NotImplementedException("TensorShape is_compatible_with");
}

[SuppressMessage("ReSharper", "ParameterHidesMember")]
public TensorShape with_rank_at_least(int rank)
{
if (rank != ndim)
@@ -59,35 +109,63 @@ namespace Tensorflow
}

/// <summary>
/// Returns the concatenation of the dimension in `self` and `other`.
/// Returns the concatenation of the dimension in `self` and `other`.
/// </summary>
/// <param name="other"></param>
/// <returns></returns>
public TensorShape concatenate(int[] other_)
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public TensorShape concatenate(int[] other)
{
var other = new TensorShape(other_);
return concatenate(new TensorShape(other));
}

if (ndim < 0 || other.ndim < 0)
/// <summary>
/// Returns the concatenation of the dimension in `self` and `other`.
/// </summary>
/// <param name="other"></param>
/// <returns></returns>
public TensorShape concatenate(TensorShape other)
{
var otherShape = other;

if (ndim < 0 || otherShape.ndim < 0)
return new TensorShape();
else
{
var concatenate_dims = new int[ndim + other.ndim];
var concatenate_dims = new int[ndim + otherShape.ndim];
for (int i = 0; i < ndim; i++)
concatenate_dims[i] = dims[i];

for (int i = 0; i < other.ndim; i++)
concatenate_dims[ndim + i] = other.dims[i];
for (int i = 0; i < otherShape.ndim; i++)
concatenate_dims[ndim + i] = otherShape.dims[i];

return new TensorShape(concatenate_dims);
}
}

public static implicit operator TensorShape(Shape shape) => new TensorShape(shape.Dimensions);
public static implicit operator Shape(TensorShape shape) => new Shape(shape.dims);
public static implicit operator TensorShape(Shape shape) => new TensorShape((int[]) shape.Dimensions.Clone());
public static implicit operator Shape(TensorShape shape) => new Shape((int[]) shape.dims.Clone());
public static implicit operator int[](TensorShape shape) => (int[])shape.dims.Clone(); //we clone to avoid any changes
public static implicit operator TensorShape(int[] dims) => new TensorShape(dims);
public static implicit operator int[](TensorShape shape) => shape.dims;

public static explicit operator int(TensorShape shape) => shape.size;
public static explicit operator TensorShape(int dim) => new TensorShape(dim);

public static explicit operator (int, int)(TensorShape shape) => shape.dims.Length == 2 ? (shape.dims[0], shape.dims[1]) : (0, 0);
public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2);

public static explicit operator (int, int, int)(TensorShape shape) => shape.dims.Length == 3 ? (shape.dims[0], shape.dims[1], shape.dims[2]) : (0, 0, 0);
public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3);

public static explicit operator (int, int, int, int)(TensorShape shape) => shape.dims.Length == 4 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3]) : (0, 0, 0, 0);
public static implicit operator TensorShape((int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4);

public static explicit operator (int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 5 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4]) : (0, 0, 0, 0, 0);
public static implicit operator TensorShape((int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5);

public static explicit operator (int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 6 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5]) : (0, 0, 0, 0, 0, 0);
public static implicit operator TensorShape((int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6);

}
}

+ 74
- 25
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -15,6 +15,8 @@
******************************************************************************/

using System;
using System.Numerics;
using NumSharp.Backends;

namespace Tensorflow
{
@@ -23,35 +25,100 @@ namespace Tensorflow
public static TF_DataType int8 = TF_DataType.TF_INT8;
public static TF_DataType int32 = TF_DataType.TF_INT32;
public static TF_DataType int64 = TF_DataType.TF_INT64;
public static TF_DataType uint8 = TF_DataType.TF_UINT8;
public static TF_DataType uint32 = TF_DataType.TF_UINT32;
public static TF_DataType uint64 = TF_DataType.TF_UINT64;
public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32?
public static TF_DataType float16 = TF_DataType.TF_HALF;
public static TF_DataType float64 = TF_DataType.TF_DOUBLE;

public static Type as_numpy_datatype(this TF_DataType type)
/// <summary>
///
/// </summary>
/// <param name="type"></param>
/// <returns><see cref="System.Type"/> equivalent to <paramref name="type"/>, if none exists, returns null.</returns>
public static Type as_numpy_dtype(this TF_DataType type)
{
switch (type)
{
case TF_DataType.TF_BOOL:
return typeof(bool);
case TF_DataType.TF_UINT8:
return typeof(byte);
case TF_DataType.TF_INT64:
return typeof(long);
case TF_DataType.TF_UINT64:
return typeof(ulong);
case TF_DataType.TF_INT32:
return typeof(int);
case TF_DataType.TF_UINT32:
return typeof(uint);
case TF_DataType.TF_INT16:
return typeof(short);
case TF_DataType.TF_UINT16:
return typeof(ushort);
case TF_DataType.TF_FLOAT:
return typeof(float);
case TF_DataType.TF_DOUBLE:
return typeof(double);
case TF_DataType.TF_STRING:
return typeof(string);
case TF_DataType.TF_COMPLEX128:
case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX
return typeof(Complex);
default:
return null;
}
}

// "sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"
public static TF_DataType as_dtype(Type type, TF_DataType? dtype = null)
/// <summary>
///
/// </summary>
/// <param name="type"></param>
/// <returns></returns>
/// <exception cref="ArgumentException">When <paramref name="type"/> has no equivalent <see cref="NPTypeCode"/></exception>
public static NPTypeCode as_numpy_typecode(this TF_DataType type)
{
switch (type)
{
case TF_DataType.TF_BOOL:
return NPTypeCode.Boolean;
case TF_DataType.TF_UINT8:
return NPTypeCode.Byte;
case TF_DataType.TF_INT64:
return NPTypeCode.Int64;
case TF_DataType.TF_INT32:
return NPTypeCode.Int32;
case TF_DataType.TF_INT16:
return NPTypeCode.Int16;
case TF_DataType.TF_UINT64:
return NPTypeCode.UInt64;
case TF_DataType.TF_UINT32:
return NPTypeCode.UInt32;
case TF_DataType.TF_UINT16:
return NPTypeCode.UInt16;
case TF_DataType.TF_FLOAT:
return NPTypeCode.Single;
case TF_DataType.TF_DOUBLE:
return NPTypeCode.Double;
case TF_DataType.TF_STRING:
return NPTypeCode.String;
case TF_DataType.TF_COMPLEX128:
case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX
return NPTypeCode.Complex;
default:
throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode.");
}
}

/// <summary>
///
/// </summary>
/// <param name="type"></param>
/// <param name="dtype"></param>
/// <returns></returns>
/// <exception cref="ArgumentException">When <paramref name="type"/> has no equivalent <see cref="TF_DataType"/></exception>
public static TF_DataType as_dtype(this Type type, TF_DataType? dtype = null)
{
switch (type.Name)
{
@@ -98,7 +165,7 @@ namespace Tensorflow
dtype = TF_DataType.TF_BOOL;
break;
default:
throw new Exception("as_dtype Not Implemented");
throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode.");
}

return dtype.Value;
@@ -106,16 +173,7 @@ namespace Tensorflow

public static DataType as_datatype_enum(this TF_DataType type)
{
DataType dtype = DataType.DtInvalid;

switch (type)
{
default:
Enum.TryParse(((int)type).ToString(), out dtype);
break;
}

return dtype;
return Enum.TryParse(((int) type).ToString(), out DataType dtype) ? dtype : DataType.DtInvalid;
}

public static TF_DataType as_base_dtype(this TF_DataType type)
@@ -132,7 +190,7 @@ namespace Tensorflow

public static Type as_numpy_dtype(this DataType type)
{
return type.as_tf_dtype().as_numpy_datatype();
return type.as_tf_dtype().as_numpy_dtype();
}

public static DataType as_base_dtype(this DataType type)
@@ -144,16 +202,7 @@ namespace Tensorflow

public static TF_DataType as_tf_dtype(this DataType type)
{
TF_DataType dtype = TF_DataType.DtInvalid;

switch (type)
{
default:
Enum.TryParse(((int)type).ToString(), out dtype);
break;
}

return dtype;
return Enum.TryParse(((int) type).ToString(), out TF_DataType dtype) ? dtype : TF_DataType.DtInvalid;
}

public static TF_DataType as_ref(this TF_DataType type)


+ 8
- 7
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -17,6 +17,7 @@
using NumSharp;
using System;
using System.Linq;
using NumSharp.Utilities;

namespace Tensorflow
{
@@ -109,7 +110,7 @@ namespace Tensorflow

// We first convert value to a numpy array or scalar.
NDArray nparray = null;
var np_dt = dtype.as_numpy_datatype();
var np_dt = dtype.as_numpy_dtype();

if (values is NDArray nd)
{
@@ -188,37 +189,37 @@ namespace Tensorflow
if (values.GetType().IsArray)
nparray = np.array((int[])values, np_dt);
else
nparray = Convert.ToInt32(values);
nparray = Converts.ToInt32(values);
break;
case "Int64":
if (values.GetType().IsArray)
nparray = np.array((int[])values, np_dt);
else
nparray = Convert.ToInt64(values);
nparray = Converts.ToInt64(values);
break;
case "Single":
if (values.GetType().IsArray)
nparray = np.array((float[])values, np_dt);
else
nparray = Convert.ToSingle(values);
nparray = Converts.ToSingle(values);
break;
case "Double":
if (values.GetType().IsArray)
nparray = np.array((double[])values, np_dt);
else
nparray = Convert.ToDouble(values);
nparray = Converts.ToDouble(values);
break;
case "String":
if (values.GetType().IsArray)
nparray = np.array((string[])values, np_dt);
else
nparray = NDArray.FromString(Convert.ToString(values));
nparray = NDArray.FromString(Converts.ToString(values));
break;
case "Boolean":
if (values.GetType().IsArray)
nparray = np.array((bool[])values, np_dt);
else
nparray = Convert.ToBoolean(values);
nparray = Converts.ToBoolean(values);
break;
default:
throw new NotImplementedException($"make_tensor_proto: Support for type {np_dt.Name} Not Implemented");


+ 38
- 0
src/TensorFlowNET.Core/globals.regen View File

@@ -0,0 +1,38 @@
%all_dtypes = ["NDArray","Complex","Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single","String"]
%all_dtypes_lowercase = ["NDArray","Complex","bool","byte","short","ushort","int","uint","long","ulong","char","double","float","string"]

%supported_primitives = ["Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single","String"]
%supported_primitives_lowercase = ["bool","byte","short","ushort","int","uint","long","ulong","char","double","float","string"]

%supported_numericals = ["Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single"]
%supported_numericals_lowercase = ["byte","short","ushort","int","uint","long","ulong","char","double","float"]
%supported_numericals_defaultvals = ["0","0","0","0","0u","0L","0UL","'\0'","0d","0f"]
%supported_numericals_onevales = ["1","1","1","1","1u","1L","1UL",1,"1d","1f"]
%supported_numericals_TF_DataType = ["TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_UINT8","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"]

//this is the type we use in summerizing/reducting:
%supported_numericals_accumulatingType = ["UInt32","Int32","UInt32","Int32","UInt32","Int64","UInt64","UInt32","Double","Single"]
%supported_numericals_accumulatingType_defaultvals = ["0","0","0","0","0u","0L","0UL","'\0'","0d","0f"]

%supported_numericals_signed = ["Int16","Int32","Int64","Double","Single"]
%supported_numericals_signed_lowercase = ["short","int","long","double","float"]
%supported_numericals_signed_defaultvals = ["0","0","0L","0d","0f"]
%supported_numericals_signed_onevales = ["1","1","1L","1d","1f"]

%supported_numericals_unsigned = ["Byte","UInt16","UInt32","UInt64","Char"]
%supported_numericals_unsigned_lowercase = ["byte","ushort","uint","ulong","char"]
%supported_numericals_unsigned_defaultvals = ["0","0","0U","0UL","'\0'"]
%supported_numericals_unsigned_onevales = ["1","1","1U","1UL","'\1'"]

%supported_dtypes = ["Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single"]
%supported_numericals_TF_DataType = ["TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_UINT8","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"]

%supported_dtypes_lowercase = ["bool","byte","short","ushort","int","uint","long","ulong","char","double","float"]
%supported_dtypes_defaultvals = [false,"0","0","0","0","0u","0L","0UL","'\0'","0d","0f"]
%supported_dtypes_onevales = [true,"1","1","1","1","1u","1L","1UL","'\1'","1d","1f"]
%supported_dtypes_dtype = ["bool","uint8","int16","uint16","int32","uint32","int64","uint64","uint8","float64","float32"]

//this is the type we use in summerizing/reducting:
%supported_dtypes_accumulatingType = ["Int32","UInt32","Int32","UInt32","Int32","UInt32","Int64","UInt64","UInt32","Double","Single"]
%supported_dtypes_accumulatingType_defaultvals = [false, "0","0","0","0u","0L","0UL","'\0'","0d","0f"]


+ 70
- 14
src/TensorFlowNET.Core/ops.GraphKeys.cs View File

@@ -29,55 +29,111 @@ namespace Tensorflow
/// </summary>
public class GraphKeys
{
#region const

/// <summary>
/// the subset of `Variable` objects that will be trained by an optimizer.
/// </summary>
public const string TRAINABLE_VARIABLES_ = "trainable_variables";

/// <summary>
/// Trainable resource-style variables.
/// </summary>
public const string TRAINABLE_RESOURCE_VARIABLES_ = "trainable_resource_variables";

/// <summary>
/// Key for streaming model ports.
/// </summary>
public const string _STREAMING_MODEL_PORTS_ = "streaming_model_ports";

/// <summary>
/// Key to collect losses
/// </summary>
public const string LOSSES_ = "losses";

/// <summary>
/// Key to collect Variable objects that are global (shared across machines).
/// Default collection for all variables, except local ones.
/// </summary>
public const string GLOBAL_VARIABLES_ = "variables";

public const string TRAIN_OP_ = "train_op";

public const string GLOBAL_STEP_ = "global_step";

public string[] _VARIABLE_COLLECTIONS_ = new string[] { "variables", "trainable_variables", "model_variables" };
/// <summary>
/// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing.
/// </summary>
public const string SAVEABLE_OBJECTS_ = "saveable_objects";
/// <summary>
/// Key to collect update_ops
/// </summary>
public const string UPDATE_OPS_ = "update_ops";

// Key to collect summaries.
public const string SUMMARIES_ = "summaries";

// Used to store v2 summary names.
public const string _SUMMARY_COLLECTION_ = "_SUMMARY_V2";

// Key for control flow context.
public const string COND_CONTEXT_ = "cond_context";
public const string WHILE_CONTEXT_ = "while_context";

#endregion

/// <summary>
/// the subset of `Variable` objects that will be trained by an optimizer.
/// </summary>
public string TRAINABLE_VARIABLES = "trainable_variables";
public string TRAINABLE_VARIABLES => TRAINABLE_VARIABLES_;

/// <summary>
/// Trainable resource-style variables.
/// </summary>
public string TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables";
public string TRAINABLE_RESOURCE_VARIABLES => TRAINABLE_RESOURCE_VARIABLES_;

/// <summary>
/// Key for streaming model ports.
/// </summary>
public string _STREAMING_MODEL_PORTS = "streaming_model_ports";
public string _STREAMING_MODEL_PORTS => _STREAMING_MODEL_PORTS_;

/// <summary>
/// Key to collect losses
/// </summary>
public string LOSSES = "losses";
public string LOSSES => LOSSES_;

/// <summary>
/// Key to collect Variable objects that are global (shared across machines).
/// Default collection for all variables, except local ones.
/// </summary>
public string GLOBAL_VARIABLES = "variables";
public string GLOBAL_VARIABLES => GLOBAL_VARIABLES_;

public string TRAIN_OP = "train_op";
public string TRAIN_OP => TRAIN_OP_;

public string GLOBAL_STEP = "global_step";
public string GLOBAL_STEP => GLOBAL_STEP_;

public string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables", "model_variables" };
public string[] _VARIABLE_COLLECTIONS => _VARIABLE_COLLECTIONS_;
/// <summary>
/// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing.
/// </summary>
public string SAVEABLE_OBJECTS = "saveable_objects";
public string SAVEABLE_OBJECTS => SAVEABLE_OBJECTS_;
/// <summary>
/// Key to collect update_ops
/// </summary>
public string UPDATE_OPS = "update_ops";
public string UPDATE_OPS => UPDATE_OPS_;

// Key to collect summaries.
public string SUMMARIES = "summaries";
public string SUMMARIES => SUMMARIES_;

// Used to store v2 summary names.
public string _SUMMARY_COLLECTION = "_SUMMARY_V2";
public string _SUMMARY_COLLECTION => _SUMMARY_COLLECTION_;

// Key for control flow context.
public string COND_CONTEXT = "cond_context";
public string WHILE_CONTEXT = "while_context";
public string COND_CONTEXT => COND_CONTEXT_;
public string WHILE_CONTEXT => WHILE_CONTEXT_;
}
}
}

+ 8
- 0
test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj View File

@@ -6,6 +6,14 @@
<GeneratePackageOnBuild>false</GeneratePackageOnBuild>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
<OutputPath>bin\debug-gpu</OutputPath>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'">
<OutputPath>bin\release-gpu</OutputPath>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Colorful.Console" Version="1.2.9" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.2" />


+ 17
- 17
test/TensorFlowNET.UnitTest/ConstantTest.cs View File

@@ -98,9 +98,9 @@ namespace TensorFlowNET.UnitTest
{
var result = sess.run(tensor);

Assert.AreEqual(result[0].shape[0], 3);
Assert.AreEqual(result[0].shape[1], 2);
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, result[0].Data<int>()));
Assert.AreEqual(result.shape[0], 3);
Assert.AreEqual(result.shape[1], 2);
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, result.Data<int>()));
}

// big size
@@ -109,13 +109,13 @@ namespace TensorFlowNET.UnitTest
{
var result = sess.run(tensor);

Assert.AreEqual(result[0].shape[0], 200);
Assert.AreEqual(result[0].shape[1], 100);
Assert.AreEqual(result.shape[0], 200);
Assert.AreEqual(result.shape[1], 100);

var data = result[0].Data<int>();
var data = result.Data<int>();
Assert.AreEqual(0, data[0]);
Assert.AreEqual(0, data[500]);
Assert.AreEqual(0, data[result[0].size - 1]);
Assert.AreEqual(0, data[result.size - 1]);
}
}

@@ -127,9 +127,9 @@ namespace TensorFlowNET.UnitTest
{
var result = sess.run(ones);

Assert.AreEqual(result[0].shape[0], 3);
Assert.AreEqual(result[0].shape[1], 2);
Assert.IsTrue(new[] { 1, 1, 1, 1, 1, 1 }.SequenceEqual(result[0].Data<int>()));
Assert.AreEqual(result.shape[0], 3);
Assert.AreEqual(result.shape[1], 2);
Assert.IsTrue(new[] { 1, 1, 1, 1, 1, 1 }.SequenceEqual(result.Data<int>()));
}
}

@@ -142,9 +142,9 @@ namespace TensorFlowNET.UnitTest
{
var result = sess.run(halfes);

Assert.AreEqual(result[0].shape[0], 3);
Assert.AreEqual(result[0].shape[1], 2);
Assert.IsTrue(new[] { .5, .5, .5, .5, .5, .5 }.SequenceEqual(result[0].Data<double>()));
Assert.AreEqual(result.shape[0], 3);
Assert.AreEqual(result.shape[1], 2);
Assert.IsTrue(new[] { .5, .5, .5, .5, .5, .5 }.SequenceEqual(result.Data<double>()));
}
}

@@ -161,10 +161,10 @@ namespace TensorFlowNET.UnitTest
using (var sess = tf.Session())
{
var result = sess.run(tensor);
var data = result[0].Data<int>();
var data = result.Data<int>();

Assert.AreEqual(result[0].shape[0], 2);
Assert.AreEqual(result[0].shape[1], 3);
Assert.AreEqual(result.shape[0], 2);
Assert.AreEqual(result.shape[1], 3);
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 1, 2, 1, 3 }, data));
}
}
@@ -177,7 +177,7 @@ namespace TensorFlowNET.UnitTest
var c = a * b;

var sess = tf.Session();
double result = sess.run(c)[0];
double result = sess.run(c);
sess.close();

Assert.AreEqual(6.0, result);


+ 3
- 3
test/TensorFlowNET.UnitTest/GradientTest.cs View File

@@ -41,7 +41,7 @@ namespace TensorFlowNET.UnitTest
var grad = tf.gradients(y, x);
Assert.AreEqual(grad[0].name, "gradients/AddN:0");

float r = sess.run(grad[0])[0];
float r = sess.run(grad[0]);
Assert.AreEqual(r, 1.4f);
}
}
@@ -57,7 +57,7 @@ namespace TensorFlowNET.UnitTest
var grad = tf.gradients(y, x);
Assert.AreEqual(grad[0].name, "gradients/AddN:0");

float r = sess.run(grad[0])[0];
float r = sess.run(grad[0]);
Assert.AreEqual(r, 14.700001f);
});
}
@@ -94,7 +94,7 @@ namespace TensorFlowNET.UnitTest

using (var sess = tf.Session(graph))
{
var r = sess.run(slice)[0];
var r = sess.run(slice);

Assert.IsTrue(Enumerable.SequenceEqual(r.shape, new[] { 2, 1, 2 }));
Assert.IsTrue(Enumerable.SequenceEqual(r[0].GetData<int>(), new[] { 11, 13 }));


+ 111
- 111
test/TensorFlowNET.UnitTest/OperationsTest.cs
File diff suppressed because it is too large
View File


+ 1
- 1
test/TensorFlowNET.UnitTest/PlaceholderTest.cs View File

@@ -17,7 +17,7 @@ namespace TensorFlowNET.UnitTest
{
var result = sess.run(y,
new FeedItem(x, 2));
Assert.AreEqual((int)result[0], 6);
Assert.AreEqual((int)result, 6);
}
}
}


+ 2
- 2
test/TensorFlowNET.UnitTest/SessionTest.cs View File

@@ -45,7 +45,7 @@ namespace TensorFlowNET.UnitTest
EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
EXPECT_EQ(0, outTensor.NDims);
ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
var output_contents = outTensor.Data<int>();
var output_contents = outTensor.ToArray<int>();
EXPECT_EQ(3 + 2, output_contents[0]);

// Add another operation to the graph.
@@ -66,7 +66,7 @@ namespace TensorFlowNET.UnitTest
EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
EXPECT_EQ(0, outTensor.NDims); // scalar
ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
output_contents = outTensor.Data<int>();
output_contents = outTensor.ToArray<int>();
EXPECT_EQ(-(7 + 2), output_contents[0]);

// Clean up


+ 1
- 1
test/TensorFlowNET.UnitTest/TensorTest.cs View File

@@ -112,7 +112,7 @@ namespace TensorFlowNET.UnitTest
var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3);

var tensor = new Tensor(nd);
var array = tensor.Data<float>();
var array = tensor.ToArray<float>();

EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT);
EXPECT_EQ(tensor.rank, nd.ndim);


+ 7
- 6
test/TensorFlowNET.UnitTest/VariableTest.cs View File

@@ -1,4 +1,5 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using Tensorflow;
using static Tensorflow.Binding;

@@ -16,7 +17,7 @@ namespace TensorFlowNET.UnitTest
{
session.run(x.initializer);
var result = session.run(x);
Assert.AreEqual(10, (int)result[0]);
Assert.AreEqual(10, (int)result);
}
}

@@ -81,7 +82,7 @@ namespace TensorFlowNET.UnitTest
using (var session = tf.Session())
{
session.run(model);
int result = session.run(y)[0];
int result = session.run(y);
Assert.AreEqual(result, 4);
}
}
@@ -97,12 +98,12 @@ namespace TensorFlowNET.UnitTest
var sess = tf.Session(graph);
sess.run(init);

var result = sess.run(variable);
Assert.IsTrue((int)result[0] == 31);
NDArray result = sess.run(variable);
Assert.IsTrue((int)result == 31);

var assign = variable.assign(12);
result = sess.run(assign);
Assert.IsTrue((int)result[0] == 12);
Assert.IsTrue((int)result == 12);
}

[TestMethod]
@@ -139,7 +140,7 @@ namespace TensorFlowNET.UnitTest
for(int i = 0; i < 5; i++)
{
x = x + 1;
result = session.run(x)[0];
result = session.run(x);
print(result);
}
}


+ 1
- 1
test/TensorFlowNET.UnitTest/nn_test/ZeroFractionTest.cs View File

@@ -31,7 +31,7 @@ namespace TensorFlowNET.UnitTest.nn_test
var y_np = this._ZeroFraction(x_np);
var x_tf = constant_op.constant(x_np);
x_tf.SetShape(x_shape);
x_tf.set_shape(x_shape);
var y_tf = nn_impl.zero_fraction(x_tf);
var y_tf_np = self.evaluate<NDArray>(y_tf);


Loading…
Cancel
Save