diff --git a/src/TensorFlowNET.Core/Buffers/Buffer.cs b/src/TensorFlowNET.Core/Buffers/Buffer.cs
index 396fb311..c08d3175 100644
--- a/src/TensorFlowNET.Core/Buffers/Buffer.cs
+++ b/src/TensorFlowNET.Core/Buffers/Buffer.cs
@@ -15,58 +15,116 @@
******************************************************************************/
using System;
+using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
+using NumSharp.Backends.Unmanaged;
+using static Tensorflow.c_api;
namespace Tensorflow
{
+ ///
+ /// Represents a TF_Buffer that can be passed to Tensorflow.
+ ///
public class Buffer : DisposableObject
{
- private TF_Buffer buffer => Marshal.PtrToStructure(_handle);
+ private unsafe TF_Buffer buffer
+ {
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ get => *bufferptr;
+ }
+
+ private unsafe TF_Buffer* bufferptr
+ {
+ [MethodImpl(MethodImplOptions.AggressiveInlining)]
+ get => (TF_Buffer*) _handle;
+ }
- public byte[] Data
+ ///
+ /// The memory block representing this buffer.
+ ///
+ /// The deallocator is set to null.
+ public UnmanagedMemoryBlock MemoryBlock
{
- get
+ get
{
- var data = new byte[buffer.length];
- if (data.Length > 0)
- Marshal.Copy(buffer.data, data, 0, data.Length);
- return data;
+ unsafe
+ {
+ EnsureNotDisposed();
+ var buff = (TF_Buffer*) _handle;
+ return new UnmanagedMemoryBlock((byte*) buff->data.ToPointer(), (long) buff->length);
+ }
}
}
- public int Length => (int)buffer.length;
-
- public Buffer()
+ ///
+ /// The bytes length of this buffer.
+ ///
+ public ulong Length
{
- _handle = c_api.TF_NewBuffer();
+ get
+ {
+ EnsureNotDisposed();
+ return buffer.length;
+ }
}
- public Buffer(IntPtr handle)
+ public Buffer() => _handle = TF_NewBuffer();
+
+ internal Buffer(IntPtr handle)
{
+ if (handle == IntPtr.Zero)
+ throw new ArgumentException("Handle (IntPtr) can't be zero.", nameof(handle));
+
_handle = handle;
}
- public Buffer(byte[] data)
- {
- var dst = Marshal.AllocHGlobal(data.Length);
- Marshal.Copy(data, 0, dst, data.Length);
+ public Buffer(byte[] data) : this(_toBuffer(data))
+ { }
- _handle = c_api.TF_NewBufferFromString(dst, (ulong)data.Length);
+ private static IntPtr _toBuffer(byte[] data)
+ {
+ if (data == null)
+ throw new ArgumentNullException(nameof(data));
- Marshal.FreeHGlobal(dst);
+ unsafe
+ {
+ fixed (byte* src = data)
+ return TF_NewBufferFromString(new IntPtr(src), (ulong) data.LongLength);
+ }
}
public static implicit operator IntPtr(Buffer buffer)
{
+ buffer.EnsureNotDisposed();
return buffer._handle;
}
- public static implicit operator byte[](Buffer buffer)
+ public static explicit operator byte[](Buffer buffer) => buffer.ToArray(); //has to be explicit, developer will assume it doesn't cost.
+
+ ///
+ /// Copies this buffer's contents onto a array.
+ ///
+ public byte[] ToArray()
{
- return buffer.Data;
+ EnsureNotDisposed();
+
+ unsafe
+ {
+ var len = buffer.length;
+ if (len == 0)
+ return Array.Empty();
+
+ byte[] data = new byte[len];
+ fixed (byte* dst = data)
+ System.Buffer.MemoryCopy((void*) bufferptr->data, dst, len, len);
+
+ return data;
+ }
}
protected override void DisposeUnmanagedResources(IntPtr handle)
- => c_api.TF_DeleteBuffer(handle);
+ {
+ TF_DeleteBuffer(handle);
+ }
}
-}
+}
\ No newline at end of file
diff --git a/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs b/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs
index 9f9b4ad7..8a2bc5c3 100644
--- a/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs
+++ b/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs
@@ -15,6 +15,8 @@
******************************************************************************/
using System.Collections.Generic;
+using System.IO;
+using Tensorflow.Util;
namespace Tensorflow
{
@@ -27,12 +29,12 @@ namespace Tensorflow
if(_registered_ops == null)
{
_registered_ops = new Dictionary();
- var handle = c_api.TF_GetAllOpList();
- var buffer = new Buffer(handle);
- var op_list = OpList.Parser.ParseFrom(buffer);
-
- foreach (var op_def in op_list.Op)
- _registered_ops[op_def.Name] = op_def;
+ using (var buffer = new Buffer(c_api.TF_GetAllOpList()))
+ {
+ var op_list = OpList.Parser.ParseFrom(buffer.MemoryBlock.Stream());
+ foreach (var op_def in op_list.Op)
+ _registered_ops[op_def.Name] = op_def;
+ }
}
return _registered_ops;
diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs
index 4a3ac793..c97e1b6f 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs
@@ -15,6 +15,7 @@
******************************************************************************/
using System.Collections.Generic;
+using System.Diagnostics.CodeAnalysis;
using System.Linq;
using Tensorflow.Operations;
@@ -66,8 +67,9 @@ namespace Tensorflow
/// within the context should have control dependencies on
/// `control_inputs`.
///
+ [SuppressMessage("ReSharper", "CoVariantArrayConversion")]
public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs)
- => control_dependencies(control_inputs == null ? null : control_inputs.OfType