|
|
|
@@ -7,6 +7,9 @@ using System.Text; |
|
|
|
using System.IO; |
|
|
|
using System.IO.MemoryMappedFiles; |
|
|
|
using LLama.Common; |
|
|
|
using System.Runtime.InteropServices; |
|
|
|
using LLama.Extensions; |
|
|
|
using Microsoft.Win32.SafeHandles; |
|
|
|
|
|
|
|
namespace LLama |
|
|
|
{ |
|
|
|
@@ -117,6 +120,7 @@ namespace LLama |
|
|
|
/// Get the state data as a byte array. |
|
|
|
/// </summary> |
|
|
|
/// <returns></returns> |
|
|
|
[Obsolete("Use `GetState` instead, this supports larger states (over 2GB)")] |
|
|
|
public byte[] GetStateData() |
|
|
|
{ |
|
|
|
var stateSize = NativeApi.llama_get_state_size(_ctx); |
|
|
|
@@ -125,6 +129,44 @@ namespace LLama |
|
|
|
return stateMemory; |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Get the state data as an opaque handle |
|
|
|
/// </summary> |
|
|
|
/// <returns></returns> |
|
|
|
public State GetState() |
|
|
|
{ |
|
|
|
var stateSize = NativeApi.llama_get_state_size(_ctx); |
|
|
|
|
|
|
|
unsafe |
|
|
|
{ |
|
|
|
var bigMemory = Marshal.AllocHGlobal((nint)stateSize); |
|
|
|
var smallMemory = IntPtr.Zero; |
|
|
|
try |
|
|
|
{ |
|
|
|
// Copy the state data into "big memory", discover the actual size required |
|
|
|
var actualSize = NativeApi.llama_copy_state_data(_ctx, (byte*)bigMemory); |
|
|
|
|
|
|
|
// Allocate a smaller buffer |
|
|
|
smallMemory = Marshal.AllocHGlobal((nint)actualSize); |
|
|
|
|
|
|
|
// Copy into the smaller buffer and free the large one to save excess memory usage |
|
|
|
Buffer.MemoryCopy(bigMemory.ToPointer(), smallMemory.ToPointer(), actualSize, actualSize); |
|
|
|
Marshal.FreeHGlobal(bigMemory); |
|
|
|
bigMemory = IntPtr.Zero; |
|
|
|
|
|
|
|
return new State(smallMemory); |
|
|
|
} |
|
|
|
catch |
|
|
|
{ |
|
|
|
if (bigMemory != IntPtr.Zero) |
|
|
|
Marshal.FreeHGlobal(bigMemory); |
|
|
|
if (smallMemory != IntPtr.Zero) |
|
|
|
Marshal.FreeHGlobal(smallMemory); |
|
|
|
throw; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Load the state from specified path. |
|
|
|
/// </summary> |
|
|
|
@@ -161,6 +203,19 @@ namespace LLama |
|
|
|
NativeApi.llama_set_state_data(_ctx, stateData); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Load the state from memory. |
|
|
|
/// </summary> |
|
|
|
/// <param name="state"></param> |
|
|
|
/// <exception cref="RuntimeError"></exception> |
|
|
|
public void LoadState(State state) |
|
|
|
{ |
|
|
|
unsafe |
|
|
|
{ |
|
|
|
NativeApi.llama_set_state_data(_ctx, (byte*)state.DangerousGetHandle().ToPointer()); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Perform the sampling. Please don't use it unless you fully know what it does. |
|
|
|
/// </summary> |
|
|
|
@@ -304,12 +359,30 @@ namespace LLama |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
/// <inheritdoc /> |
|
|
|
public virtual void Dispose() |
|
|
|
{ |
|
|
|
_ctx.Dispose(); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// |
|
|
|
/// The state of this model, which can be reloaded later |
|
|
|
/// </summary> |
|
|
|
public void Dispose() |
|
|
|
public class State |
|
|
|
: SafeHandleZeroOrMinusOneIsInvalid |
|
|
|
{ |
|
|
|
_ctx.Dispose(); |
|
|
|
internal State(IntPtr memory) |
|
|
|
: base(true) |
|
|
|
{ |
|
|
|
SetHandle(memory); |
|
|
|
} |
|
|
|
|
|
|
|
/// <inheritdoc /> |
|
|
|
protected override bool ReleaseHandle() |
|
|
|
{ |
|
|
|
Marshal.FreeHGlobal(handle); |
|
|
|
return true; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |