using LLama.Abstractions; using LLama.Web.Common; namespace LLama.Web.Models { public class ModelSession { private readonly string _sessionId; private readonly LLamaModel _model; private readonly LLamaContext _context; private readonly ILLamaExecutor _executor; private readonly ISessionConfig _sessionConfig; private readonly ITextStreamTransform _outputTransform; private readonly InferenceOptions _defaultInferenceConfig; private CancellationTokenSource _cancellationTokenSource; public ModelSession(LLamaModel model, LLamaContext context, string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceOptions = null) { _model = model; _context = context; _sessionId = sessionId; _sessionConfig = sessionConfig; _defaultInferenceConfig = inferenceOptions ?? new InferenceOptions(); _outputTransform = CreateOutputFilter(); _executor = CreateExecutor(); } /// /// Gets the session identifier. /// public string SessionId => _sessionId; /// /// Gets the name of the model. /// public string ModelName => _sessionConfig.Model; /// /// Gets the context. /// public LLamaContext Context => _context; /// /// Gets the session configuration. /// public ISessionConfig SessionConfig => _sessionConfig; /// /// Gets the inference parameters. /// public InferenceOptions InferenceParams => _defaultInferenceConfig; /// /// Initializes the prompt. /// /// The inference configuration. /// The cancellation token. internal async Task InitializePrompt(InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) { if (_sessionConfig.ExecutorType == LLamaExecutorType.Stateless) return; if (string.IsNullOrEmpty(_sessionConfig.Prompt)) return; // Run Initial prompt var inferenceParams = ConfigureInferenceParams(inferenceConfig); _cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); await foreach (var _ in _executor.InferAsync(_sessionConfig.Prompt, inferenceParams, _cancellationTokenSource.Token)) { // We dont really need the response of the initial prompt, so exit on first token break; }; } /// /// Runs inference on the model context /// /// The message. /// The inference configuration. /// The cancellation token. /// internal IAsyncEnumerable InferAsync(string message, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) { var inferenceParams = ConfigureInferenceParams(inferenceConfig); _cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); var inferenceStream = _executor.InferAsync(message, inferenceParams, _cancellationTokenSource.Token); if (_outputTransform is not null) return _outputTransform.TransformAsync(inferenceStream); return inferenceStream; } public void CancelInfer() { _cancellationTokenSource?.Cancel(); } public bool IsInferCanceled() { return _cancellationTokenSource.IsCancellationRequested; } /// /// Configures the inference parameters. /// /// The inference configuration. private IInferenceParams ConfigureInferenceParams(InferenceOptions inferenceConfig) { var inferenceParams = inferenceConfig ?? _defaultInferenceConfig; inferenceParams.AntiPrompts = _sessionConfig.GetAntiPrompts(); return inferenceParams; } private ITextStreamTransform CreateOutputFilter() { var outputFilters = _sessionConfig.GetOutputFilters(); if (outputFilters.Count > 0) return new LLamaTransforms.KeywordTextOutputStreamTransform(outputFilters); return null; } private ILLamaExecutor CreateExecutor() { return _sessionConfig.ExecutorType switch { LLamaExecutorType.Interactive => new InteractiveExecutor(_context), LLamaExecutorType.Instruct => new InstructExecutor(_context), LLamaExecutorType.Stateless => new StatelessExecutor(_model.LLamaWeights, _model.ModelParams), _ => default }; } } }