You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ModelSession.cs 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. using LLama.Abstractions;
  2. using LLama.Web.Common;
  3. namespace LLama.Web.Models
  4. {
  5. public class ModelSession
  6. {
  7. private readonly string _sessionId;
  8. private readonly LLamaModel _model;
  9. private readonly LLamaContext _context;
  10. private readonly ILLamaExecutor _executor;
  11. private readonly Common.SessionOptions _sessionParams;
  12. private readonly ITextStreamTransform _outputTransform;
  13. private readonly InferenceOptions _defaultInferenceConfig;
  14. private CancellationTokenSource _cancellationTokenSource;
  15. public ModelSession(LLamaModel model, LLamaContext context, string sessionId, Common.SessionOptions sessionOptions, InferenceOptions inferenceOptions = null)
  16. {
  17. _model = model;
  18. _context = context;
  19. _sessionId = sessionId;
  20. _sessionParams = sessionOptions;
  21. _defaultInferenceConfig = inferenceOptions ?? new InferenceOptions();
  22. _outputTransform = CreateOutputFilter(_sessionParams);
  23. _executor = CreateExecutor(_model, _context, _sessionParams);
  24. }
  25. /// <summary>
  26. /// Gets the session identifier.
  27. /// </summary>
  28. public string SessionId => _sessionId;
  29. /// <summary>
  30. /// Gets the name of the model.
  31. /// </summary>
  32. public string ModelName => _sessionParams.Model;
  33. /// <summary>
  34. /// Gets the context.
  35. /// </summary>
  36. public LLamaContext Context => _context;
  37. /// <summary>
  38. /// Gets the session configuration.
  39. /// </summary>
  40. public Common.SessionOptions SessionConfig => _sessionParams;
  41. /// <summary>
  42. /// Gets the inference parameters.
  43. /// </summary>
  44. public InferenceOptions InferenceParams => _defaultInferenceConfig;
  45. /// <summary>
  46. /// Initializes the prompt.
  47. /// </summary>
  48. /// <param name="inferenceConfig">The inference configuration.</param>
  49. /// <param name="cancellationToken">The cancellation token.</param>
  50. internal async Task InitializePrompt(InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default)
  51. {
  52. if (_sessionParams.ExecutorType == LLamaExecutorType.Stateless)
  53. return;
  54. if (string.IsNullOrEmpty(_sessionParams.Prompt))
  55. return;
  56. // Run Initial prompt
  57. var inferenceParams = ConfigureInferenceParams(inferenceConfig);
  58. _cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
  59. await foreach (var _ in _executor.InferAsync(_sessionParams.Prompt, inferenceParams, _cancellationTokenSource.Token))
  60. {
  61. // We dont really need the response of the initial prompt, so exit on first token
  62. break;
  63. };
  64. }
  65. /// <summary>
  66. /// Runs inference on the model context
  67. /// </summary>
  68. /// <param name="message">The message.</param>
  69. /// <param name="inferenceConfig">The inference configuration.</param>
  70. /// <param name="cancellationToken">The cancellation token.</param>
  71. /// <returns></returns>
  72. internal IAsyncEnumerable<string> InferAsync(string message, InferenceOptions inferenceConfig = null, CancellationToken cancellationToken = default)
  73. {
  74. var inferenceParams = ConfigureInferenceParams(inferenceConfig);
  75. _cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
  76. var inferenceStream = _executor.InferAsync(message, inferenceParams, _cancellationTokenSource.Token);
  77. if (_outputTransform is not null)
  78. return _outputTransform.TransformAsync(inferenceStream);
  79. return inferenceStream;
  80. }
  81. public void CancelInfer()
  82. {
  83. _cancellationTokenSource?.Cancel();
  84. }
  85. public bool IsInferCanceled()
  86. {
  87. return _cancellationTokenSource.IsCancellationRequested;
  88. }
  89. /// <summary>
  90. /// Configures the inference parameters.
  91. /// </summary>
  92. /// <param name="inferenceConfig">The inference configuration.</param>
  93. private IInferenceParams ConfigureInferenceParams(InferenceOptions inferenceConfig)
  94. {
  95. var inferenceParams = inferenceConfig ?? _defaultInferenceConfig;
  96. inferenceParams.AntiPrompts = _sessionParams.GetAntiPrompts();
  97. return inferenceParams;
  98. }
  99. private ITextStreamTransform CreateOutputFilter(Common.SessionOptions sessionConfig)
  100. {
  101. var outputFilters = sessionConfig.GetOutputFilters();
  102. if (outputFilters.Count > 0)
  103. return new LLamaTransforms.KeywordTextOutputStreamTransform(outputFilters);
  104. return null;
  105. }
  106. private ILLamaExecutor CreateExecutor(LLamaModel model, LLamaContext context, Common.SessionOptions sessionConfig)
  107. {
  108. return sessionConfig.ExecutorType switch
  109. {
  110. LLamaExecutorType.Interactive => new InteractiveExecutor(_context),
  111. LLamaExecutorType.Instruct => new InstructExecutor(_context),
  112. LLamaExecutorType.Stateless => new StatelessExecutor(_model.LLamaWeights, _model.ModelParams),
  113. _ => default
  114. };
  115. }
  116. }
  117. }