using LLama.Abstractions;
using LLama.Web.Common;
namespace LLama.Web.Models
{
public class ModelSession
{
private readonly string _sessionId;
private readonly LLamaModel _model;
private readonly LLamaContext _context;
private readonly ILLamaExecutor _executor;
private readonly Common.SessionOptions _sessionParams;
private readonly ITextStreamTransform _outputTransform;
private readonly InferenceOptions _defaultInferenceConfig;
private CancellationTokenSource _cancellationTokenSource;
public ModelSession(LLamaModel model, LLamaContext context, string sessionId, Common.SessionOptions sessionOptions, InferenceOptions inferenceOptions = null)
{
_model = model;
_context = context;
_sessionId = sessionId;
_sessionParams = sessionOptions;
_defaultInferenceConfig = inferenceOptions ?? new InferenceOptions();
_outputTransform = CreateOutputFilter(_sessionParams);
_executor = CreateExecutor(_model, _context, _sessionParams);
}
///
/// Gets the session identifier.
///
public string SessionId => _sessionId;
///
/// Gets the name of the model.
///
public string ModelName => _sessionParams.Model;
///
/// Gets the context.
///
public LLamaContext Context => _context;
///
/// Gets the session configuration.
///
public Common.SessionOptions SessionConfig => _sessionParams;
///
/// Gets the inference parameters.
///
public InferenceOptions InferenceParams => _defaultInferenceConfig;
///
/// Initializes the prompt.
///
/// The inference configuration.
/// The cancellation token.
internal async Task InitializePrompt(InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default)
{
if (_sessionParams.ExecutorType == LLamaExecutorType.Stateless)
return;
if (string.IsNullOrEmpty(_sessionParams.Prompt))
return;
// Run Initial prompt
var inferenceParams = ConfigureInferenceParams(inferenceConfig);
_cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
await foreach (var _ in _executor.InferAsync(_sessionParams.Prompt, inferenceParams, _cancellationTokenSource.Token))
{
// We dont really need the response of the initial prompt, so exit on first token
break;
};
}
///
/// Runs inference on the model context
///
/// The message.
/// The inference configuration.
/// The cancellation token.
///
internal IAsyncEnumerable InferAsync(string message, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default)
{
var inferenceParams = ConfigureInferenceParams(inferenceConfig);
_cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
var inferenceStream = _executor.InferAsync(message, inferenceParams, _cancellationTokenSource.Token);
if (_outputTransform is not null)
return _outputTransform.TransformAsync(inferenceStream);
return inferenceStream;
}
public void CancelInfer()
{
_cancellationTokenSource?.Cancel();
}
public bool IsInferCanceled()
{
return _cancellationTokenSource.IsCancellationRequested;
}
///
/// Configures the inference parameters.
///
/// The inference configuration.
private IInferenceParams ConfigureInferenceParams(InferenceOptions inferenceConfig)
{
var inferenceParams = inferenceConfig ?? _defaultInferenceConfig;
inferenceParams.AntiPrompts = _sessionParams.GetAntiPrompts();
return inferenceParams;
}
private ITextStreamTransform CreateOutputFilter(Common.SessionOptions sessionConfig)
{
var outputFilters = sessionConfig.GetOutputFilters();
if (outputFilters.Count > 0)
return new LLamaTransforms.KeywordTextOutputStreamTransform(outputFilters);
return null;
}
private ILLamaExecutor CreateExecutor(LLamaModel model, LLamaContext context, Common.SessionOptions sessionConfig)
{
return sessionConfig.ExecutorType switch
{
LLamaExecutorType.Interactive => new InteractiveExecutor(_context),
LLamaExecutorType.Instruct => new InstructExecutor(_context),
LLamaExecutorType.Stateless => new StatelessExecutor(_model.LLamaWeights, _model.ModelParams),
_ => default
};
}
}
}