* Refactored DisposableObject * Added different build directory for TensorflowNET.Examples.GPU * _FetchHandler: Switched to NPTypeCode * gfile.cs, Walk(...): Handle case when directory top doesn't exist. * Tensor.Creation: Perf-opted when creating tensor from NDArray of string * Graph.cs: refactor and added docs * Tensor.Creation.cs: perf-ops * Tensor.Explicit.cs: perf-ops * Copied globals.regen from NumSharp - Added supported_numericals_TF_DataType * Tensor perf-ops and cleanup, Revamped dtypes.cs, some renames. - Cleanup and docs to all Tensor.cs files - Changed all uses of System.Convert to NumSharp.Utilities.Converts - Added all missing types in dtypes.cs - Renamed tensor.Data<T> to tensor.ToArray<T>, added obsolete message - Renamed tensor.Data() to tensor.BufferToArray(), added obsolete message - Made GraphKeys to use const string instead allocating strings at every use of GraphKeys. * Tensor: Added guards for explicit casts. * Tensor: Added explicit cast to string * Tensor.ToArray<T>(): Added support for cases when tensor is scalar. * Tensor.BufferToArray(): Fixed to use long instead of int. * TensorShape: Revamped and documented. * BaseSession: Added Session.run(ITensorOrOperation fetche, params FeedItem[] feed_dict) * Tensor: renamed _dtype to _override_dtype - Fixed all locations _dtype is used incorrectly. * Fixed unit tests * Tensor.Operations: Reverted commit * DisposableObject: sorted internal_dispose to properly handle Dispose() calls * Tensor.DisposeUnmanagedResources: Nullify _handle after delete. * TensorShape.this[...]: fixed guard check. * DisposableObject #362tags/v0.12
@@ -59,6 +59,6 @@ namespace Tensorflow | |||
} | |||
[DllImport(TensorFlowLibName)] | |||
public static unsafe extern IntPtr TF_Version(); | |||
public static extern IntPtr TF_Version(); | |||
} | |||
} |
@@ -308,15 +308,14 @@ namespace Tensorflow | |||
public static IEnumerable TupleToEnumerable(object tuple) | |||
{ | |||
Type t = tuple.GetType(); | |||
if(t.IsGenericType && (t.FullName.StartsWith("System.Tuple") || t.FullName.StartsWith("System.ValueTuple"))) | |||
if (t.IsGenericType && (t.FullName.StartsWith("System.Tuple") || t.FullName.StartsWith("System.ValueTuple"))) | |||
{ | |||
var flds = t.GetFields(); | |||
for(int i = 0; i < flds.Length;i++) | |||
for (int i = 0; i < flds.Length; i++) | |||
{ | |||
yield return flds[i].GetValue(tuple); | |||
} | |||
} | |||
else | |||
} else | |||
{ | |||
throw new System.Exception("Expected Tuple."); | |||
} | |||
@@ -329,12 +328,9 @@ namespace Tensorflow | |||
public static bool isinstance(object Item1, object tuple) | |||
{ | |||
var tup = TupleToEnumerable(tuple); | |||
foreach(var t in tup) | |||
{ | |||
if(isinstance(Item1, (Type)t)) | |||
foreach (var t in TupleToEnumerable(tuple)) | |||
if (isinstance(Item1, (Type) t)) | |||
return true; | |||
} | |||
return false; | |||
} | |||
} | |||
@@ -66,7 +66,7 @@ namespace Tensorflow | |||
return buffer.Data; | |||
} | |||
protected override void DisposeUnManagedState(IntPtr handle) | |||
protected override void DisposeUnmanagedResources(IntPtr handle) | |||
=> c_api.TF_DeleteBuffer(handle); | |||
} | |||
} |
@@ -29,18 +29,10 @@ namespace Tensorflow | |||
protected DisposableObject() { } | |||
public DisposableObject(IntPtr handle) | |||
{ | |||
_handle = handle; | |||
} | |||
protected virtual void DisposeManagedState() | |||
{ | |||
} | |||
protected DisposableObject(IntPtr handle) | |||
=> _handle = handle; | |||
protected abstract void DisposeUnManagedState(IntPtr handle); | |||
protected virtual void Dispose(bool disposing) | |||
private void internal_dispose(bool disposing) | |||
{ | |||
if (disposing) | |||
{ | |||
@@ -48,30 +40,43 @@ namespace Tensorflow | |||
if (_handle != IntPtr.Zero) | |||
{ | |||
// dispose managed state (managed objects). | |||
DisposeManagedState(); | |||
DisposeManagedResources(); | |||
// set large fields to null. | |||
DisposeUnManagedState(_handle); | |||
DisposeUnmanagedResources(_handle); | |||
_handle = IntPtr.Zero; | |||
} | |||
} | |||
} | |||
/// <summary> | |||
/// Dispose any managed resources. | |||
/// </summary> | |||
/// <remarks>Equivalent to what you would perform inside <see cref="Dispose()"/></remarks> | |||
protected virtual void DisposeManagedResources() | |||
{ | |||
} | |||
/// <summary> | |||
/// Dispose any unmanaged resources related to given <paramref name="handle"/>. | |||
/// </summary> | |||
protected abstract void DisposeUnmanagedResources(IntPtr handle); | |||
// override a finalizer only if Dispose(bool disposing) above has code to free unmanaged resources. | |||
~DisposableObject() | |||
{ | |||
// Do not change this code. Put cleanup code in Dispose(bool disposing) above. | |||
Dispose(false); | |||
internal_dispose(false); | |||
} | |||
// This code added to correctly implement the disposable pattern. | |||
public void Dispose() | |||
{ | |||
// Do not change this code. Put cleanup code in Dispose(bool disposing) above. | |||
Dispose(true); | |||
internal_dispose(true); | |||
// uncomment the following line if the finalizer is overridden above. | |||
GC.SuppressFinalize(this); | |||
} | |||
} | |||
} | |||
} |
@@ -1,8 +1,9 @@ | |||
using System; | |||
using System.IO; | |||
namespace Tensorflow.Eager | |||
{ | |||
public class ContextOptions : IDisposable | |||
public class ContextOptions : IDisposable //TODO! Eli: Shouldn't this inherieting DisposableObject? | |||
{ | |||
private IntPtr _handle; | |||
@@ -23,57 +23,58 @@ using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
/* | |||
A TensorFlow computation, represented as a dataflow graph. | |||
A `Graph` contains a set of | |||
`tf.Operation` objects, | |||
which represent units of computation; and | |||
`tf.Tensor` objects, which represent | |||
the units of data that flow between operations. | |||
A default `Graph` is always registered, and accessible by calling | |||
`tf.get_default_graph`. | |||
To add an operation to the default graph, simply call one of the functions | |||
that defines a new `Operation`: | |||
```python | |||
c = tf.constant(4.0) | |||
assert c.graph is tf.get_default_graph() | |||
``` | |||
Another typical usage involves the | |||
`tf.Graph.as_default` | |||
context manager, which overrides the current default graph for the | |||
lifetime of the context: | |||
```python | |||
g = tf.Graph() | |||
with g.as_default(): | |||
# Define operations and tensors in `g`. | |||
c = tf.constant(30.0) | |||
assert c.graph is g | |||
``` | |||
Important note: This class *is not* thread-safe for graph construction. All | |||
operations should be created from a single thread, or external | |||
synchronization must be provided. Unless otherwise specified, all methods | |||
are not thread-safe. | |||
A `Graph` instance supports an arbitrary number of "collections" | |||
that are identified by name. For convenience when building a large | |||
graph, collections can store groups of related objects: for | |||
example, the `tf.Variable` uses a collection (named | |||
`tf.GraphKeys.GLOBAL_VARIABLES`) for | |||
all variables that are created during the construction of a graph. The caller | |||
may define additional collections by specifying a new name. | |||
*/ | |||
/// <summary> | |||
/// TensorFlow uses a dataflow graph to represent your computation in terms of the dependencies between individual operations. | |||
/// This leads to a low-level programming model in which you first define the dataflow graph, | |||
/// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | |||
/// https://www.tensorflow.org/guide/graphs | |||
/// TensorFlow uses a dataflow graph to represent your computation in terms of the dependencies between individual operations. | |||
/// This leads to a low-level programming model in which you first define the dataflow graph, | |||
/// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | |||
/// </summary> | |||
/* | |||
A TensorFlow computation, represented as a dataflow graph. | |||
A `Graph` contains a set of | |||
`tf.Operation` objects, | |||
which represent units of computation; and | |||
`tf.Tensor` objects, which represent | |||
the units of data that flow between operations. | |||
A default `Graph` is always registered, and accessible by calling | |||
`tf.get_default_graph`. | |||
To add an operation to the default graph, simply call one of the functions | |||
that defines a new `Operation`: | |||
```python | |||
c = tf.constant(4.0) | |||
assert c.graph is tf.get_default_graph() | |||
``` | |||
Another typical usage involves the | |||
`tf.Graph.as_default` | |||
context manager, which overrides the current default graph for the | |||
lifetime of the context: | |||
```python | |||
g = tf.Graph() | |||
with g.as_default(): | |||
# Define operations and tensors in `g`. | |||
c = tf.constant(30.0) | |||
assert c.graph is g | |||
``` | |||
Important note: This class *is not* thread-safe for graph construction. All | |||
operations should be created from a single thread, or external | |||
synchronization must be provided. Unless otherwise specified, all methods | |||
are not thread-safe. | |||
A `Graph` instance supports an arbitrary number of "collections" | |||
that are identified by name. For convenience when building a large | |||
graph, collections can store groups of related objects: for | |||
example, the `tf.Variable` uses a collection (named | |||
`tf.GraphKeys.GLOBAL_VARIABLES`) for | |||
all variables that are created during the construction of a graph. The caller | |||
may define additional collections by specifying a new name. | |||
*/ | |||
/// <remarks>https://www.tensorflow.org/guide/graphs <br></br>https://www.tensorflow.org/api_docs/python/tf/Graph</remarks> | |||
public partial class Graph : DisposableObject, IEnumerable<Operation> | |||
{ | |||
private Dictionary<int, ITensorOrOperation> _nodes_by_id; | |||
@@ -439,12 +440,12 @@ namespace Tensorflow | |||
_unfetchable_ops.Add(op); | |||
} | |||
protected override void DisposeManagedState() | |||
protected override void DisposeManagedResources() | |||
{ | |||
ops.default_graph_stack.remove(this); | |||
} | |||
protected override void DisposeUnManagedState(IntPtr handle) | |||
protected override void DisposeUnmanagedResources(IntPtr handle) | |||
{ | |||
c_api.TF_DeleteGraph(handle); | |||
} | |||
@@ -37,7 +37,7 @@ namespace Tensorflow | |||
c_api.TF_ImportGraphDefOptionsAddReturnOutput(_handle, name, index); | |||
} | |||
protected override void DisposeUnManagedState(IntPtr handle) | |||
protected override void DisposeUnmanagedResources(IntPtr handle) | |||
=> c_api.TF_DeleteImportGraphDefOptions(handle); | |||
public static implicit operator IntPtr(ImportGraphDefOptions opts) => opts._handle; | |||
@@ -16,6 +16,7 @@ | |||
using System.Collections.Generic; | |||
using System.IO; | |||
using System.Linq; | |||
namespace Tensorflow.IO | |||
{ | |||
@@ -28,6 +29,9 @@ namespace Tensorflow.IO | |||
/// <param name="in_order">Traverse in order if True, post order if False.</param> | |||
public IEnumerable<(string, string[], string[])> Walk(string top, bool in_order = true) | |||
{ | |||
if (!Directory.Exists(top)) | |||
return Enumerable.Empty<(string, string[], string[])>(); | |||
return walk_v2(top, in_order); | |||
} | |||
@@ -141,7 +141,7 @@ namespace Tensorflow.Operations | |||
data, frame_name, is_constant, parallel_iterations, name: name); | |||
if (use_input_shape) | |||
result.SetShape(data.TensorShape); | |||
result.set_shape(data.TensorShape); | |||
return result; | |||
} | |||
@@ -233,7 +233,7 @@ namespace Tensorflow.Operations | |||
dims.AddRange(x_static_shape.dims.Skip(2)); | |||
var shape = new TensorShape(dims.ToArray()); | |||
x_t.SetShape(shape); | |||
x_t.set_shape(shape); | |||
return x_t; | |||
} | |||
@@ -351,7 +351,7 @@ namespace Tensorflow | |||
var input_shape = tensor_util.to_shape(input_tensor.shape); | |||
if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined()) | |||
{ | |||
var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_datatype()); | |||
var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_dtype()); | |||
return constant_op.constant(nd, name: name); | |||
} | |||
} | |||
@@ -98,7 +98,7 @@ namespace Tensorflow | |||
// float to be selected, hence we use a >= comparison. | |||
var keep_mask = random_tensor >= rate; | |||
var ret = x * scale * math_ops.cast(keep_mask, x.dtype); | |||
ret.SetShape(x.TensorShape); | |||
ret.set_shape(x.TensorShape); | |||
return ret; | |||
}); | |||
} | |||
@@ -49,7 +49,7 @@ namespace Tensorflow | |||
// dispose newOpts | |||
if (opts == null) | |||
c_api.TF_DeleteSessionOptions(newOpts); | |||
newOpts.Dispose(); | |||
status.Check(true); | |||
} | |||
@@ -64,6 +64,11 @@ namespace Tensorflow | |||
return _run(fetche, feed_dict)[0]; | |||
} | |||
public virtual NDArray run(ITensorOrOperation fetche, params FeedItem[] feed_dict) | |||
{ | |||
return _run(fetche, feed_dict)[0]; | |||
} | |||
public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
{ | |||
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); | |||
@@ -273,7 +278,7 @@ namespace Tensorflow | |||
{ | |||
var tensor = new Tensor(output); | |||
NDArray nd = null; | |||
Type type = tensor.dtype.as_numpy_datatype(); | |||
Type type = tensor.dtype.as_numpy_dtype(); | |||
var ndims = tensor.shape; | |||
var offset = c_api.TF_TensorData(output); | |||
@@ -285,7 +290,7 @@ namespace Tensorflow | |||
nd = NDArray.Scalar(*(bool*)offset); | |||
break; | |||
case TF_DataType.TF_STRING: | |||
var bytes = tensor.Data(); | |||
var bytes = tensor.BufferToArray(); | |||
// wired, don't know why we have to start from offset 9. | |||
// length in the begin | |||
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | |||
@@ -324,7 +329,7 @@ namespace Tensorflow | |||
nd = np.array(bools).reshape(ndims); | |||
break; | |||
case TF_DataType.TF_STRING: | |||
var bytes = tensor.Data(); | |||
var bytes = tensor.BufferToArray(); | |||
// wired, don't know why we have to start from offset 9. | |||
// length in the begin | |||
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]); | |||
@@ -396,7 +401,7 @@ namespace Tensorflow | |||
Dispose(); | |||
} | |||
protected override void DisposeUnManagedState(IntPtr handle) | |||
protected override void DisposeUnmanagedResources(IntPtr handle) | |||
{ | |||
using (var status = new Status()) | |||
{ | |||
@@ -32,7 +32,7 @@ namespace Tensorflow | |||
_handle = handle; | |||
} | |||
protected override void DisposeUnManagedState(IntPtr handle) | |||
protected override void DisposeUnmanagedResources(IntPtr handle) | |||
=> c_api.TF_DeleteSessionOptions(handle); | |||
public void SetConfig(ConfigProto config) | |||
@@ -17,6 +17,7 @@ | |||
using NumSharp; | |||
using System; | |||
using System.Collections.Generic; | |||
using NumSharp.Backends; | |||
namespace Tensorflow | |||
{ | |||
@@ -71,18 +72,18 @@ namespace Tensorflow | |||
{ | |||
if(tensor_values.Length > 0) | |||
{ | |||
switch (tensor_values[0].dtype.Name) | |||
switch (tensor_values[0].typecode) | |||
{ | |||
case "Int32": | |||
case NPTypeCode.Int32: | |||
full_values.Add(float.NaN); | |||
break; | |||
case "Single": | |||
case NPTypeCode.Single: | |||
full_values.Add(float.NaN); | |||
break; | |||
case "String": | |||
case NPTypeCode.String: | |||
full_values.Add(float.NaN); | |||
break; | |||
case "Char": | |||
case NPTypeCode.Char: | |||
full_values.Add(float.NaN); | |||
break; | |||
default: | |||
@@ -100,21 +101,21 @@ namespace Tensorflow | |||
j += 1; | |||
if (value.ndim == 0) | |||
{ | |||
switch (value.dtype.Name) | |||
switch (value.typecode) | |||
{ | |||
case "Int16": | |||
case NPTypeCode.Int16: | |||
full_values.Add(value.GetValue<short>(0)); | |||
break; | |||
case "Int32": | |||
case NPTypeCode.Int32: | |||
full_values.Add(value.GetValue<int>(0)); | |||
break; | |||
case "Int64": | |||
case NPTypeCode.Int64: | |||
full_values.Add(value.GetValue<long>(0)); | |||
break; | |||
case "Single": | |||
case NPTypeCode.Single: | |||
full_values.Add(value.GetValue<float>(0)); | |||
break; | |||
case "Double": | |||
case NPTypeCode.Double: | |||
full_values.Add(value.GetValue<double>(0)); | |||
break; | |||
/*case "String": | |||
@@ -50,7 +50,7 @@ namespace Tensorflow | |||
/// </summary> | |||
public void Check(bool throwException = false) | |||
{ | |||
if(Code != TF_Code.TF_OK) | |||
if (Code != TF_Code.TF_OK) | |||
{ | |||
Console.WriteLine(Message); | |||
if (throwException) | |||
@@ -65,7 +65,7 @@ namespace Tensorflow | |||
return status._handle; | |||
} | |||
protected override void DisposeUnManagedState(IntPtr handle) | |||
protected override void DisposeUnmanagedResources(IntPtr handle) | |||
=> c_api.TF_DeleteStatus(handle); | |||
} | |||
} | |||
} |
@@ -16,11 +16,13 @@ | |||
using NumSharp; | |||
using System; | |||
using System.Diagnostics.CodeAnalysis; | |||
using System.Linq; | |||
using System.Numerics; | |||
using System.Runtime.CompilerServices; | |||
using System.Runtime.InteropServices; | |||
using System.Text; | |||
using NumSharp.Backends; | |||
using NumSharp.Backends.Unmanaged; | |||
using static Tensorflow.c_api; | |||
@@ -462,7 +464,7 @@ namespace Tensorflow | |||
*v = value; | |||
_handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(Complex)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(Complex), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); | |||
IsMemoryOwner=true; | |||
} | |||
} | |||
#endif | |||
/// <summary> | |||
@@ -477,7 +479,7 @@ namespace Tensorflow | |||
IntPtr tensor = c_api.TF_TensorData(handle); | |||
Marshal.WriteInt64(tensor, 0); | |||
fixed (byte* src = &buffer[0]) | |||
fixed (byte* src = buffer) | |||
c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); | |||
_handle = handle; | |||
status.Check(true); | |||
@@ -486,35 +488,55 @@ namespace Tensorflow | |||
public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null) | |||
{ | |||
// todo: handle nd of type "String" here too | |||
if (tensorDType == TF_DataType.TF_STRING && nd.dtype.Name == "Byte") | |||
if (tensorDType == TF_DataType.TF_STRING && nd.typecode == NPTypeCode.Byte) | |||
{ | |||
var buffer = nd.ToArray<byte>(); | |||
var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); | |||
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); | |||
IntPtr tensor = c_api.TF_TensorData(handle); | |||
Marshal.WriteInt64(tensor, 0); | |||
if (nd.Unsafe.Storage.Shape.IsContiguous) | |||
{ | |||
var bytesLength = (UIntPtr)nd.size; | |||
var size = c_api.TF_StringEncodedSize(bytesLength); | |||
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); | |||
IntPtr tensor = c_api.TF_TensorData(handle); | |||
Marshal.WriteInt64(tensor, 0); | |||
var status = new Status(); | |||
c_api.TF_StringEncode((byte*) nd.Unsafe.Address, bytesLength, (sbyte*) (tensor + sizeof(Int64)), size, status); | |||
status.Check(true); | |||
_handle = handle; | |||
IsMemoryOwner = false; | |||
} | |||
else | |||
{ | |||
var buffer = nd.ToArray<byte>(); | |||
var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length); | |||
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); | |||
IntPtr tensor = c_api.TF_TensorData(handle); | |||
Marshal.WriteInt64(tensor, 0); | |||
var status = new Status(); | |||
fixed (byte* src = buffer) | |||
c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(Int64)), size, status); | |||
status.Check(true); | |||
_handle = handle; | |||
IsMemoryOwner = false; | |||
} | |||
var status = new Status(); | |||
fixed (byte* src = &buffer[0]) | |||
c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); | |||
status.Check(true); | |||
_handle=handle; | |||
IsMemoryOwner = false; | |||
return; | |||
} | |||
_handle = CreateTensorFromNDArray(nd, tensorDType); | |||
IsMemoryOwner = true; | |||
} | |||
private unsafe IntPtr CreateTensorFromNDArray(NDArray nd, TF_DataType? given_dtype) | |||
{ | |||
if (nd.dtype.Name == "String") | |||
throw new NotImplementedException("Support for NDArray of type string not implemented yet"); | |||
if (nd.dtype.Name == "String") | |||
throw new NotImplementedException("Support for NDArray of type string not implemented yet"); | |||
IArraySlice arraySlice; | |||
var shape = nd.Unsafe.Storage.Shape; | |||
if (shape.IsSliced || shape.IsBroadcasted) | |||
if (nd.Unsafe.Storage.Shape.IsContiguous == false) | |||
{ | |||
// the memory is NOT contiguous, so we have to copy the view into a contiguous memory block. | |||
arraySlice = nd.CloneData(); | |||
@@ -527,51 +549,52 @@ namespace Tensorflow | |||
this.Tag = arraySlice; // keep a reference to the memory block to make sure it is not disposed while TF is using it | |||
var ptr = new IntPtr(arraySlice.Address); | |||
int num_bytes = (nd.size * nd.dtypesize); | |||
var dtype = given_dtype ?? ToTFDataType(nd.dtype); | |||
var dtype = given_dtype ?? nd.dtype.as_dtype(); | |||
var handle = TF_NewTensor(dtype, dims: nd.shape.Select(i=>(long)i).ToArray(), num_dims: nd.ndim, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs); | |||
IsMemoryOwner = false; | |||
return handle; | |||
} | |||
public unsafe Tensor(byte[][] buffer, long[] shape) | |||
{ | |||
int size = 0; | |||
foreach (var b in buffer) | |||
{ | |||
size += (int)TF_StringEncodedSize((UIntPtr)b.Length); | |||
} | |||
int totalSize = size + buffer.Length * 8; | |||
ulong offset = 0; | |||
IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr)totalSize); | |||
// Clear offset table | |||
IntPtr pOffset = TF_TensorData(handle); | |||
IntPtr dst = pOffset + buffer.Length * 8; | |||
IntPtr dstLimit = pOffset + totalSize; | |||
for (int i = 0; i < buffer.Length; i++) | |||
{ | |||
Marshal.WriteInt64(pOffset, (long)offset); | |||
using (var status = new Status()) | |||
{ | |||
fixed (byte* src = &buffer[i][0]) | |||
{ | |||
var written = TF_StringEncode(src, (UIntPtr)buffer[i].Length, (sbyte*)dst, (UIntPtr)(dstLimit.ToInt64() - dst.ToInt64()), status); | |||
status.Check(true); | |||
pOffset += 8; | |||
dst += (int)written; | |||
offset += written; | |||
} | |||
} | |||
} | |||
_handle = handle; | |||
} | |||
public unsafe Tensor(byte[][] buffer, long[] shape) | |||
{ | |||
int size = 0; | |||
foreach (var b in buffer) | |||
{ | |||
size += (int)TF_StringEncodedSize((UIntPtr)b.Length); | |||
} | |||
int totalSize = size + buffer.Length * 8; | |||
ulong offset = 0; | |||
IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr)totalSize); | |||
// Clear offset table | |||
IntPtr pOffset = TF_TensorData(handle); | |||
IntPtr dst = pOffset + buffer.Length * 8; | |||
IntPtr dstLimit = pOffset + totalSize; | |||
for (int i = 0; i < buffer.Length; i++) | |||
{ | |||
Marshal.WriteInt64(pOffset, (long)offset); | |||
using (var status = new Status()) | |||
{ | |||
fixed (byte* src = &buffer[i][0]) | |||
{ | |||
var written = TF_StringEncode(src, (UIntPtr)buffer[i].Length, (sbyte*)dst, (UIntPtr)(dstLimit.ToInt64() - dst.ToInt64()), status); | |||
status.Check(true); | |||
pOffset += 8; | |||
dst += (int)written; | |||
offset += written; | |||
} | |||
} | |||
} | |||
_handle = handle; | |||
} | |||
public Tensor(Operation op, int value_index, TF_DataType dtype) | |||
{ | |||
_op = op; | |||
_value_index = value_index; | |||
_dtype = dtype; | |||
_override_dtype = dtype; | |||
_id = ops.uid(); | |||
} | |||
@@ -589,11 +612,11 @@ namespace Tensorflow | |||
/// specified dimensions. | |||
/// </remarks> | |||
[MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
[SuppressMessage("ReSharper", "LocalVariableHidesMember")] | |||
protected unsafe IntPtr CreateTensorWithoutCopying(TF_DataType dt, long[] shape, Array data, int element_size) | |||
{ | |||
if (dt == TF_DataType.TF_STRING && data is byte[]) | |||
if (dt == TF_DataType.TF_STRING && data is byte[] buffer) | |||
{ | |||
var buffer = (byte[])data; | |||
var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); | |||
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); | |||
@@ -1,4 +1,5 @@ | |||
using System; | |||
using System.Runtime.CompilerServices; | |||
namespace Tensorflow | |||
{ | |||
@@ -6,86 +7,142 @@ namespace Tensorflow | |||
{ | |||
public static explicit operator bool(Tensor tensor) | |||
{ | |||
EnsureScalar(tensor); | |||
return tensor.Data<bool>()[0]; | |||
unsafe | |||
{ | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_BOOL); | |||
return *(bool*) tensor.buffer; | |||
} | |||
} | |||
public static explicit operator sbyte(Tensor tensor) | |||
{ | |||
EnsureScalar(tensor); | |||
return tensor.Data<sbyte>()[0]; | |||
unsafe | |||
{ | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_INT8); | |||
return *(sbyte*) tensor.buffer; | |||
} | |||
} | |||
public static explicit operator byte(Tensor tensor) | |||
{ | |||
EnsureScalar(tensor); | |||
return tensor.Data<byte>()[0]; | |||
unsafe | |||
{ | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_UINT8); | |||
return *(byte*) tensor.buffer; | |||
} | |||
} | |||
public static explicit operator ushort(Tensor tensor) | |||
{ | |||
EnsureScalar(tensor); | |||
return tensor.Data<ushort>()[0]; | |||
unsafe | |||
{ | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_UINT16); | |||
return *(ushort*) tensor.buffer; | |||
} | |||
} | |||
public static explicit operator short(Tensor tensor) | |||
{ | |||
EnsureScalar(tensor); | |||
return tensor.Data<short>()[0]; | |||
unsafe | |||
{ | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_INT16); | |||
return *(short*) tensor.buffer; | |||
} | |||
} | |||
public static explicit operator int(Tensor tensor) | |||
{ | |||
EnsureScalar(tensor); | |||
return tensor.Data<int>()[0]; | |||
unsafe | |||
{ | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_INT32); | |||
return *(int*) tensor.buffer; | |||
} | |||
} | |||
public static explicit operator uint(Tensor tensor) | |||
{ | |||
EnsureScalar(tensor); | |||
return tensor.Data<uint>()[0]; | |||
unsafe | |||
{ | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_UINT32); | |||
return *(uint*) tensor.buffer; | |||
} | |||
} | |||
public static explicit operator long(Tensor tensor) | |||
{ | |||
EnsureScalar(tensor); | |||
return tensor.Data<long>()[0]; | |||
unsafe | |||
{ | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_INT64); | |||
return *(long*) tensor.buffer; | |||
} | |||
} | |||
public static explicit operator ulong(Tensor tensor) | |||
{ | |||
EnsureScalar(tensor); | |||
return tensor.Data<ulong>()[0]; | |||
unsafe | |||
{ | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_UINT64); | |||
return *(ulong*) tensor.buffer; | |||
} | |||
} | |||
public static explicit operator float(Tensor tensor) | |||
{ | |||
EnsureScalar(tensor); | |||
return tensor.Data<float>()[0]; | |||
unsafe | |||
{ | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_FLOAT); | |||
return *(float*) tensor.buffer; | |||
} | |||
} | |||
public static explicit operator double(Tensor tensor) | |||
{ | |||
EnsureScalar(tensor); | |||
return tensor.Data<double>()[0]; | |||
unsafe | |||
{ | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_DOUBLE); | |||
return *(double*) tensor.buffer; | |||
} | |||
} | |||
public static explicit operator string(Tensor tensor) | |||
{ | |||
unsafe | |||
{ | |||
EnsureScalar(tensor); | |||
EnsureDType(tensor, TF_DataType.TF_STRING); | |||
return new string((char*) tensor.buffer, 0, (int) tensor.size); | |||
} | |||
} | |||
[MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
private static void EnsureDType(Tensor tensor, TF_DataType @is) | |||
{ | |||
if (tensor.dtype != @is) | |||
throw new InvalidCastException($"Unable to cast scalar tensor {tensor.dtype} to {@is}"); | |||
} | |||
[MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
private static void EnsureScalar(Tensor tensor) | |||
{ | |||
if (tensor == null) | |||
{ | |||
throw new ArgumentNullException(nameof(tensor)); | |||
} | |||
if (tensor.TensorShape.ndim != 0) | |||
{ | |||
throw new ArgumentException("Tensor must have 0 dimensions in order to convert to scalar"); | |||
} | |||
if (tensor.TensorShape.size != 1) | |||
{ | |||
throw new ArgumentException("Tensor must have size 1 in order to convert to scalar"); | |||
} | |||
} | |||
} | |||
@@ -69,11 +69,12 @@ namespace Tensorflow | |||
TF_DataType.TF_QINT8, TF_DataType.TF_QINT16, TF_DataType.TF_QINT32, | |||
TF_DataType.TF_UINT8, TF_DataType.TF_UINT16, TF_DataType.TF_UINT32, TF_DataType.TF_UINT64 | |||
}; | |||
public static Tensor operator /(double x, Tensor y) => BinaryOpWrapper("truediv", x, y); | |||
public static Tensor operator /(float x, Tensor y) => BinaryOpWrapper("truediv", x, y); | |||
public static Tensor operator /(int x, Tensor y) => BinaryOpWrapper("floordiv", x, y); | |||
public static Tensor operator /(Tensor x, Tensor y) => | |||
_intTfDataTypes.Contains(x._dtype) | |||
_intTfDataTypes.Contains(x.dtype) | |||
? BinaryOpWrapper("floordiv", x, y) | |||
: BinaryOpWrapper("truediv", x, y); | |||
public static Tensor operator /(Tensor x, int y) => BinaryOpWrapper("floordiv", x, y); | |||
@@ -122,8 +123,7 @@ namespace Tensorflow | |||
if (y is Tensor tr) | |||
dtype = tr.dtype.as_base_dtype(); | |||
var namescope = ops.name_scope(null, name, new { x, y }); | |||
return tf_with(namescope, scope => | |||
return tf_with(ops.name_scope(null, name, new { x, y }), scope => | |||
{ | |||
Tensor result = null; | |||
var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x"); | |||
@@ -155,7 +155,6 @@ namespace Tensorflow | |||
return result; | |||
}); | |||
} | |||
} | |||
} |
@@ -17,9 +17,16 @@ | |||
using NumSharp; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics.CodeAnalysis; | |||
using System.Globalization; | |||
using System.Linq; | |||
using System.Runtime.CompilerServices; | |||
using System.Runtime.InteropServices; | |||
using System.Text; | |||
using System.Threading.Tasks; | |||
using NumSharp.Backends; | |||
using NumSharp.Backends.Unmanaged; | |||
using NumSharp.Utilities; | |||
using Tensorflow.Framework; | |||
using static Tensorflow.Binding; | |||
@@ -29,42 +36,68 @@ namespace Tensorflow | |||
/// A tensor is a generalization of vectors and matrices to potentially higher dimensions. | |||
/// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. | |||
/// </summary> | |||
[SuppressMessage("ReSharper", "ConvertToAutoProperty")] | |||
public partial class Tensor : DisposableObject, ITensorOrOperation, _TensorLike | |||
{ | |||
private int _id; | |||
private Operation _op; | |||
private readonly int _id; | |||
private readonly Operation _op; | |||
private readonly int _value_index; | |||
private TF_Output? _tf_output; | |||
private readonly TF_DataType _override_dtype; | |||
public int Id => _id; | |||
/// <summary> | |||
/// The Graph that contains this tensor. | |||
/// </summary> | |||
public Graph graph => op?.graph; | |||
/// <summary> | |||
/// The Operation that produces this tensor as an output. | |||
/// </summary> | |||
public Operation op => _op; | |||
public Tensor[] outputs => op.outputs; | |||
/// <summary> | |||
/// The string name of this tensor. | |||
/// The string name of this tensor. | |||
/// </summary> | |||
public string name => $"{(op == null ? "<unnamed Operation>" : $"{op.name}:{_value_index}")}"; | |||
private int _value_index; | |||
/// <summary> | |||
/// The index of this tensor in the outputs of its Operation. | |||
/// </summary> | |||
public int value_index => _value_index; | |||
private TF_DataType _dtype = TF_DataType.DtInvalid; | |||
public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle); | |||
/// <summary> | |||
/// The DType of elements in this tensor. | |||
/// </summary> | |||
public TF_DataType dtype => _handle == IntPtr.Zero ? _override_dtype : c_api.TF_TensorType(_handle); | |||
public ulong bytesize => _handle == IntPtr.Zero ? 0 : c_api.TF_TensorByteSize(_handle); | |||
public ulong itemsize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); | |||
public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / itemsize; | |||
public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); | |||
public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | |||
public int NDims => rank; | |||
private TF_Output? _tf_output; | |||
/// <summary> | |||
/// The name of the device on which this tensor will be produced, or null. | |||
/// </summary> | |||
public string Device => op.Device; | |||
public int[] dims => shape; | |||
/// <summary> | |||
/// used for keep other pointer when do implicit operating | |||
/// Used for keep other pointer when do implicit operating | |||
/// </summary> | |||
public object Tag { get; set; } | |||
/// <summary> | |||
/// Returns the shape of a tensor. | |||
/// </summary> | |||
/// <remarks>https://www.tensorflow.org/api_docs/python/tf/shape</remarks> | |||
public int[] shape | |||
{ | |||
get | |||
@@ -76,14 +109,13 @@ namespace Tensorflow | |||
var status = new Status(); | |||
c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); | |||
status.Check(); | |||
} | |||
else | |||
} else | |||
{ | |||
for (int i = 0; i < rank; i++) | |||
dims[i] = c_api.TF_Dim(_handle, i); | |||
} | |||
return dims.Select(x => Convert.ToInt32(x)).ToArray(); | |||
return dims.Select(x => ((IConvertible) x).ToInt32(CultureInfo.InvariantCulture)).ToArray(); | |||
} | |||
set | |||
@@ -93,38 +125,52 @@ namespace Tensorflow | |||
if (value == null) | |||
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); | |||
else | |||
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(x => Convert.ToInt64(x)).ToArray(), value.Length, status); | |||
c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), value.Select(Convert.ToInt64).ToArray(), value.Length, status); | |||
} | |||
} | |||
public int[] _shape_tuple() | |||
{ | |||
if (shape == null) return null; | |||
return shape.Select(x => (int)x).ToArray(); | |||
return (int[]) shape.Clone(); | |||
} | |||
public TensorShape TensorShape => tensor_util.to_shape(shape); | |||
public void SetShape(TensorShape shape) | |||
/// <summary> | |||
/// Updates the shape of this tensor. | |||
/// </summary> | |||
public void set_shape(TensorShape shape) | |||
{ | |||
this.shape = shape.dims; | |||
this.shape = (int[]) shape.dims.Clone(); | |||
} | |||
/// <summary> | |||
/// Updates the shape of this tensor. | |||
/// </summary> | |||
[Obsolete("Please use set_shape(TensorShape shape) instead.", false)] | |||
public void SetShape(TensorShape shape) | |||
{ | |||
this.shape = (int[]) shape.dims.Clone(); | |||
} | |||
/// <summary> | |||
/// Updates the shape of this tensor. | |||
/// </summary> | |||
public void set_shape(Tensor shape) | |||
{ | |||
// ReSharper disable once MergeConditionalExpression | |||
this.shape = shape is null ? null : shape.shape; | |||
} | |||
public int[] dims => shape; | |||
/// <summary> | |||
/// number of dimensions | |||
/// 0 Scalar (magnitude only) | |||
/// 1 Vector (magnitude and direction) | |||
/// 2 Matrix (table of numbers) | |||
/// 3 3-Tensor (cube of numbers) | |||
/// number of dimensions <br></br> | |||
/// 0 Scalar (magnitude only) <br></br> | |||
/// 1 Vector (magnitude and direction) <br></br> | |||
/// 2 Matrix (table of numbers) <br></br> | |||
/// 3 3-Tensor (cube of numbers) <br></br> | |||
/// n n-Tensor (you get the idea) | |||
/// </summary> | |||
/// <remarks>https://www.tensorflow.org/api_docs/python/tf/rank</remarks> | |||
public int rank | |||
{ | |||
get | |||
@@ -137,17 +183,15 @@ namespace Tensorflow | |||
status.Check(); | |||
return ndim; | |||
} | |||
else | |||
{ | |||
return c_api.TF_NumDims(_handle); | |||
} | |||
return c_api.TF_NumDims(_handle); | |||
} | |||
} | |||
public int NDims => rank; | |||
public string Device => op.Device; | |||
/// <summary> | |||
/// Returns a list of Operations that consume this tensor. | |||
/// </summary> | |||
/// <returns></returns> | |||
public Operation[] consumers() | |||
{ | |||
var output = _as_tf_output(); | |||
@@ -157,37 +201,191 @@ namespace Tensorflow | |||
public TF_Output _as_tf_output() | |||
{ | |||
if(!_tf_output.HasValue) | |||
if (!_tf_output.HasValue) | |||
_tf_output = new TF_Output(op, value_index); | |||
return _tf_output.Value; | |||
} | |||
public T[] Data<T>() | |||
[Obsolete("Please use ToArray<T>() instead.", false)] | |||
public T[] Data<T>() where T : unmanaged | |||
{ | |||
return ToArray<T>(); | |||
} | |||
/// <summary> | |||
/// | |||
/// </summary> | |||
/// <typeparam name="T"></typeparam> | |||
/// <returns></returns> | |||
/// <exception cref="ArgumentException">When <typeparam name="T"> is string </typeparam></exception> | |||
public T[] ToArray<T>() where T : unmanaged | |||
{ | |||
// Column major order | |||
// https://en.wikipedia.org/wiki/File:Row_and_column_major_order.svg | |||
// matrix:[[1, 2, 3], [4, 5, 6]] | |||
// index: 0 2 4 1 3 5 | |||
// result: 1 4 2 5 3 6 | |||
var data = new T[size]; | |||
for (ulong i = 0; i < size; i++) | |||
//when T is string | |||
if (typeof(T) == typeof(string)) | |||
{ | |||
data[i] = Marshal.PtrToStructure<T>(buffer + (int)(i * itemsize)); | |||
if (dtype != TF_DataType.TF_STRING) | |||
throw new ArgumentException($"Given <{typeof(T).Name}> can't be converted to string."); | |||
return (T[]) (object) StringData(); | |||
} | |||
return data; | |||
//Are the types matching? | |||
if (typeof(T).as_dtype() == dtype) | |||
{ | |||
if (NDims == 0 && size == 1) //is it a scalar? | |||
{ | |||
unsafe | |||
{ | |||
return new T[] {*(T*) buffer}; | |||
} | |||
} | |||
//types match, no need to perform cast | |||
var ret = new T[size]; | |||
unsafe | |||
{ | |||
var len = (long) size; | |||
fixed (T* dstRet = ret) | |||
{ | |||
T* dst = dstRet; //local stack copy | |||
if (typeof(T).IsPrimitive) | |||
{ | |||
var src = (T*) buffer; | |||
len *= ((long) itemsize); | |||
System.Buffer.MemoryCopy(src, dst, len, len); | |||
} else | |||
{ | |||
var itemsize = (long) this.itemsize; | |||
var buffer = this.buffer.ToInt64(); | |||
Parallel.For(0L, len, i => dst[i] = Marshal.PtrToStructure<T>(new IntPtr(buffer + i * itemsize))); | |||
} | |||
} | |||
} | |||
return ret; | |||
} else | |||
{ | |||
//types do not match, need to perform cast | |||
if (NDims == 0 && size == 1) //is it a scalar? | |||
{ | |||
unsafe | |||
{ | |||
#if _REGEN | |||
#region Compute | |||
switch (dtype.as_numpy_dtype().GetTypeCode()) | |||
{ | |||
%foreach supported_dtypes,supported_dtypes_lowercase% | |||
case NPTypeCode.#1: return new T[] {Converts.ChangeType<T>(*(#2*) buffer, NPTypeCode.#1)}; | |||
% | |||
case NPTypeCode.String: return new T[] {Converts.ChangeType<T>((string)this, NPTypeCode.String)}; | |||
default: | |||
throw new NotSupportedException(); | |||
} | |||
#endregion | |||
#else | |||
#region Compute | |||
switch (dtype.as_numpy_dtype()?.GetTypeCode()) | |||
{ | |||
case NPTypeCode.Boolean: return new T[] {Converts.ChangeType<T>(*(bool*) buffer, NPTypeCode.Boolean)}; | |||
case NPTypeCode.Byte: return new T[] {Converts.ChangeType<T>(*(byte*) buffer, NPTypeCode.Byte)}; | |||
case NPTypeCode.Int16: return new T[] {Converts.ChangeType<T>(*(short*) buffer, NPTypeCode.Int16)}; | |||
case NPTypeCode.UInt16: return new T[] {Converts.ChangeType<T>(*(ushort*) buffer, NPTypeCode.UInt16)}; | |||
case NPTypeCode.Int32: return new T[] {Converts.ChangeType<T>(*(int*) buffer, NPTypeCode.Int32)}; | |||
case NPTypeCode.UInt32: return new T[] {Converts.ChangeType<T>(*(uint*) buffer, NPTypeCode.UInt32)}; | |||
case NPTypeCode.Int64: return new T[] {Converts.ChangeType<T>(*(long*) buffer, NPTypeCode.Int64)}; | |||
case NPTypeCode.UInt64: return new T[] {Converts.ChangeType<T>(*(ulong*) buffer, NPTypeCode.UInt64)}; | |||
case NPTypeCode.Char: return new T[] {Converts.ChangeType<T>(*(char*) buffer, NPTypeCode.Char)}; | |||
case NPTypeCode.Double: return new T[] {Converts.ChangeType<T>(*(double*) buffer, NPTypeCode.Double)}; | |||
case NPTypeCode.Single: return new T[] {Converts.ChangeType<T>(*(float*) buffer, NPTypeCode.Single)}; | |||
case NPTypeCode.String: return new T[] {Converts.ChangeType<T>((string)this, NPTypeCode.String)}; | |||
default: | |||
throw new NotSupportedException(); | |||
} | |||
#endregion | |||
#endif | |||
} | |||
} | |||
var ret = new T[size]; | |||
unsafe | |||
{ | |||
var len = (long) size; | |||
fixed (T* dstRet = ret) | |||
{ | |||
T* dst = dstRet; //local stack copy | |||
#if _REGEN | |||
#region Compute | |||
switch (dtype.as_numpy_dtype().GetTypeCode()) | |||
{ | |||
%foreach supported_dtypes,supported_dtypes_lowercase% | |||
case NPTypeCode.#1: new UnmanagedMemoryBlock<#2>((#2*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
% | |||
default: | |||
throw new NotSupportedException(); | |||
} | |||
#endregion | |||
#else | |||
#region Compute | |||
switch (dtype.as_numpy_dtype().GetTypeCode()) | |||
{ | |||
case NPTypeCode.Boolean: new UnmanagedMemoryBlock<bool>((bool*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
case NPTypeCode.Byte: new UnmanagedMemoryBlock<byte>((byte*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
case NPTypeCode.Int16: new UnmanagedMemoryBlock<short>((short*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
case NPTypeCode.UInt16: new UnmanagedMemoryBlock<ushort>((ushort*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
case NPTypeCode.Int32: new UnmanagedMemoryBlock<int>((int*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
case NPTypeCode.UInt32: new UnmanagedMemoryBlock<uint>((uint*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
case NPTypeCode.Int64: new UnmanagedMemoryBlock<long>((long*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
case NPTypeCode.UInt64: new UnmanagedMemoryBlock<ulong>((ulong*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
case NPTypeCode.Char: new UnmanagedMemoryBlock<char>((char*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
case NPTypeCode.Double: new UnmanagedMemoryBlock<double>((double*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
case NPTypeCode.Single: new UnmanagedMemoryBlock<float>((float*) buffer, len).CastTo(new UnmanagedMemoryBlock<T>(dst, len), null, null); break; | |||
case NPTypeCode.String: throw new NotSupportedException("Unable to convert from string to other dtypes"); //TODO! this should call Converts.To<T> | |||
default: | |||
throw new NotSupportedException(); | |||
} | |||
#endregion | |||
#endif | |||
} | |||
} | |||
return ret; | |||
} | |||
} | |||
/// <summary> | |||
/// Copies the memory of current buffer onto newly allocated array. | |||
/// </summary> | |||
/// <returns></returns> | |||
[Obsolete("Please use set_shape(TensorShape shape) instead.", false)] | |||
public byte[] Data() | |||
{ | |||
var data = new byte[bytesize]; | |||
Marshal.Copy(buffer, data, 0, (int)bytesize); | |||
return data; | |||
return BufferToArray(); | |||
} | |||
/// <summary> | |||
/// Copies the memory of current buffer onto newly allocated array. | |||
/// </summary> | |||
/// <returns></returns> | |||
public byte[] BufferToArray() | |||
{ | |||
unsafe | |||
{ | |||
// ReSharper disable once LocalVariableHidesMember | |||
var bytesize = (long) this.bytesize; | |||
var data = new byte[bytesize]; | |||
fixed (byte* dst = data) | |||
System.Buffer.MemoryCopy(buffer.ToPointer(), dst, bytesize, bytesize); | |||
return data; | |||
} | |||
} | |||
public unsafe string[] StringData() | |||
/// Used internally in ToArray<T> | |||
private unsafe string[] StringData() | |||
{ | |||
// | |||
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes. | |||
@@ -199,19 +397,19 @@ namespace Tensorflow | |||
var buffer = new byte[size][]; | |||
var src = c_api.TF_TensorData(_handle); | |||
var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize); | |||
src += (int)(size * 8); | |||
var srcLen = (IntPtr) (src.ToInt64() + (long) bytesize); | |||
src += (int) (size * 8); | |||
for (int i = 0; i < buffer.Length; i++) | |||
{ | |||
using (var status = new Status()) | |||
{ | |||
IntPtr dst = IntPtr.Zero; | |||
UIntPtr dstLen = UIntPtr.Zero; | |||
var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, status); | |||
var read = c_api.TF_StringDecode((byte*) src, (UIntPtr) (srcLen.ToInt64() - src.ToInt64()), (byte**) &dst, &dstLen, status); | |||
status.Check(true); | |||
buffer[i] = new byte[(int)dstLen]; | |||
buffer[i] = new byte[(int) dstLen]; | |||
Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); | |||
src += (int)read; | |||
src += (int) read; | |||
} | |||
} | |||
@@ -229,51 +427,29 @@ namespace Tensorflow | |||
} | |||
/// <summary> | |||
/// Evaluates this tensor in a `Session`. | |||
/// Evaluates this tensor in a `Session`. | |||
/// </summary> | |||
/// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param> | |||
/// <param name="session">The `Session` to be used to evaluate this tensor.</param> | |||
/// <returns></returns> | |||
/// <returns>A <see cref="NumSharp"/> array corresponding to the value of this tensor.</returns> | |||
public NDArray eval(params FeedItem[] feed_dict) | |||
{ | |||
return ops._eval_using_default_session(this, feed_dict, graph); | |||
} | |||
/// <summary> | |||
/// Evaluates this tensor in a `Session`. | |||
/// </summary> | |||
/// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param> | |||
/// <param name="session">The `Session` to be used to evaluate this tensor.</param> | |||
/// <returns>A <see cref="NumSharp"/> array corresponding to the value of this tensor.</returns> | |||
public NDArray eval(Session session, FeedItem[] feed_dict = null) | |||
{ | |||
return ops._eval_using_default_session(this, feed_dict, graph, session); | |||
} | |||
public TF_DataType ToTFDataType(Type type) | |||
{ | |||
switch (type.Name) | |||
{ | |||
case "Char": | |||
return TF_DataType.TF_UINT8; | |||
case "Int16": | |||
return TF_DataType.TF_INT16; | |||
case "Int32": | |||
return TF_DataType.TF_INT32; | |||
case "Int64": | |||
return TF_DataType.TF_INT64; | |||
case "Single": | |||
return TF_DataType.TF_FLOAT; | |||
case "Double": | |||
return TF_DataType.TF_DOUBLE; | |||
case "Byte": | |||
return TF_DataType.TF_UINT8; | |||
case "String": | |||
return TF_DataType.TF_STRING; | |||
case "Boolean": | |||
return TF_DataType.TF_BOOL; | |||
default: | |||
throw new NotImplementedException("ToTFDataType error"); | |||
} | |||
} | |||
public Tensor slice(Slice slice) | |||
{ | |||
var slice_spec = new int[] { slice.Start.Value }; | |||
var slice_spec = new int[] {slice.Start.Value}; | |||
var begin = new List<int>(); | |||
var end = new List<int>(); | |||
var strides = new List<int>(); | |||
@@ -289,26 +465,26 @@ namespace Tensorflow | |||
if (slice.Stop.HasValue) | |||
{ | |||
end.Add(slice.Stop.Value); | |||
} | |||
else | |||
} else | |||
{ | |||
end.Add(0); | |||
end_mask |= (1 << index); | |||
} | |||
strides.Add(slice.Step); | |||
index += 1; | |||
} | |||
return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => | |||
return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope => | |||
{ | |||
string name = scope; | |||
if (begin != null) | |||
{ | |||
var (packed_begin, packed_end, packed_strides) = | |||
(array_ops.stack(begin.ToArray()), | |||
array_ops.stack(end.ToArray()), | |||
array_ops.stack(strides.ToArray())); | |||
array_ops.stack(end.ToArray()), | |||
array_ops.stack(strides.ToArray())); | |||
return gen_array_ops.strided_slice( | |||
this, | |||
@@ -320,7 +496,6 @@ namespace Tensorflow | |||
shrink_axis_mask: shrink_axis_mask, | |||
new_axis_mask: new_axis_mask, | |||
ellipsis_mask: ellipsis_mask, | |||
name: name); | |||
} | |||
@@ -330,7 +505,7 @@ namespace Tensorflow | |||
public Tensor slice(int start) | |||
{ | |||
var slice_spec = new int[] { start }; | |||
var slice_spec = new int[] {start}; | |||
var begin = new List<int>(); | |||
var end = new List<int>(); | |||
var strides = new List<int>(); | |||
@@ -349,15 +524,15 @@ namespace Tensorflow | |||
index += 1; | |||
} | |||
return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => | |||
return tf_with(ops.name_scope(null, "strided_slice", new {begin, end, strides}), scope => | |||
{ | |||
string name = scope; | |||
if (begin != null) | |||
{ | |||
var (packed_begin, packed_end, packed_strides) = | |||
(array_ops.stack(begin.ToArray()), | |||
array_ops.stack(end.ToArray()), | |||
array_ops.stack(strides.ToArray())); | |||
array_ops.stack(end.ToArray()), | |||
array_ops.stack(strides.ToArray())); | |||
return gen_array_ops.strided_slice( | |||
this, | |||
@@ -369,7 +544,6 @@ namespace Tensorflow | |||
shrink_axis_mask: shrink_axis_mask, | |||
new_axis_mask: new_axis_mask, | |||
ellipsis_mask: ellipsis_mask, | |||
name: name); | |||
} | |||
@@ -392,15 +566,12 @@ namespace Tensorflow | |||
return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}"; | |||
} | |||
protected override void DisposeManagedState() | |||
{ | |||
} | |||
protected override void DisposeUnManagedState(IntPtr handle) | |||
protected override void DisposeUnmanagedResources(IntPtr handle) | |||
{ | |||
if(handle != IntPtr.Zero) | |||
if (handle != IntPtr.Zero) | |||
{ | |||
c_api.TF_DeleteTensor(handle); | |||
_handle = IntPtr.Zero; | |||
} | |||
} | |||
@@ -417,4 +588,4 @@ namespace Tensorflow | |||
public int tensor_int_val { get; set; } | |||
} | |||
} | |||
} |
@@ -1,35 +1,84 @@ | |||
using NumSharp; | |||
using System; | |||
using System.Diagnostics.CodeAnalysis; | |||
using System.Linq; | |||
using System.Runtime.CompilerServices; | |||
namespace Tensorflow | |||
{ | |||
/// <summary> | |||
/// Represents the shape of a `Tensor`. | |||
/// Represents the shape of a `Tensor`. | |||
/// </summary> | |||
/// <remarks>https://www.tensorflow.org/api_docs/python/tf/TensorShape</remarks> | |||
public class TensorShape | |||
{ | |||
private Shape shape; | |||
private readonly Shape shape; | |||
/// <summary> | |||
/// Returns a list of Dimensions, or None if the shape is unspecified. | |||
/// </summary> | |||
public int[] dims => shape.Dimensions; | |||
/// <summary> | |||
/// Returns the rank of this shape. | |||
/// </summary> | |||
public int ndim => shape.NDim; | |||
/// <summary> | |||
/// Returns the rank of this shape. | |||
/// </summary> | |||
public int rank => shape.NDim; | |||
/// <summary> | |||
/// Returns the size this shape represents. | |||
/// </summary> | |||
public int size => shape.Size; | |||
public TensorShape(TensorShapeProto proto) | |||
{ | |||
if (proto.UnknownRank) return; | |||
switch (proto.Dim.Count) | |||
{ | |||
case 0: shape = new Shape(new int[0]); break; | |||
case 1: shape = Shape.Vector((int) proto.Dim[0].Size); break; | |||
case 2: shape = Shape.Matrix((int) proto.Dim[0].Size, (int) proto.Dim[1].Size); break; | |||
default: | |||
var protodims = proto.Dim; | |||
var len = protodims.Count; | |||
var dims = new int[len]; | |||
for (int i = 0; i < len; i++) | |||
dims[i] = (int) protodims[i].Size; | |||
shape.reshape(proto.Dim.Select(x => (int)x.Size).ToArray()); | |||
shape = new Shape(dims); break; | |||
} | |||
} | |||
public TensorShape(params int[] dims) | |||
{ | |||
shape = new Shape(dims); | |||
switch (dims.Length) | |||
{ | |||
case 0: shape = new Shape(new int[0]); break; | |||
case 1: shape = Shape.Vector((int) dims[0]); break; | |||
case 2: shape = Shape.Matrix(dims[0], dims[1]); break; | |||
default: shape = new Shape(dims); break; | |||
} | |||
} | |||
/// <summary> | |||
/// | |||
/// </summary> | |||
/// <param name="slice"></param> | |||
/// <returns></returns> | |||
/// <exception cref="ArgumentException">When <see cref="Slice"/> is not an Index.</exception> | |||
[SuppressMessage("ReSharper", "PossibleInvalidOperationException")] | |||
public TensorShape this[Slice slice] | |||
{ | |||
get | |||
{ | |||
if (slice.Start.HasValue == false || slice.Length.HasValue == false) | |||
throw new ArgumentException("Slice must has Start and Length."); | |||
return new TensorShape(dims.Skip(slice.Start.Value) | |||
.Take(slice.Length.Value) | |||
.ToArray()); | |||
@@ -37,7 +86,7 @@ namespace Tensorflow | |||
} | |||
/// <summary> | |||
/// Returns True iff `self` is fully defined in every dimension. | |||
/// Returns True iff `self` is fully defined in every dimension. | |||
/// </summary> | |||
/// <returns></returns> | |||
public bool is_fully_defined() | |||
@@ -50,6 +99,7 @@ namespace Tensorflow | |||
throw new NotImplementedException("TensorShape is_compatible_with"); | |||
} | |||
[SuppressMessage("ReSharper", "ParameterHidesMember")] | |||
public TensorShape with_rank_at_least(int rank) | |||
{ | |||
if (rank != ndim) | |||
@@ -59,35 +109,63 @@ namespace Tensorflow | |||
} | |||
/// <summary> | |||
/// Returns the concatenation of the dimension in `self` and `other`. | |||
/// Returns the concatenation of the dimension in `self` and `other`. | |||
/// </summary> | |||
/// <param name="other"></param> | |||
/// <returns></returns> | |||
public TensorShape concatenate(int[] other_) | |||
[MethodImpl(MethodImplOptions.AggressiveInlining)] | |||
public TensorShape concatenate(int[] other) | |||
{ | |||
var other = new TensorShape(other_); | |||
return concatenate(new TensorShape(other)); | |||
} | |||
if (ndim < 0 || other.ndim < 0) | |||
/// <summary> | |||
/// Returns the concatenation of the dimension in `self` and `other`. | |||
/// </summary> | |||
/// <param name="other"></param> | |||
/// <returns></returns> | |||
public TensorShape concatenate(TensorShape other) | |||
{ | |||
var otherShape = other; | |||
if (ndim < 0 || otherShape.ndim < 0) | |||
return new TensorShape(); | |||
else | |||
{ | |||
var concatenate_dims = new int[ndim + other.ndim]; | |||
var concatenate_dims = new int[ndim + otherShape.ndim]; | |||
for (int i = 0; i < ndim; i++) | |||
concatenate_dims[i] = dims[i]; | |||
for (int i = 0; i < other.ndim; i++) | |||
concatenate_dims[ndim + i] = other.dims[i]; | |||
for (int i = 0; i < otherShape.ndim; i++) | |||
concatenate_dims[ndim + i] = otherShape.dims[i]; | |||
return new TensorShape(concatenate_dims); | |||
} | |||
} | |||
public static implicit operator TensorShape(Shape shape) => new TensorShape(shape.Dimensions); | |||
public static implicit operator Shape(TensorShape shape) => new Shape(shape.dims); | |||
public static implicit operator TensorShape(Shape shape) => new TensorShape((int[]) shape.Dimensions.Clone()); | |||
public static implicit operator Shape(TensorShape shape) => new Shape((int[]) shape.dims.Clone()); | |||
public static implicit operator int[](TensorShape shape) => (int[])shape.dims.Clone(); //we clone to avoid any changes | |||
public static implicit operator TensorShape(int[] dims) => new TensorShape(dims); | |||
public static implicit operator int[](TensorShape shape) => shape.dims; | |||
public static explicit operator int(TensorShape shape) => shape.size; | |||
public static explicit operator TensorShape(int dim) => new TensorShape(dim); | |||
public static explicit operator (int, int)(TensorShape shape) => shape.dims.Length == 2 ? (shape.dims[0], shape.dims[1]) : (0, 0); | |||
public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2); | |||
public static explicit operator (int, int, int)(TensorShape shape) => shape.dims.Length == 3 ? (shape.dims[0], shape.dims[1], shape.dims[2]) : (0, 0, 0); | |||
public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3); | |||
public static explicit operator (int, int, int, int)(TensorShape shape) => shape.dims.Length == 4 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3]) : (0, 0, 0, 0); | |||
public static implicit operator TensorShape((int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4); | |||
public static explicit operator (int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 5 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4]) : (0, 0, 0, 0, 0); | |||
public static implicit operator TensorShape((int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5); | |||
public static explicit operator (int, int, int, int, int, int)(TensorShape shape) => shape.dims.Length == 6 ? (shape.dims[0], shape.dims[1], shape.dims[2], shape.dims[3], shape.dims[4], shape.dims[5]) : (0, 0, 0, 0, 0, 0); | |||
public static implicit operator TensorShape((int, int, int, int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3, dims.Item4, dims.Item5, dims.Item6); | |||
} | |||
} |
@@ -15,6 +15,8 @@ | |||
******************************************************************************/ | |||
using System; | |||
using System.Numerics; | |||
using NumSharp.Backends; | |||
namespace Tensorflow | |||
{ | |||
@@ -23,35 +25,100 @@ namespace Tensorflow | |||
public static TF_DataType int8 = TF_DataType.TF_INT8; | |||
public static TF_DataType int32 = TF_DataType.TF_INT32; | |||
public static TF_DataType int64 = TF_DataType.TF_INT64; | |||
public static TF_DataType uint8 = TF_DataType.TF_UINT8; | |||
public static TF_DataType uint32 = TF_DataType.TF_UINT32; | |||
public static TF_DataType uint64 = TF_DataType.TF_UINT64; | |||
public static TF_DataType float32 = TF_DataType.TF_FLOAT; // is that float32? | |||
public static TF_DataType float16 = TF_DataType.TF_HALF; | |||
public static TF_DataType float64 = TF_DataType.TF_DOUBLE; | |||
public static Type as_numpy_datatype(this TF_DataType type) | |||
/// <summary> | |||
/// | |||
/// </summary> | |||
/// <param name="type"></param> | |||
/// <returns><see cref="System.Type"/> equivalent to <paramref name="type"/>, if none exists, returns null.</returns> | |||
public static Type as_numpy_dtype(this TF_DataType type) | |||
{ | |||
switch (type) | |||
{ | |||
case TF_DataType.TF_BOOL: | |||
return typeof(bool); | |||
case TF_DataType.TF_UINT8: | |||
return typeof(byte); | |||
case TF_DataType.TF_INT64: | |||
return typeof(long); | |||
case TF_DataType.TF_UINT64: | |||
return typeof(ulong); | |||
case TF_DataType.TF_INT32: | |||
return typeof(int); | |||
case TF_DataType.TF_UINT32: | |||
return typeof(uint); | |||
case TF_DataType.TF_INT16: | |||
return typeof(short); | |||
case TF_DataType.TF_UINT16: | |||
return typeof(ushort); | |||
case TF_DataType.TF_FLOAT: | |||
return typeof(float); | |||
case TF_DataType.TF_DOUBLE: | |||
return typeof(double); | |||
case TF_DataType.TF_STRING: | |||
return typeof(string); | |||
case TF_DataType.TF_COMPLEX128: | |||
case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX | |||
return typeof(Complex); | |||
default: | |||
return null; | |||
} | |||
} | |||
// "sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex" | |||
public static TF_DataType as_dtype(Type type, TF_DataType? dtype = null) | |||
/// <summary> | |||
/// | |||
/// </summary> | |||
/// <param name="type"></param> | |||
/// <returns></returns> | |||
/// <exception cref="ArgumentException">When <paramref name="type"/> has no equivalent <see cref="NPTypeCode"/></exception> | |||
public static NPTypeCode as_numpy_typecode(this TF_DataType type) | |||
{ | |||
switch (type) | |||
{ | |||
case TF_DataType.TF_BOOL: | |||
return NPTypeCode.Boolean; | |||
case TF_DataType.TF_UINT8: | |||
return NPTypeCode.Byte; | |||
case TF_DataType.TF_INT64: | |||
return NPTypeCode.Int64; | |||
case TF_DataType.TF_INT32: | |||
return NPTypeCode.Int32; | |||
case TF_DataType.TF_INT16: | |||
return NPTypeCode.Int16; | |||
case TF_DataType.TF_UINT64: | |||
return NPTypeCode.UInt64; | |||
case TF_DataType.TF_UINT32: | |||
return NPTypeCode.UInt32; | |||
case TF_DataType.TF_UINT16: | |||
return NPTypeCode.UInt16; | |||
case TF_DataType.TF_FLOAT: | |||
return NPTypeCode.Single; | |||
case TF_DataType.TF_DOUBLE: | |||
return NPTypeCode.Double; | |||
case TF_DataType.TF_STRING: | |||
return NPTypeCode.String; | |||
case TF_DataType.TF_COMPLEX128: | |||
case TF_DataType.TF_COMPLEX64: //64 is also TF_COMPLEX | |||
return NPTypeCode.Complex; | |||
default: | |||
throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode."); | |||
} | |||
} | |||
/// <summary> | |||
/// | |||
/// </summary> | |||
/// <param name="type"></param> | |||
/// <param name="dtype"></param> | |||
/// <returns></returns> | |||
/// <exception cref="ArgumentException">When <paramref name="type"/> has no equivalent <see cref="TF_DataType"/></exception> | |||
public static TF_DataType as_dtype(this Type type, TF_DataType? dtype = null) | |||
{ | |||
switch (type.Name) | |||
{ | |||
@@ -98,7 +165,7 @@ namespace Tensorflow | |||
dtype = TF_DataType.TF_BOOL; | |||
break; | |||
default: | |||
throw new Exception("as_dtype Not Implemented"); | |||
throw new NotSupportedException($"Unable to convert {type} to a NumSharp typecode."); | |||
} | |||
return dtype.Value; | |||
@@ -106,16 +173,7 @@ namespace Tensorflow | |||
public static DataType as_datatype_enum(this TF_DataType type) | |||
{ | |||
DataType dtype = DataType.DtInvalid; | |||
switch (type) | |||
{ | |||
default: | |||
Enum.TryParse(((int)type).ToString(), out dtype); | |||
break; | |||
} | |||
return dtype; | |||
return Enum.TryParse(((int) type).ToString(), out DataType dtype) ? dtype : DataType.DtInvalid; | |||
} | |||
public static TF_DataType as_base_dtype(this TF_DataType type) | |||
@@ -132,7 +190,7 @@ namespace Tensorflow | |||
public static Type as_numpy_dtype(this DataType type) | |||
{ | |||
return type.as_tf_dtype().as_numpy_datatype(); | |||
return type.as_tf_dtype().as_numpy_dtype(); | |||
} | |||
public static DataType as_base_dtype(this DataType type) | |||
@@ -144,16 +202,7 @@ namespace Tensorflow | |||
public static TF_DataType as_tf_dtype(this DataType type) | |||
{ | |||
TF_DataType dtype = TF_DataType.DtInvalid; | |||
switch (type) | |||
{ | |||
default: | |||
Enum.TryParse(((int)type).ToString(), out dtype); | |||
break; | |||
} | |||
return dtype; | |||
return Enum.TryParse(((int) type).ToString(), out TF_DataType dtype) ? dtype : TF_DataType.DtInvalid; | |||
} | |||
public static TF_DataType as_ref(this TF_DataType type) | |||
@@ -17,6 +17,7 @@ | |||
using NumSharp; | |||
using System; | |||
using System.Linq; | |||
using NumSharp.Utilities; | |||
namespace Tensorflow | |||
{ | |||
@@ -109,7 +110,7 @@ namespace Tensorflow | |||
// We first convert value to a numpy array or scalar. | |||
NDArray nparray = null; | |||
var np_dt = dtype.as_numpy_datatype(); | |||
var np_dt = dtype.as_numpy_dtype(); | |||
if (values is NDArray nd) | |||
{ | |||
@@ -188,37 +189,37 @@ namespace Tensorflow | |||
if (values.GetType().IsArray) | |||
nparray = np.array((int[])values, np_dt); | |||
else | |||
nparray = Convert.ToInt32(values); | |||
nparray = Converts.ToInt32(values); | |||
break; | |||
case "Int64": | |||
if (values.GetType().IsArray) | |||
nparray = np.array((int[])values, np_dt); | |||
else | |||
nparray = Convert.ToInt64(values); | |||
nparray = Converts.ToInt64(values); | |||
break; | |||
case "Single": | |||
if (values.GetType().IsArray) | |||
nparray = np.array((float[])values, np_dt); | |||
else | |||
nparray = Convert.ToSingle(values); | |||
nparray = Converts.ToSingle(values); | |||
break; | |||
case "Double": | |||
if (values.GetType().IsArray) | |||
nparray = np.array((double[])values, np_dt); | |||
else | |||
nparray = Convert.ToDouble(values); | |||
nparray = Converts.ToDouble(values); | |||
break; | |||
case "String": | |||
if (values.GetType().IsArray) | |||
nparray = np.array((string[])values, np_dt); | |||
else | |||
nparray = NDArray.FromString(Convert.ToString(values)); | |||
nparray = NDArray.FromString(Converts.ToString(values)); | |||
break; | |||
case "Boolean": | |||
if (values.GetType().IsArray) | |||
nparray = np.array((bool[])values, np_dt); | |||
else | |||
nparray = Convert.ToBoolean(values); | |||
nparray = Converts.ToBoolean(values); | |||
break; | |||
default: | |||
throw new NotImplementedException($"make_tensor_proto: Support for type {np_dt.Name} Not Implemented"); | |||
@@ -0,0 +1,38 @@ | |||
%all_dtypes = ["NDArray","Complex","Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single","String"] | |||
%all_dtypes_lowercase = ["NDArray","Complex","bool","byte","short","ushort","int","uint","long","ulong","char","double","float","string"] | |||
%supported_primitives = ["Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single","String"] | |||
%supported_primitives_lowercase = ["bool","byte","short","ushort","int","uint","long","ulong","char","double","float","string"] | |||
%supported_numericals = ["Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single"] | |||
%supported_numericals_lowercase = ["byte","short","ushort","int","uint","long","ulong","char","double","float"] | |||
%supported_numericals_defaultvals = ["0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] | |||
%supported_numericals_onevales = ["1","1","1","1","1u","1L","1UL",1,"1d","1f"] | |||
%supported_numericals_TF_DataType = ["TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_UINT8","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"] | |||
//this is the type we use in summerizing/reducting: | |||
%supported_numericals_accumulatingType = ["UInt32","Int32","UInt32","Int32","UInt32","Int64","UInt64","UInt32","Double","Single"] | |||
%supported_numericals_accumulatingType_defaultvals = ["0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] | |||
%supported_numericals_signed = ["Int16","Int32","Int64","Double","Single"] | |||
%supported_numericals_signed_lowercase = ["short","int","long","double","float"] | |||
%supported_numericals_signed_defaultvals = ["0","0","0L","0d","0f"] | |||
%supported_numericals_signed_onevales = ["1","1","1L","1d","1f"] | |||
%supported_numericals_unsigned = ["Byte","UInt16","UInt32","UInt64","Char"] | |||
%supported_numericals_unsigned_lowercase = ["byte","ushort","uint","ulong","char"] | |||
%supported_numericals_unsigned_defaultvals = ["0","0","0U","0UL","'\0'"] | |||
%supported_numericals_unsigned_onevales = ["1","1","1U","1UL","'\1'"] | |||
%supported_dtypes = ["Boolean","Byte","Int16","UInt16","Int32","UInt32","Int64","UInt64","Char","Double","Single"] | |||
%supported_numericals_TF_DataType = ["TF_DataType.TF_UINT8","TF_DataType.TF_INT16","TF_DataType.TF_UINT16","TF_DataType.TF_INT32","TF_DataType.TF_UINT32","TF_DataType.TF_INT64","TF_DataType.TF_UINT64","TF_DataType.TF_UINT8","TF_DataType.TF_DOUBLE","TF_DataType.TF_FLOAT"] | |||
%supported_dtypes_lowercase = ["bool","byte","short","ushort","int","uint","long","ulong","char","double","float"] | |||
%supported_dtypes_defaultvals = [false,"0","0","0","0","0u","0L","0UL","'\0'","0d","0f"] | |||
%supported_dtypes_onevales = [true,"1","1","1","1","1u","1L","1UL","'\1'","1d","1f"] | |||
%supported_dtypes_dtype = ["bool","uint8","int16","uint16","int32","uint32","int64","uint64","uint8","float64","float32"] | |||
//this is the type we use in summerizing/reducting: | |||
%supported_dtypes_accumulatingType = ["Int32","UInt32","Int32","UInt32","Int32","UInt32","Int64","UInt64","UInt32","Double","Single"] | |||
%supported_dtypes_accumulatingType_defaultvals = [false, "0","0","0","0u","0L","0UL","'\0'","0d","0f"] | |||
@@ -29,55 +29,111 @@ namespace Tensorflow | |||
/// </summary> | |||
public class GraphKeys | |||
{ | |||
#region const | |||
/// <summary> | |||
/// the subset of `Variable` objects that will be trained by an optimizer. | |||
/// </summary> | |||
public const string TRAINABLE_VARIABLES_ = "trainable_variables"; | |||
/// <summary> | |||
/// Trainable resource-style variables. | |||
/// </summary> | |||
public const string TRAINABLE_RESOURCE_VARIABLES_ = "trainable_resource_variables"; | |||
/// <summary> | |||
/// Key for streaming model ports. | |||
/// </summary> | |||
public const string _STREAMING_MODEL_PORTS_ = "streaming_model_ports"; | |||
/// <summary> | |||
/// Key to collect losses | |||
/// </summary> | |||
public const string LOSSES_ = "losses"; | |||
/// <summary> | |||
/// Key to collect Variable objects that are global (shared across machines). | |||
/// Default collection for all variables, except local ones. | |||
/// </summary> | |||
public const string GLOBAL_VARIABLES_ = "variables"; | |||
public const string TRAIN_OP_ = "train_op"; | |||
public const string GLOBAL_STEP_ = "global_step"; | |||
public string[] _VARIABLE_COLLECTIONS_ = new string[] { "variables", "trainable_variables", "model_variables" }; | |||
/// <summary> | |||
/// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | |||
/// </summary> | |||
public const string SAVEABLE_OBJECTS_ = "saveable_objects"; | |||
/// <summary> | |||
/// Key to collect update_ops | |||
/// </summary> | |||
public const string UPDATE_OPS_ = "update_ops"; | |||
// Key to collect summaries. | |||
public const string SUMMARIES_ = "summaries"; | |||
// Used to store v2 summary names. | |||
public const string _SUMMARY_COLLECTION_ = "_SUMMARY_V2"; | |||
// Key for control flow context. | |||
public const string COND_CONTEXT_ = "cond_context"; | |||
public const string WHILE_CONTEXT_ = "while_context"; | |||
#endregion | |||
/// <summary> | |||
/// the subset of `Variable` objects that will be trained by an optimizer. | |||
/// </summary> | |||
public string TRAINABLE_VARIABLES = "trainable_variables"; | |||
public string TRAINABLE_VARIABLES => TRAINABLE_VARIABLES_; | |||
/// <summary> | |||
/// Trainable resource-style variables. | |||
/// </summary> | |||
public string TRAINABLE_RESOURCE_VARIABLES = "trainable_resource_variables"; | |||
public string TRAINABLE_RESOURCE_VARIABLES => TRAINABLE_RESOURCE_VARIABLES_; | |||
/// <summary> | |||
/// Key for streaming model ports. | |||
/// </summary> | |||
public string _STREAMING_MODEL_PORTS = "streaming_model_ports"; | |||
public string _STREAMING_MODEL_PORTS => _STREAMING_MODEL_PORTS_; | |||
/// <summary> | |||
/// Key to collect losses | |||
/// </summary> | |||
public string LOSSES = "losses"; | |||
public string LOSSES => LOSSES_; | |||
/// <summary> | |||
/// Key to collect Variable objects that are global (shared across machines). | |||
/// Default collection for all variables, except local ones. | |||
/// </summary> | |||
public string GLOBAL_VARIABLES = "variables"; | |||
public string GLOBAL_VARIABLES => GLOBAL_VARIABLES_; | |||
public string TRAIN_OP = "train_op"; | |||
public string TRAIN_OP => TRAIN_OP_; | |||
public string GLOBAL_STEP = "global_step"; | |||
public string GLOBAL_STEP => GLOBAL_STEP_; | |||
public string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables", "model_variables" }; | |||
public string[] _VARIABLE_COLLECTIONS => _VARIABLE_COLLECTIONS_; | |||
/// <summary> | |||
/// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. | |||
/// </summary> | |||
public string SAVEABLE_OBJECTS = "saveable_objects"; | |||
public string SAVEABLE_OBJECTS => SAVEABLE_OBJECTS_; | |||
/// <summary> | |||
/// Key to collect update_ops | |||
/// </summary> | |||
public string UPDATE_OPS = "update_ops"; | |||
public string UPDATE_OPS => UPDATE_OPS_; | |||
// Key to collect summaries. | |||
public string SUMMARIES = "summaries"; | |||
public string SUMMARIES => SUMMARIES_; | |||
// Used to store v2 summary names. | |||
public string _SUMMARY_COLLECTION = "_SUMMARY_V2"; | |||
public string _SUMMARY_COLLECTION => _SUMMARY_COLLECTION_; | |||
// Key for control flow context. | |||
public string COND_CONTEXT = "cond_context"; | |||
public string WHILE_CONTEXT = "while_context"; | |||
public string COND_CONTEXT => COND_CONTEXT_; | |||
public string WHILE_CONTEXT => WHILE_CONTEXT_; | |||
} | |||
} | |||
} |
@@ -6,6 +6,14 @@ | |||
<GeneratePackageOnBuild>false</GeneratePackageOnBuild> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||
<OutputPath>bin\debug-gpu</OutputPath> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'"> | |||
<OutputPath>bin\release-gpu</OutputPath> | |||
</PropertyGroup> | |||
<ItemGroup> | |||
<PackageReference Include="Colorful.Console" Version="1.2.9" /> | |||
<PackageReference Include="Newtonsoft.Json" Version="12.0.2" /> | |||
@@ -98,9 +98,9 @@ namespace TensorFlowNET.UnitTest | |||
{ | |||
var result = sess.run(tensor); | |||
Assert.AreEqual(result[0].shape[0], 3); | |||
Assert.AreEqual(result[0].shape[1], 2); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, result[0].Data<int>())); | |||
Assert.AreEqual(result.shape[0], 3); | |||
Assert.AreEqual(result.shape[1], 2); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, result.Data<int>())); | |||
} | |||
// big size | |||
@@ -109,13 +109,13 @@ namespace TensorFlowNET.UnitTest | |||
{ | |||
var result = sess.run(tensor); | |||
Assert.AreEqual(result[0].shape[0], 200); | |||
Assert.AreEqual(result[0].shape[1], 100); | |||
Assert.AreEqual(result.shape[0], 200); | |||
Assert.AreEqual(result.shape[1], 100); | |||
var data = result[0].Data<int>(); | |||
var data = result.Data<int>(); | |||
Assert.AreEqual(0, data[0]); | |||
Assert.AreEqual(0, data[500]); | |||
Assert.AreEqual(0, data[result[0].size - 1]); | |||
Assert.AreEqual(0, data[result.size - 1]); | |||
} | |||
} | |||
@@ -127,9 +127,9 @@ namespace TensorFlowNET.UnitTest | |||
{ | |||
var result = sess.run(ones); | |||
Assert.AreEqual(result[0].shape[0], 3); | |||
Assert.AreEqual(result[0].shape[1], 2); | |||
Assert.IsTrue(new[] { 1, 1, 1, 1, 1, 1 }.SequenceEqual(result[0].Data<int>())); | |||
Assert.AreEqual(result.shape[0], 3); | |||
Assert.AreEqual(result.shape[1], 2); | |||
Assert.IsTrue(new[] { 1, 1, 1, 1, 1, 1 }.SequenceEqual(result.Data<int>())); | |||
} | |||
} | |||
@@ -142,9 +142,9 @@ namespace TensorFlowNET.UnitTest | |||
{ | |||
var result = sess.run(halfes); | |||
Assert.AreEqual(result[0].shape[0], 3); | |||
Assert.AreEqual(result[0].shape[1], 2); | |||
Assert.IsTrue(new[] { .5, .5, .5, .5, .5, .5 }.SequenceEqual(result[0].Data<double>())); | |||
Assert.AreEqual(result.shape[0], 3); | |||
Assert.AreEqual(result.shape[1], 2); | |||
Assert.IsTrue(new[] { .5, .5, .5, .5, .5, .5 }.SequenceEqual(result.Data<double>())); | |||
} | |||
} | |||
@@ -161,10 +161,10 @@ namespace TensorFlowNET.UnitTest | |||
using (var sess = tf.Session()) | |||
{ | |||
var result = sess.run(tensor); | |||
var data = result[0].Data<int>(); | |||
var data = result.Data<int>(); | |||
Assert.AreEqual(result[0].shape[0], 2); | |||
Assert.AreEqual(result[0].shape[1], 3); | |||
Assert.AreEqual(result.shape[0], 2); | |||
Assert.AreEqual(result.shape[1], 3); | |||
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 1, 2, 1, 3 }, data)); | |||
} | |||
} | |||
@@ -177,7 +177,7 @@ namespace TensorFlowNET.UnitTest | |||
var c = a * b; | |||
var sess = tf.Session(); | |||
double result = sess.run(c)[0]; | |||
double result = sess.run(c); | |||
sess.close(); | |||
Assert.AreEqual(6.0, result); | |||
@@ -41,7 +41,7 @@ namespace TensorFlowNET.UnitTest | |||
var grad = tf.gradients(y, x); | |||
Assert.AreEqual(grad[0].name, "gradients/AddN:0"); | |||
float r = sess.run(grad[0])[0]; | |||
float r = sess.run(grad[0]); | |||
Assert.AreEqual(r, 1.4f); | |||
} | |||
} | |||
@@ -57,7 +57,7 @@ namespace TensorFlowNET.UnitTest | |||
var grad = tf.gradients(y, x); | |||
Assert.AreEqual(grad[0].name, "gradients/AddN:0"); | |||
float r = sess.run(grad[0])[0]; | |||
float r = sess.run(grad[0]); | |||
Assert.AreEqual(r, 14.700001f); | |||
}); | |||
} | |||
@@ -94,7 +94,7 @@ namespace TensorFlowNET.UnitTest | |||
using (var sess = tf.Session(graph)) | |||
{ | |||
var r = sess.run(slice)[0]; | |||
var r = sess.run(slice); | |||
Assert.IsTrue(Enumerable.SequenceEqual(r.shape, new[] { 2, 1, 2 })); | |||
Assert.IsTrue(Enumerable.SequenceEqual(r[0].GetData<int>(), new[] { 11, 13 })); | |||
@@ -17,7 +17,7 @@ namespace TensorFlowNET.UnitTest | |||
{ | |||
var result = sess.run(y, | |||
new FeedItem(x, 2)); | |||
Assert.AreEqual((int)result[0], 6); | |||
Assert.AreEqual((int)result, 6); | |||
} | |||
} | |||
} | |||
@@ -45,7 +45,7 @@ namespace TensorFlowNET.UnitTest | |||
EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); | |||
EXPECT_EQ(0, outTensor.NDims); | |||
ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); | |||
var output_contents = outTensor.Data<int>(); | |||
var output_contents = outTensor.ToArray<int>(); | |||
EXPECT_EQ(3 + 2, output_contents[0]); | |||
// Add another operation to the graph. | |||
@@ -66,7 +66,7 @@ namespace TensorFlowNET.UnitTest | |||
EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype); | |||
EXPECT_EQ(0, outTensor.NDims); // scalar | |||
ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize); | |||
output_contents = outTensor.Data<int>(); | |||
output_contents = outTensor.ToArray<int>(); | |||
EXPECT_EQ(-(7 + 2), output_contents[0]); | |||
// Clean up | |||
@@ -112,7 +112,7 @@ namespace TensorFlowNET.UnitTest | |||
var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); | |||
var tensor = new Tensor(nd); | |||
var array = tensor.Data<float>(); | |||
var array = tensor.ToArray<float>(); | |||
EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT); | |||
EXPECT_EQ(tensor.rank, nd.ndim); | |||
@@ -1,4 +1,5 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using NumSharp; | |||
using Tensorflow; | |||
using static Tensorflow.Binding; | |||
@@ -16,7 +17,7 @@ namespace TensorFlowNET.UnitTest | |||
{ | |||
session.run(x.initializer); | |||
var result = session.run(x); | |||
Assert.AreEqual(10, (int)result[0]); | |||
Assert.AreEqual(10, (int)result); | |||
} | |||
} | |||
@@ -81,7 +82,7 @@ namespace TensorFlowNET.UnitTest | |||
using (var session = tf.Session()) | |||
{ | |||
session.run(model); | |||
int result = session.run(y)[0]; | |||
int result = session.run(y); | |||
Assert.AreEqual(result, 4); | |||
} | |||
} | |||
@@ -97,12 +98,12 @@ namespace TensorFlowNET.UnitTest | |||
var sess = tf.Session(graph); | |||
sess.run(init); | |||
var result = sess.run(variable); | |||
Assert.IsTrue((int)result[0] == 31); | |||
NDArray result = sess.run(variable); | |||
Assert.IsTrue((int)result == 31); | |||
var assign = variable.assign(12); | |||
result = sess.run(assign); | |||
Assert.IsTrue((int)result[0] == 12); | |||
Assert.IsTrue((int)result == 12); | |||
} | |||
[TestMethod] | |||
@@ -139,7 +140,7 @@ namespace TensorFlowNET.UnitTest | |||
for(int i = 0; i < 5; i++) | |||
{ | |||
x = x + 1; | |||
result = session.run(x)[0]; | |||
result = session.run(x); | |||
print(result); | |||
} | |||
} | |||
@@ -31,7 +31,7 @@ namespace TensorFlowNET.UnitTest.nn_test | |||
var y_np = this._ZeroFraction(x_np); | |||
var x_tf = constant_op.constant(x_np); | |||
x_tf.SetShape(x_shape); | |||
x_tf.set_shape(x_shape); | |||
var y_tf = nn_impl.zero_fraction(x_tf); | |||
var y_tf_np = self.evaluate<NDArray>(y_tf); | |||