You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ernie.go 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. // Copyright 2023 The casbin Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package model
  15. import (
  16. "context"
  17. "errors"
  18. "fmt"
  19. "io"
  20. "net/http"
  21. "strings"
  22. ernie "github.com/anhao/go-ernie"
  23. )
  24. type ErnieModelProvider struct {
  25. subType string
  26. apiKey string
  27. secretKey string
  28. temperature float32
  29. topP float32
  30. presencePenalty float32
  31. }
  32. func NewErnieModelProvider(subType string, apiKey string, secretKey string, temperature float32, topP float32, presencePenalty float32) (*ErnieModelProvider, error) {
  33. return &ErnieModelProvider{
  34. subType: subType,
  35. apiKey: apiKey,
  36. secretKey: secretKey,
  37. temperature: temperature,
  38. topP: topP,
  39. presencePenalty: presencePenalty,
  40. }, nil
  41. }
  42. func (p *ErnieModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error {
  43. client := ernie.NewDefaultClient(p.apiKey, p.secretKey)
  44. ctx := context.Background()
  45. flusher, ok := writer.(http.Flusher)
  46. if !ok {
  47. return fmt.Errorf("writer does not implement http.Flusher")
  48. }
  49. messages := []ernie.ChatCompletionMessage{
  50. {
  51. Role: "user",
  52. Content: question,
  53. },
  54. }
  55. flushData := func(data string) error {
  56. if _, err := fmt.Fprintf(writer, "event: message\ndata: %s\n\n", data); err != nil {
  57. return err
  58. }
  59. flusher.Flush()
  60. builder.WriteString(data)
  61. return nil
  62. }
  63. temperature := p.temperature
  64. topP := p.topP
  65. presencePenalty := p.presencePenalty
  66. if p.subType == "ERNIE-Bot" {
  67. stream, err := client.CreateErnieBotChatCompletionStream(ctx,
  68. ernie.ErnieBotRequest{
  69. Messages: messages,
  70. Temperature: temperature,
  71. TopP: topP,
  72. PresencePenalty: presencePenalty,
  73. })
  74. if err != nil {
  75. return err
  76. }
  77. defer stream.Close()
  78. for {
  79. response, err := stream.Recv()
  80. if errors.Is(err, io.EOF) {
  81. return nil
  82. }
  83. if err != nil {
  84. return err
  85. }
  86. err = flushData(response.Result)
  87. if err != nil {
  88. return err
  89. }
  90. }
  91. } else if p.subType == "ERNIE-Bot-turbo" {
  92. stream, err := client.CreateErnieBotTurboChatCompletionStream(ctx,
  93. ernie.ErnieBotTurboRequest{
  94. Messages: messages,
  95. Temperature: temperature,
  96. TopP: topP,
  97. PresencePenalty: presencePenalty,
  98. })
  99. if err != nil {
  100. return err
  101. }
  102. defer stream.Close()
  103. for {
  104. response, err := stream.Recv()
  105. if errors.Is(err, io.EOF) {
  106. return nil
  107. }
  108. if err != nil {
  109. return err
  110. }
  111. err = flushData(response.Result)
  112. if err != nil {
  113. return err
  114. }
  115. }
  116. } else if p.subType == "BLOOMZ-7B" {
  117. stream, err := client.CreateBloomz7b1ChatCompletionStream(ctx, ernie.Bloomz7b1Request{Messages: messages})
  118. if err != nil {
  119. return err
  120. }
  121. defer stream.Close()
  122. for {
  123. response, err := stream.Recv()
  124. if errors.Is(err, io.EOF) {
  125. return nil
  126. }
  127. if err != nil {
  128. return err
  129. }
  130. err = flushData(response.Result)
  131. if err != nil {
  132. return err
  133. }
  134. }
  135. } else if p.subType == "Llama-2" {
  136. stream, err := client.CreateLlamaChatCompletionStream(ctx, ernie.LlamaChatRequest{Messages: messages})
  137. if err != nil {
  138. return err
  139. }
  140. defer stream.Close()
  141. for {
  142. response, err := stream.Recv()
  143. if errors.Is(err, io.EOF) {
  144. return nil
  145. }
  146. if err != nil {
  147. return err
  148. }
  149. err = flushData(response.Result)
  150. if err != nil {
  151. return err
  152. }
  153. }
  154. }
  155. return nil
  156. }