You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

openai.go 2.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. // Copyright 2023 The casbin Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package ai
  15. import (
  16. "context"
  17. "fmt"
  18. "io"
  19. "net/http"
  20. "strings"
  21. "github.com/sashabaranov/go-openai"
  22. )
  23. type OpenAiModelProvider struct {
  24. SubType string
  25. SecretKey string
  26. }
  27. func NewOpenAiModelProvider(subType string, secretKey string) (*OpenAiModelProvider, error) {
  28. p := &OpenAiModelProvider{
  29. SubType: subType,
  30. SecretKey: secretKey,
  31. }
  32. return p, nil
  33. }
  34. func (p *OpenAiModelProvider) QueryText(question string, writer io.Writer, builder *strings.Builder) error {
  35. client := getProxyClientFromToken(p.SecretKey)
  36. ctx := context.Background()
  37. flusher, ok := writer.(http.Flusher)
  38. if !ok {
  39. return fmt.Errorf("writer does not implement http.Flusher")
  40. }
  41. model := p.SubType
  42. if model == "" {
  43. model = openai.GPT3TextDavinci003
  44. }
  45. // https://platform.openai.com/tokenizer
  46. // https://github.com/pkoukk/tiktoken-go#available-encodings
  47. promptTokens, err := GetTokenSize(model, question)
  48. if err != nil {
  49. return err
  50. }
  51. // https://platform.openai.com/docs/models/gpt-3-5
  52. maxTokens := 4097 - promptTokens
  53. respStream, err := client.CreateCompletionStream(
  54. ctx,
  55. openai.CompletionRequest{
  56. Model: model,
  57. Prompt: question,
  58. MaxTokens: maxTokens,
  59. Stream: true,
  60. },
  61. )
  62. if err != nil {
  63. return err
  64. }
  65. defer respStream.Close()
  66. isLeadingReturn := true
  67. for {
  68. completion, streamErr := respStream.Recv()
  69. if streamErr != nil {
  70. if streamErr == io.EOF {
  71. break
  72. }
  73. return streamErr
  74. }
  75. data := completion.Choices[0].Text
  76. if isLeadingReturn && len(data) != 0 {
  77. if strings.Count(data, "\n") == len(data) {
  78. continue
  79. } else {
  80. isLeadingReturn = false
  81. }
  82. }
  83. fmt.Printf("%s", data)
  84. // Write the streamed data as Server-Sent Events
  85. if _, err = fmt.Fprintf(writer, "event: message\ndata: %s\n\n", data); err != nil {
  86. return err
  87. }
  88. flusher.Flush()
  89. // Append the response to the strings.Builder
  90. builder.WriteString(data)
  91. }
  92. return nil
  93. }