|
|
@@ -16,44 +16,36 @@ |
|
|
|
|
|
|
|
using Google.Protobuf; |
|
|
|
using System; |
|
|
|
using System.Runtime.InteropServices; |
|
|
|
|
|
|
|
namespace Tensorflow |
|
|
|
{ |
|
|
|
internal class SessionOptions : DisposableObject |
|
|
|
internal sealed class SessionOptions : IDisposable |
|
|
|
{ |
|
|
|
public SafeSessionOptionsHandle Handle { get; } |
|
|
|
|
|
|
|
public SessionOptions(string target = "", ConfigProto config = null) |
|
|
|
{ |
|
|
|
_handle = c_api.TF_NewSessionOptions(); |
|
|
|
c_api.TF_SetTarget(_handle, target); |
|
|
|
Handle = c_api.TF_NewSessionOptions(); |
|
|
|
c_api.TF_SetTarget(Handle, target); |
|
|
|
if (config != null) |
|
|
|
SetConfig(config); |
|
|
|
} |
|
|
|
|
|
|
|
public SessionOptions(IntPtr handle) |
|
|
|
{ |
|
|
|
_handle = handle; |
|
|
|
} |
|
|
|
|
|
|
|
protected override void DisposeUnmanagedResources(IntPtr handle) |
|
|
|
=> c_api.TF_DeleteSessionOptions(handle); |
|
|
|
public void Dispose() |
|
|
|
=> Handle.Dispose(); |
|
|
|
|
|
|
|
private void SetConfig(ConfigProto config) |
|
|
|
private unsafe void SetConfig(ConfigProto config) |
|
|
|
{ |
|
|
|
var bytes = config.ToByteArray(); |
|
|
|
var proto = Marshal.AllocHGlobal(bytes.Length); |
|
|
|
Marshal.Copy(bytes, 0, proto, bytes.Length); |
|
|
|
|
|
|
|
using (var status = new Status()) |
|
|
|
fixed (byte* proto2 = bytes) |
|
|
|
{ |
|
|
|
c_api.TF_SetConfig(_handle, proto, (ulong)bytes.Length, status.Handle); |
|
|
|
status.Check(false); |
|
|
|
using (var status = new Status()) |
|
|
|
{ |
|
|
|
c_api.TF_SetConfig(Handle, (IntPtr)proto2, (ulong)bytes.Length, status.Handle); |
|
|
|
status.Check(false); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
Marshal.FreeHGlobal(proto); |
|
|
|
} |
|
|
|
|
|
|
|
public static implicit operator IntPtr(SessionOptions opts) => opts._handle; |
|
|
|
public static implicit operator SessionOptions(IntPtr handle) => new SessionOptions(handle); |
|
|
|
} |
|
|
|
} |