using LLama.Abstractions; using LLama.Web.Common; namespace LLama.Web.Models { public class ModelSession { private readonly string _sessionId; private readonly LLamaModel _model; private readonly LLamaContext _context; private readonly ILLamaExecutor _executor; private readonly Common.SessionOptions _sessionParams; private readonly ITextStreamTransform _outputTransform; private readonly InferenceOptions _defaultInferenceConfig; private CancellationTokenSource _cancellationTokenSource; public ModelSession(LLamaModel model, LLamaContext context, string sessionId, Common.SessionOptions sessionOptions, InferenceOptions inferenceOptions = null) { _model = model; _context = context; _sessionId = sessionId; _sessionParams = sessionOptions; _defaultInferenceConfig = inferenceOptions ?? new InferenceOptions(); _outputTransform = CreateOutputFilter(_sessionParams); _executor = CreateExecutor(_model, _context, _sessionParams); } /// /// Gets the session identifier. /// public string SessionId => _sessionId; /// /// Gets the name of the model. /// public string ModelName => _sessionParams.Model; /// /// Gets the context. /// public LLamaContext Context => _context; /// /// Gets the session configuration. /// public Common.SessionOptions SessionConfig => _sessionParams; /// /// Gets the inference parameters. /// public InferenceOptions InferenceParams => _defaultInferenceConfig; /// /// Initializes the prompt. /// /// The inference configuration. /// The cancellation token. internal async Task InitializePrompt(InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) { if (_sessionParams.ExecutorType == LLamaExecutorType.Stateless) return; if (string.IsNullOrEmpty(_sessionParams.Prompt)) return; // Run Initial prompt var inferenceParams = ConfigureInferenceParams(inferenceConfig); _cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); await foreach (var _ in _executor.InferAsync(_sessionParams.Prompt, inferenceParams, _cancellationTokenSource.Token)) { // We dont really need the response of the initial prompt, so exit on first token break; }; } /// /// Runs inference on the model context /// /// The message. /// The inference configuration. /// The cancellation token. /// internal IAsyncEnumerable InferAsync(string message, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default) { var inferenceParams = ConfigureInferenceParams(inferenceConfig); _cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); var inferenceStream = _executor.InferAsync(message, inferenceParams, _cancellationTokenSource.Token); if (_outputTransform is not null) return _outputTransform.TransformAsync(inferenceStream); return inferenceStream; } public void CancelInfer() { _cancellationTokenSource?.Cancel(); } public bool IsInferCanceled() { return _cancellationTokenSource.IsCancellationRequested; } /// /// Configures the inference parameters. /// /// The inference configuration. private IInferenceParams ConfigureInferenceParams(InferenceOptions inferenceConfig) { var inferenceParams = inferenceConfig ?? _defaultInferenceConfig; inferenceParams.AntiPrompts = _sessionParams.GetAntiPrompts(); return inferenceParams; } private ITextStreamTransform CreateOutputFilter(Common.SessionOptions sessionConfig) { var outputFilters = sessionConfig.GetOutputFilters(); if (outputFilters.Count > 0) return new LLamaTransforms.KeywordTextOutputStreamTransform(outputFilters); return null; } private ILLamaExecutor CreateExecutor(LLamaModel model, LLamaContext context, Common.SessionOptions sessionConfig) { return sessionConfig.ExecutorType switch { LLamaExecutorType.Interactive => new InteractiveExecutor(_context), LLamaExecutorType.Instruct => new InstructExecutor(_context), LLamaExecutorType.Stateless => new StatelessExecutor(_model.LLamaWeights, _model.ModelParams), _ => default }; } } }