Browse Source

Merge pull request #563 from sharwell/safe-session-options-handle

Implement SafeSessionOptionsHandle as a wrapper for TF_SessionOptions
tags/v0.20
Haiping GitHub 5 years ago
parent
commit
d42bc071d8
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 62 additions and 30 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Graphs/c_api.graph.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  3. +40
    -0
      src/TensorFlowNET.Core/Sessions/SafeSessionOptionsHandle.cs
  4. +2
    -2
      src/TensorFlowNET.Core/Sessions/Session.cs
  5. +14
    -22
      src/TensorFlowNET.Core/Sessions/SessionOptions.cs
  6. +4
    -4
      src/TensorFlowNET.Core/Sessions/c_api.session.cs

+ 1
- 1
src/TensorFlowNET.Core/Graphs/c_api.graph.cs View File

@@ -287,7 +287,7 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TF_LoadSessionFromSavedModel(IntPtr session_options, IntPtr run_options,
public static extern IntPtr TF_LoadSessionFromSavedModel(SafeSessionOptionsHandle session_options, IntPtr run_options,
string export_dir, string[] tags, int tags_len,
IntPtr graph, ref TF_Buffer meta_graph_def, SafeStatusHandle status);



+ 1
- 1
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -47,7 +47,7 @@ namespace Tensorflow
lock (Locks.ProcessWide)
{
status = status ?? new Status();
_handle = c_api.TF_NewSession(_graph, opts, status.Handle);
_handle = c_api.TF_NewSession(_graph, opts.Handle, status.Handle);
status.Check(true);
}
}


+ 40
- 0
src/TensorFlowNET.Core/Sessions/SafeSessionOptionsHandle.cs View File

@@ -0,0 +1,40 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using Tensorflow.Util;

namespace Tensorflow
{
public sealed class SafeSessionOptionsHandle : SafeTensorflowHandle
{
public SafeSessionOptionsHandle()
{
}

public SafeSessionOptionsHandle(IntPtr handle)
: base(handle)
{
}

protected override bool ReleaseHandle()
{
c_api.TF_DeleteSessionOptions(handle);
SetHandle(IntPtr.Zero);
return true;
}
}
}

+ 2
- 2
src/TensorFlowNET.Core/Sessions/Session.cs View File

@@ -55,7 +55,7 @@ namespace Tensorflow
IntPtr sess;
try
{
sess = c_api.TF_LoadSessionFromSavedModel(opt,
sess = c_api.TF_LoadSessionFromSavedModel(opt.Handle,
IntPtr.Zero,
path,
tags,
@@ -66,7 +66,7 @@ namespace Tensorflow
status.Check(true);
} catch (TensorflowException ex) when (ex.Message.Contains("Could not find SavedModel"))
{
sess = c_api.TF_LoadSessionFromSavedModel(opt,
sess = c_api.TF_LoadSessionFromSavedModel(opt.Handle,
IntPtr.Zero,
Path.GetFullPath(path),
tags,


+ 14
- 22
src/TensorFlowNET.Core/Sessions/SessionOptions.cs View File

@@ -16,44 +16,36 @@

using Google.Protobuf;
using System;
using System.Runtime.InteropServices;

namespace Tensorflow
{
internal class SessionOptions : DisposableObject
internal sealed class SessionOptions : IDisposable
{
public SafeSessionOptionsHandle Handle { get; }

public SessionOptions(string target = "", ConfigProto config = null)
{
_handle = c_api.TF_NewSessionOptions();
c_api.TF_SetTarget(_handle, target);
Handle = c_api.TF_NewSessionOptions();
c_api.TF_SetTarget(Handle, target);
if (config != null)
SetConfig(config);
}

public SessionOptions(IntPtr handle)
{
_handle = handle;
}

protected override void DisposeUnmanagedResources(IntPtr handle)
=> c_api.TF_DeleteSessionOptions(handle);
public void Dispose()
=> Handle.Dispose();

private void SetConfig(ConfigProto config)
private unsafe void SetConfig(ConfigProto config)
{
var bytes = config.ToByteArray();
var proto = Marshal.AllocHGlobal(bytes.Length);
Marshal.Copy(bytes, 0, proto, bytes.Length);

using (var status = new Status())
fixed (byte* proto2 = bytes)
{
c_api.TF_SetConfig(_handle, proto, (ulong)bytes.Length, status.Handle);
status.Check(false);
using (var status = new Status())
{
c_api.TF_SetConfig(Handle, (IntPtr)proto2, (ulong)bytes.Length, status.Handle);
status.Check(false);
}
}

Marshal.FreeHGlobal(proto);
}

public static implicit operator IntPtr(SessionOptions opts) => opts._handle;
public static implicit operator SessionOptions(IntPtr handle) => new SessionOptions(handle);
}
}

+ 4
- 4
src/TensorFlowNET.Core/Sessions/c_api.session.cs View File

@@ -50,14 +50,14 @@ namespace Tensorflow
/// <param name="status">TF_Status*</param>
/// <returns>TF_Session*</returns>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TF_NewSession(IntPtr graph, IntPtr opts, SafeStatusHandle status);
public static extern IntPtr TF_NewSession(IntPtr graph, SafeSessionOptionsHandle opts, SafeStatusHandle status);

/// <summary>
/// Return a new options object.
/// </summary>
/// <returns>TF_SessionOptions*</returns>
[DllImport(TensorFlowLibName)]
public static extern unsafe IntPtr TF_NewSessionOptions();
public static extern SafeSessionOptionsHandle TF_NewSessionOptions();

/// <summary>
/// Run the graph associated with the session starting with the supplied inputs
@@ -116,9 +116,9 @@ namespace Tensorflow
/// <param name="proto_len">size_t</param>
/// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)]
public static extern void TF_SetConfig(IntPtr options, IntPtr proto, ulong proto_len, SafeStatusHandle status);
public static extern void TF_SetConfig(SafeSessionOptionsHandle options, IntPtr proto, ulong proto_len, SafeStatusHandle status);

[DllImport(TensorFlowLibName)]
public static extern void TF_SetTarget(IntPtr options, string target);
public static extern void TF_SetTarget(SafeSessionOptionsHandle options, string target);
}
}

Loading…
Cancel
Save