|
- using LLama.Abstractions;
- using LLama.Web.Common;
-
- namespace LLama.Web.Models
- {
- public class ModelSession
- {
- private readonly string _sessionId;
- private readonly LLamaModel _model;
- private readonly LLamaContext _context;
- private readonly ILLamaExecutor _executor;
- private readonly ISessionConfig _sessionConfig;
- private readonly ITextStreamTransform _outputTransform;
- private readonly InferenceOptions _defaultInferenceConfig;
-
- private CancellationTokenSource _cancellationTokenSource;
-
- public ModelSession(LLamaModel model, LLamaContext context, string sessionId, ISessionConfig sessionConfig, InferenceOptions inferenceOptions = null)
- {
- _model = model;
- _context = context;
- _sessionId = sessionId;
- _sessionConfig = sessionConfig;
- _defaultInferenceConfig = inferenceOptions ?? new InferenceOptions();
- _outputTransform = CreateOutputFilter();
- _executor = CreateExecutor();
- }
-
- /// <summary>
- /// Gets the session identifier.
- /// </summary>
- public string SessionId => _sessionId;
-
- /// <summary>
- /// Gets the name of the model.
- /// </summary>
- public string ModelName => _sessionConfig.Model;
-
- /// <summary>
- /// Gets the context.
- /// </summary>
- public LLamaContext Context => _context;
-
- /// <summary>
- /// Gets the session configuration.
- /// </summary>
- public ISessionConfig SessionConfig => _sessionConfig;
-
- /// <summary>
- /// Gets the inference parameters.
- /// </summary>
- public InferenceOptions InferenceParams => _defaultInferenceConfig;
-
-
-
- /// <summary>
- /// Initializes the prompt.
- /// </summary>
- /// <param name="inferenceConfig">The inference configuration.</param>
- /// <param name="cancellationToken">The cancellation token.</param>
- internal async Task InitializePrompt(InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default)
- {
- if (_sessionConfig.ExecutorType == LLamaExecutorType.Stateless)
- return;
-
- if (string.IsNullOrEmpty(_sessionConfig.Prompt))
- return;
-
- // Run Initial prompt
- var inferenceParams = ConfigureInferenceParams(inferenceConfig);
- _cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
- await foreach (var _ in _executor.InferAsync(_sessionConfig.Prompt, inferenceParams, _cancellationTokenSource.Token))
- {
- // We dont really need the response of the initial prompt, so exit on first token
- break;
- };
- }
-
-
- /// <summary>
- /// Runs inference on the model context
- /// </summary>
- /// <param name="message">The message.</param>
- /// <param name="inferenceConfig">The inference configuration.</param>
- /// <param name="cancellationToken">The cancellation token.</param>
- /// <returns></returns>
- internal IAsyncEnumerable<string> InferAsync(string message, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default)
- {
- var inferenceParams = ConfigureInferenceParams(inferenceConfig);
- _cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
-
- var inferenceStream = _executor.InferAsync(message, inferenceParams, _cancellationTokenSource.Token);
- if (_outputTransform is not null)
- return _outputTransform.TransformAsync(inferenceStream);
-
- return inferenceStream;
- }
-
-
- public void CancelInfer()
- {
- _cancellationTokenSource?.Cancel();
- }
-
- public bool IsInferCanceled()
- {
- return _cancellationTokenSource.IsCancellationRequested;
- }
-
- /// <summary>
- /// Configures the inference parameters.
- /// </summary>
- /// <param name="inferenceConfig">The inference configuration.</param>
- private IInferenceParams ConfigureInferenceParams(InferenceOptions inferenceConfig)
- {
- var inferenceParams = inferenceConfig ?? _defaultInferenceConfig;
- inferenceParams.AntiPrompts = _sessionConfig.GetAntiPrompts();
- return inferenceParams;
- }
-
- private ITextStreamTransform CreateOutputFilter()
- {
- var outputFilters = _sessionConfig.GetOutputFilters();
- if (outputFilters.Count > 0)
- return new LLamaTransforms.KeywordTextOutputStreamTransform(outputFilters);
-
- return null;
- }
-
-
- private ILLamaExecutor CreateExecutor()
- {
- return _sessionConfig.ExecutorType switch
- {
- LLamaExecutorType.Interactive => new InteractiveExecutor(_context),
- LLamaExecutorType.Instruct => new InstructExecutor(_context),
- LLamaExecutorType.Stateless => new StatelessExecutor(_model.LLamaWeights, _context.Params),
- _ => default
- };
- }
- }
- }
|