Browse Source

Merge pull request #57 from martindevans/larger_states

Larger states
tags/v0.4.2-preview
Rinne GitHub 2 years ago
parent
commit
66d6b00b49
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 88 additions and 9 deletions
  1. +76
    -3
      LLama/LLamaModel.cs
  2. +2
    -4
      LLama/LLamaStatelessExecutor.cs
  3. +10
    -2
      LLama/ResettableLLamaModel.cs

+ 76
- 3
LLama/LLamaModel.cs View File

@@ -7,6 +7,9 @@ using System.Text;
using System.IO;
using System.IO.MemoryMappedFiles;
using LLama.Common;
using System.Runtime.InteropServices;
using LLama.Extensions;
using Microsoft.Win32.SafeHandles;

namespace LLama
{
@@ -117,6 +120,7 @@ namespace LLama
/// Get the state data as a byte array.
/// </summary>
/// <returns></returns>
[Obsolete("Use `GetState` instead, this supports larger states (over 2GB)")]
public byte[] GetStateData()
{
var stateSize = NativeApi.llama_get_state_size(_ctx);
@@ -125,6 +129,44 @@ namespace LLama
return stateMemory;
}

/// <summary>
/// Get the state data as an opaque handle
/// </summary>
/// <returns></returns>
public State GetState()
{
var stateSize = NativeApi.llama_get_state_size(_ctx);

unsafe
{
var bigMemory = Marshal.AllocHGlobal((nint)stateSize);
var smallMemory = IntPtr.Zero;
try
{
// Copy the state data into "big memory", discover the actual size required
var actualSize = NativeApi.llama_copy_state_data(_ctx, (byte*)bigMemory);

// Allocate a smaller buffer
smallMemory = Marshal.AllocHGlobal((nint)actualSize);

// Copy into the smaller buffer and free the large one to save excess memory usage
Buffer.MemoryCopy(bigMemory.ToPointer(), smallMemory.ToPointer(), actualSize, actualSize);
Marshal.FreeHGlobal(bigMemory);
bigMemory = IntPtr.Zero;

return new State(smallMemory);
}
catch
{
if (bigMemory != IntPtr.Zero)
Marshal.FreeHGlobal(bigMemory);
if (smallMemory != IntPtr.Zero)
Marshal.FreeHGlobal(smallMemory);
throw;
}
}
}

/// <summary>
/// Load the state from specified path.
/// </summary>
@@ -161,6 +203,19 @@ namespace LLama
NativeApi.llama_set_state_data(_ctx, stateData);
}

/// <summary>
/// Load the state from memory.
/// </summary>
/// <param name="state"></param>
/// <exception cref="RuntimeError"></exception>
public void LoadState(State state)
{
unsafe
{
NativeApi.llama_set_state_data(_ctx, (byte*)state.DangerousGetHandle().ToPointer());
}
}

/// <summary>
/// Perform the sampling. Please don't use it unless you fully know what it does.
/// </summary>
@@ -304,12 +359,30 @@ namespace LLama
}
}

/// <inheritdoc />
public virtual void Dispose()
{
_ctx.Dispose();
}

/// <summary>
///
/// The state of this model, which can be reloaded later
/// </summary>
public void Dispose()
public class State
: SafeHandleZeroOrMinusOneIsInvalid
{
_ctx.Dispose();
internal State(IntPtr memory)
: base(true)
{
SetHandle(memory);
}

/// <inheritdoc />
protected override bool ReleaseHandle()
{
Marshal.FreeHGlobal(handle);
return true;
}
}
}
}

+ 2
- 4
LLama/LLamaStatelessExecutor.cs View File

@@ -3,10 +3,8 @@ using LLama.Common;
using LLama.Native;
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;

namespace LLama
@@ -19,7 +17,7 @@ namespace LLama
public class StatelessExecutor : ILLamaExecutor
{
private LLamaModel _model;
private byte[] _originalState;
private LLamaModel.State _originalState;
/// <summary>
/// The mode used by the executor when running the inference.
/// </summary>
@@ -33,7 +31,7 @@ namespace LLama
_model = model;
var tokens = model.Tokenize(" ", true);
Utils.Eval(_model.NativeHandle, tokens.ToArray(), 0, tokens.Count(), 0, _model.Params.Threads);
_originalState = model.GetStateData();
_originalState = model.GetState();
}

/// <inheritdoc />


+ 10
- 2
LLama/ResettableLLamaModel.cs View File

@@ -13,7 +13,7 @@ namespace LLama
/// <summary>
/// The initial state of the model
/// </summary>
public byte[] OriginalState { get; set; }
public State OriginalState { get; set; }
/// <summary>
///
/// </summary>
@@ -21,7 +21,7 @@ namespace LLama
/// <param name="encoding"></param>
public ResettableLLamaModel(ModelParams Params, string encoding = "UTF-8") : base(Params, encoding)
{
OriginalState = GetStateData();
OriginalState = GetState();
}

/// <summary>
@@ -31,5 +31,13 @@ namespace LLama
{
LoadState(OriginalState);
}

/// <inheritdoc />
public override void Dispose()
{
OriginalState.Dispose();

base.Dispose();
}
}
}

Loading…
Cancel
Save