Browse Source

BatchedExecutor Fixed Forking (#621)

* Previously when a conversation was forked this would result in both the parent and the child sharing exactly the same logits. Since sampling is allowed to modify logits this could lead to issues in sampling (e.g. one conversation is sampled and overwrites logits to be all zero, second conversation is sampled and generates nonsense). Fixed this by setting a "forked" flag, logits are copied if this flag is set. Flag is cleared next time the conversation is prompted so this extra copying only happens once after a fork occurs.

* Removed finalizer from `BatchedExecutor`. This class does not directly own any unmanaged resources so it is not necessary.
tags/0.11.0
Martin Evans GitHub 1 year ago
parent
commit
268f3a6b07
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
2 changed files with 72 additions and 30 deletions
  1. +1
    -11
      LLama/Batched/BatchedExecutor.cs
  2. +71
    -19
      LLama/Batched/Conversation.cs

+ 1
- 11
LLama/Batched/BatchedExecutor.cs View File

@@ -55,14 +55,6 @@ public sealed class BatchedExecutor
Epoch = 1;
}

/// <summary>
/// Finalizer for BatchedExecutor
/// </summary>
~BatchedExecutor()
{
Dispose();
}

/// <summary>
/// Start a new <see cref="Conversation"/> with the given prompt
/// </summary>
@@ -89,7 +81,7 @@ public sealed class BatchedExecutor
if (IsDisposed)
throw new ObjectDisposedException(nameof(BatchedExecutor));

return new Conversation(this, GetNextSequenceId(), 0);
return new Conversation(this, GetNextSequenceId());
}

/// <summary>
@@ -123,8 +115,6 @@ public sealed class BatchedExecutor
return;
IsDisposed = true;

GC.SuppressFinalize(this);

Context.Dispose();
}



+ 71
- 19
LLama/Batched/Conversation.cs View File

@@ -1,5 +1,7 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using LLama.Native;

namespace LLama.Batched;
@@ -14,6 +16,7 @@ public sealed class Conversation
private LLamaPos _end;
private int _batchIndex;
private bool _disposed;
private bool _forked;

/// <summary>
/// The executor which this conversation belongs to
@@ -46,12 +49,10 @@ public sealed class Conversation
public bool RequiresSampling => _requiredEpoch == Executor.Epoch;

#region construction/destruction
internal Conversation(BatchedExecutor batch, LLamaSeqId id, LLamaPos end)
internal Conversation(BatchedExecutor batch, LLamaSeqId id)
{
ConversationId = id;
Executor = batch;

_end = end;
}

/// <summary>
@@ -98,16 +99,24 @@ public sealed class Conversation
{
AssertNotDisposed();

if (RequiresInference)
throw new CannotForkWhileRequiresInferenceException();

// Create a new conversation which references the current position in this one
var c = new Conversation(Executor, Executor.GetNextSequenceId(), _end)
var c = new Conversation(Executor, Executor.GetNextSequenceId())
{
_batchIndex = _batchIndex,
// Because these values are copied to the forked conversation it means that it will share the exact same output
// logits next time sampling is done. This is a problem, because the sampling process is allowed to modify those
// logits, so sampling one conversation may mess up the fork! Setting the "forked" flag on both sequences ensures
// they both copy the logits before the next sampling run, to fix this issue.
_requiredEpoch = _requiredEpoch,
_batchIndex = _batchIndex,
_forked = true,

_end = _end,
};

// Setting this flag means that logits will be copied next time sampling is called, ensuring that the forked
// conversation doesn't share logits with this one.
_forked = true;

// Assign tokens to the new sequence
NativeApi.llama_kv_cache_seq_cp(Executor.Context.NativeHandle, ConversationId, c.ConversationId, 0, _end);

@@ -131,7 +140,14 @@ public sealed class Conversation
if (_requiredEpoch > Executor.Epoch)
throw new CannotSampleRequiresInferenceException();

return Executor.Context.NativeHandle.GetLogitsIth(_batchIndex);
var span = Executor.Context.NativeHandle.GetLogitsIth(_batchIndex);

// If necessary copy the span, to protect it from modification. This is only done when
// this conversation has been forked in this epoch.
if (_forked)
span = span.ToArray();

return span;
}
#endregion

@@ -162,20 +178,56 @@ public sealed class Conversation
/// <param name="tokens"></param>
/// <returns></returns>
/// <exception cref="ObjectDisposedException"></exception>
public void Prompt(IReadOnlyList<LLamaToken> tokens)
/// <exception cref="AlreadyPromptedConversationException"></exception>
public void Prompt(List<LLamaToken> tokens)
{
AssertCanBePrompted();

#if NET6_0_OR_GREATER
var span = CollectionsMarshal.AsSpan(tokens);
Prompt(span);
#else
// Borrow an array and copy tokens into it
var arr = ArrayPool<LLamaToken>.Shared.Rent(tokens.Count);
try
{
for (var i = 0; i < tokens.Count; i++)
arr[i] = tokens[i];

Prompt(arr.AsSpan());
}
finally
{
ArrayPool<LLamaToken>.Shared.Return(arr);
}
#endif
}

/// <summary>
/// Add tokens to this conversation
/// </summary>
/// <param name="tokens"></param>
/// <returns></returns>
/// <exception cref="ObjectDisposedException"></exception>
/// <exception cref="AlreadyPromptedConversationException"></exception>
public void Prompt(ReadOnlySpan<LLamaToken> tokens)
{
AssertCanBePrompted();

// No point doing anything if there is no actual prompt!
if (tokens.Count == 0)
if (tokens.Length == 0)
return;

// Add the prompt to the batch
for (var i = 0; i < tokens.Count; i++)
_batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Count - 1);
for (var i = 0; i < tokens.Length; i++)
_batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1);

// Mark this conversation as needing inference/sampling
_requiredEpoch = Executor.Epoch + 1;

// Unset the forked flag. Since this conversation has just been prompted it's no longer
// sharing anything with any other conversations.
_forked = false;
}

/// <summary>
@@ -184,16 +236,16 @@ public sealed class Conversation
/// <param name="token"></param>
/// <returns></returns>
/// <exception cref="ObjectDisposedException"></exception>
/// <exception cref="InvalidOperationException"></exception>
/// <exception cref="AlreadyPromptedConversationException"></exception>
public void Prompt(LLamaToken token)
{
AssertCanBePrompted();

// Add this token as input
_batchIndex = Executor.Batch.Add(token, _end++, ConversationId, true);
// Mark this conversation as needing inference/sampling
_requiredEpoch = Executor.Epoch + 1;
unsafe
{
Span<LLamaToken> span = stackalloc LLamaToken[1] { token };
Prompt(span);
}
}
#endregion



Loading…
Cancel
Save