You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLama3ChatSession.cs 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. using LLama.Abstractions;
  2. using LLama.Common;
  3. namespace LLama.Examples.Examples;
  4. // When using chatsession, it's a common case that you want to strip the role names
  5. // rather than display them. This example shows how to use transforms to strip them.
  6. public class LLama3ChatSession
  7. {
  8. public static async Task Run()
  9. {
  10. string modelPath = UserSettings.GetModelPath();
  11. var parameters = new ModelParams(modelPath)
  12. {
  13. Seed = 1337,
  14. GpuLayerCount = 10
  15. };
  16. using var model = LLamaWeights.LoadFromFile(parameters);
  17. using var context = model.CreateContext(parameters);
  18. var executor = new InteractiveExecutor(context);
  19. var chatHistoryJson = File.ReadAllText("Assets/chat-with-bob.json");
  20. ChatHistory chatHistory = ChatHistory.FromJson(chatHistoryJson) ?? new ChatHistory();
  21. ChatSession session = new(executor, chatHistory);
  22. session.WithHistoryTransform(new LLama3HistoryTransform());
  23. session.WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(
  24. new string[] { "User:", "Assistant:", "�" },
  25. redundancyLength: 5));
  26. InferenceParams inferenceParams = new InferenceParams()
  27. {
  28. Temperature = 0.6f,
  29. AntiPrompts = new List<string> { "User:" }
  30. };
  31. Console.ForegroundColor = ConsoleColor.Yellow;
  32. Console.WriteLine("The chat session has started.");
  33. // show the prompt
  34. Console.ForegroundColor = ConsoleColor.Green;
  35. string userInput = Console.ReadLine() ?? "";
  36. while (userInput != "exit")
  37. {
  38. await foreach (
  39. var text
  40. in session.ChatAsync(
  41. new ChatHistory.Message(AuthorRole.User, userInput),
  42. inferenceParams))
  43. {
  44. Console.ForegroundColor = ConsoleColor.White;
  45. Console.Write(text);
  46. }
  47. Console.WriteLine();
  48. Console.ForegroundColor = ConsoleColor.Green;
  49. userInput = Console.ReadLine() ?? "";
  50. Console.ForegroundColor = ConsoleColor.White;
  51. }
  52. }
  53. class LLama3HistoryTransform : IHistoryTransform
  54. {
  55. /// <summary>
  56. /// Convert a ChatHistory instance to plain text.
  57. /// </summary>
  58. /// <param name="history">The ChatHistory instance</param>
  59. /// <returns></returns>
  60. public string HistoryToText(ChatHistory history)
  61. {
  62. string res = Bos;
  63. foreach (var message in history.Messages)
  64. {
  65. res += EncodeMessage(message);
  66. }
  67. res += EncodeHeader(new ChatHistory.Message(AuthorRole.Assistant, ""));
  68. return res;
  69. }
  70. private string EncodeHeader(ChatHistory.Message message)
  71. {
  72. string res = StartHeaderId;
  73. res += message.AuthorRole.ToString();
  74. res += EndHeaderId;
  75. res += "\n\n";
  76. return res;
  77. }
  78. private string EncodeMessage(ChatHistory.Message message)
  79. {
  80. string res = EncodeHeader(message);
  81. res += message.Content;
  82. res += EndofTurn;
  83. return res;
  84. }
  85. /// <summary>
  86. /// Converts plain text to a ChatHistory instance.
  87. /// </summary>
  88. /// <param name="role">The role for the author.</param>
  89. /// <param name="text">The chat history as plain text.</param>
  90. /// <returns>The updated history.</returns>
  91. public ChatHistory TextToHistory(AuthorRole role, string text)
  92. {
  93. return new ChatHistory(new ChatHistory.Message[] { new ChatHistory.Message(role, text) });
  94. }
  95. /// <summary>
  96. /// Copy the transform.
  97. /// </summary>
  98. /// <returns></returns>
  99. public IHistoryTransform Clone()
  100. {
  101. return new LLama3HistoryTransform();
  102. }
  103. private const string StartHeaderId = "<|start_header_id|>";
  104. private const string EndHeaderId = "<|end_header_id|>";
  105. private const string Bos = "<|begin_of_text|>";
  106. private const string Eos = "<|end_of_text|>";
  107. private const string EndofTurn = "<|eot_id|>";
  108. }
  109. }