You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sessionConnectionChat.js 7.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. const createConnectionSessionChat = () => {
  2. const outputErrorTemplate = $("#outputErrorTemplate").html();
  3. const outputInfoTemplate = $("#outputInfoTemplate").html();
  4. const outputUserTemplate = $("#outputUserTemplate").html();
  5. const outputBotTemplate = $("#outputBotTemplate").html();
  6. const signatureTemplate = $("#signatureTemplate").html();
  7. let inferenceSession;
  8. const connection = new signalR.HubConnectionBuilder().withUrl("/SessionConnectionHub").build();
  9. const scrollContainer = $("#scroll-container");
  10. const outputContainer = $("#output-container");
  11. const chatInput = $("#input");
  12. const onStatus = (connection, status) => {
  13. if (status == Enums.SessionConnectionStatus.Connected) {
  14. $("#socket").text("Connected").addClass("text-success");
  15. }
  16. else if (status == Enums.SessionConnectionStatus.Loaded) {
  17. loaderHide();
  18. enableControls();
  19. $("#load").hide();
  20. $("#unload").show();
  21. onInfo(`New model session successfully started`)
  22. }
  23. }
  24. const onError = (error) => {
  25. enableControls();
  26. outputContainer.append(Mustache.render(outputErrorTemplate, { text: error, date: getDateTime() }));
  27. }
  28. const onInfo = (message) => {
  29. outputContainer.append(Mustache.render(outputInfoTemplate, { text: message, date: getDateTime() }));
  30. }
  31. let responseContent;
  32. let responseContainer;
  33. let responseFirstToken;
  34. const onResponse = (response) => {
  35. if (!response)
  36. return;
  37. if (response.tokenType == Enums.TokenType.Begin) {
  38. const uniqueId = randomString();
  39. outputContainer.append(Mustache.render(outputBotTemplate, { id: uniqueId, ...response }));
  40. responseContainer = $(`#${uniqueId}`);
  41. responseContent = responseContainer.find(".content");
  42. responseFirstToken = true;
  43. scrollToBottom(true);
  44. return;
  45. }
  46. if (response.tokenType == Enums.TokenType.End || response.tokenType == Enums.TokenType.Cancel) {
  47. enableControls();
  48. responseContainer.find(".signature").append(Mustache.render(signatureTemplate, response));
  49. scrollToBottom();
  50. }
  51. else {
  52. if (responseFirstToken) {
  53. responseContent.empty();
  54. responseFirstToken = false;
  55. responseContainer.find(".date").append(getDateTime());
  56. }
  57. responseContent.append(response.content);
  58. scrollToBottom();
  59. }
  60. }
  61. const sendPrompt = async () => {
  62. const text = chatInput.val();
  63. if (text) {
  64. chatInput.val(null);
  65. disableControls();
  66. outputContainer.append(Mustache.render(outputUserTemplate, { text: text, date: getDateTime() }));
  67. inferenceSession = await connection
  68. .stream("SendPrompt", text, serializeFormToJson('SessionParameters'))
  69. .subscribe({
  70. next: onResponse,
  71. complete: onResponse,
  72. error: onError,
  73. });
  74. scrollToBottom(true);
  75. }
  76. }
  77. const cancelPrompt = async () => {
  78. if (inferenceSession)
  79. inferenceSession.dispose();
  80. }
  81. const loadModel = async () => {
  82. const sessionParams = serializeFormToJson('SessionParameters');
  83. loaderShow();
  84. disableControls();
  85. disablePromptControls();
  86. $("#load").attr("disabled", "disabled");
  87. // TODO: Split parameters sets
  88. await connection.invoke('LoadModel', sessionParams, sessionParams);
  89. }
  90. const unloadModel = async () => {
  91. disableControls();
  92. enablePromptControls();
  93. $("#load").removeAttr("disabled");
  94. }
  95. const serializeFormToJson = (form) => {
  96. const formDataJson = {};
  97. const formData = new FormData(document.getElementById(form));
  98. formData.forEach((value, key) => {
  99. if (key.includes("."))
  100. key = key.split(".")[1];
  101. // Convert number strings to numbers
  102. if (!isNaN(value) && value.trim() !== "") {
  103. formDataJson[key] = parseFloat(value);
  104. }
  105. // Convert boolean strings to booleans
  106. else if (value === "true" || value === "false") {
  107. formDataJson[key] = (value === "true");
  108. }
  109. else {
  110. formDataJson[key] = value;
  111. }
  112. });
  113. return formDataJson;
  114. }
  115. const enableControls = () => {
  116. $(".input-control").removeAttr("disabled");
  117. }
  118. const disableControls = () => {
  119. $(".input-control").attr("disabled", "disabled");
  120. }
  121. const enablePromptControls = () => {
  122. $("#load").show();
  123. $("#unload").hide();
  124. $(".prompt-control").removeAttr("disabled");
  125. activatePromptTab();
  126. }
  127. const disablePromptControls = () => {
  128. $(".prompt-control").attr("disabled", "disabled");
  129. activateParamsTab();
  130. }
  131. const clearOutput = () => {
  132. outputContainer.empty();
  133. }
  134. const updatePrompt = () => {
  135. const customPrompt = $("#PromptText");
  136. const selection = $("option:selected", "#Prompt");
  137. const selectedValue = selection.data("prompt");
  138. customPrompt.text(selectedValue);
  139. }
  140. const getDateTime = () => {
  141. const dateTime = new Date();
  142. return dateTime.toLocaleString();
  143. }
  144. const randomString = () => {
  145. return Math.random().toString(36).slice(2);
  146. }
  147. const scrollToBottom = (force) => {
  148. const scrollTop = scrollContainer.scrollTop();
  149. const scrollHeight = scrollContainer[0].scrollHeight;
  150. if (force) {
  151. scrollContainer.scrollTop(scrollContainer[0].scrollHeight);
  152. return;
  153. }
  154. if (scrollTop + 70 >= scrollHeight - scrollContainer.innerHeight()) {
  155. scrollContainer.scrollTop(scrollContainer[0].scrollHeight)
  156. }
  157. }
  158. const activatePromptTab = () => {
  159. $("#nav-prompt-tab").trigger("click");
  160. }
  161. const activateParamsTab = () => {
  162. $("#nav-params-tab").trigger("click");
  163. }
  164. const loaderShow = () => {
  165. $(".spinner").show();
  166. }
  167. const loaderHide = () => {
  168. $(".spinner").hide();
  169. }
  170. // Map UI functions
  171. $("#load").on("click", loadModel);
  172. $("#unload").on("click", unloadModel);
  173. $("#send").on("click", sendPrompt);
  174. $("#clear").on("click", clearOutput);
  175. $("#cancel").on("click", cancelPrompt);
  176. $("#Prompt").on("change", updatePrompt);
  177. chatInput.on('keydown', function (event) {
  178. if (event.key === 'Enter' && !event.shiftKey) {
  179. event.preventDefault();
  180. sendPrompt();
  181. }
  182. });
  183. $(".slider").on("input", function (e) {
  184. const slider = $(this);
  185. slider.next().text(slider.val());
  186. }).trigger("input");
  187. // Map signalr functions
  188. connection.on("OnStatus", onStatus);
  189. connection.on("OnError", onError);
  190. connection.on("OnResponse", onResponse);
  191. connection.start();
  192. }