You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

StatefulChatService.cs 2.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. 
  2. using LLama.WebAPI.Models;
  3. using Microsoft;
  4. using System.Runtime.CompilerServices;
  5. namespace LLama.WebAPI.Services;
  6. public class StatefulChatService : IDisposable
  7. {
  8. private readonly ChatSession _session;
  9. private readonly LLamaContext _context;
  10. private readonly ILogger<StatefulChatService> _logger;
  11. private bool _continue = false;
  12. private const string SystemPrompt = "Transcript of a dialog, where the User interacts with an Assistant. Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.";
  13. public StatefulChatService(IConfiguration configuration, ILogger<StatefulChatService> logger)
  14. {
  15. var @params = new Common.ModelParams(configuration["ModelPath"]!)
  16. {
  17. ContextSize = 512,
  18. };
  19. // todo: share weights from a central service
  20. using var weights = LLamaWeights.LoadFromFile(@params);
  21. _logger = logger;
  22. _context = new LLamaContext(weights, @params);
  23. _session = new ChatSession(new InteractiveExecutor(_context));
  24. _session.History.AddMessage(Common.AuthorRole.System, SystemPrompt);
  25. }
  26. public void Dispose()
  27. {
  28. _context?.Dispose();
  29. }
  30. public async Task<string> Send(SendMessageInput input)
  31. {
  32. if (!_continue)
  33. {
  34. _logger.LogInformation("Prompt: {text}", SystemPrompt);
  35. _continue = true;
  36. }
  37. _logger.LogInformation("Input: {text}", input.Text);
  38. var outputs = _session.ChatAsync(
  39. new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text),
  40. new Common.InferenceParams()
  41. {
  42. RepeatPenalty = 1.0f,
  43. AntiPrompts = new string[] { "User:" },
  44. });
  45. var result = "";
  46. await foreach (var output in outputs)
  47. {
  48. _logger.LogInformation("Message: {output}", output);
  49. result += output;
  50. }
  51. return result;
  52. }
  53. public async IAsyncEnumerable<string> SendStream(SendMessageInput input)
  54. {
  55. if (!_continue)
  56. {
  57. _logger.LogInformation(SystemPrompt);
  58. _continue = true;
  59. }
  60. _logger.LogInformation(input.Text);
  61. var outputs = _session.ChatAsync(
  62. new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text!)
  63. , new Common.InferenceParams()
  64. {
  65. RepeatPenalty = 1.0f,
  66. AntiPrompts = new string[] { "User:" },
  67. });
  68. await foreach (var output in outputs)
  69. {
  70. _logger.LogInformation(output);
  71. yield return output;
  72. }
  73. }
  74. }