You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

vector_embedding.go 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. // Copyright 2023 The casbin Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package object
  15. import (
  16. "context"
  17. "fmt"
  18. "io"
  19. "net/http"
  20. "path/filepath"
  21. "time"
  22. "github.com/casbin/casibase/ai"
  23. "github.com/casbin/casibase/storage"
  24. "github.com/casbin/casibase/util"
  25. "golang.org/x/time/rate"
  26. )
  27. func filterTextFiles(files []*storage.Object) []*storage.Object {
  28. extSet := map[string]bool{
  29. ".txt": true,
  30. ".md": true,
  31. ".docx": true,
  32. ".doc": false,
  33. ".pdf": true,
  34. }
  35. var res []*storage.Object
  36. for _, file := range files {
  37. ext := filepath.Ext(file.Key)
  38. if extSet[ext] {
  39. res = append(res, file)
  40. }
  41. }
  42. return res
  43. }
  44. func getFilteredFileObjects(provider string, prefix string) ([]*storage.Object, error) {
  45. files, err := storage.ListObjects(provider, prefix)
  46. if err != nil {
  47. return nil, err
  48. }
  49. return filterTextFiles(files), nil
  50. }
  51. func getObjectFile(object *storage.Object) (io.ReadCloser, error) {
  52. resp, err := http.Get(object.Url)
  53. if err != nil {
  54. return nil, err
  55. }
  56. if resp.StatusCode != http.StatusOK {
  57. resp.Body.Close()
  58. return nil, fmt.Errorf("HTTP request failed with status code: %d", resp.StatusCode)
  59. }
  60. return resp.Body, nil
  61. }
  62. func addEmbeddedVector(authToken string, text string, storeName string, fileName string) (bool, error) {
  63. embedding, err := ai.GetEmbeddingSafe(authToken, text)
  64. if err != nil {
  65. return false, err
  66. }
  67. displayName := text
  68. if len(text) > 25 {
  69. displayName = text[:25]
  70. }
  71. vector := &Vector{
  72. Owner: "admin",
  73. Name: fmt.Sprintf("vector_%s", util.GetRandomName()),
  74. CreatedTime: util.GetCurrentTime(),
  75. DisplayName: displayName,
  76. Store: storeName,
  77. File: fileName,
  78. Text: text,
  79. Data: embedding,
  80. }
  81. return AddVector(vector)
  82. }
  83. func addVectorsForStore(authToken string, provider string, key string, storeName string) (bool, error) {
  84. timeLimiter := rate.NewLimiter(rate.Every(time.Minute), 3)
  85. objs, err := getFilteredFileObjects(provider, key)
  86. if err != nil {
  87. return false, err
  88. }
  89. if len(objs) == 0 {
  90. return false, nil
  91. }
  92. for _, obj := range objs {
  93. f, err := getObjectFile(obj)
  94. if err != nil {
  95. return false, err
  96. }
  97. defer f.Close()
  98. filename := obj.Key
  99. text, err := ai.ReadFileToString(f, filename)
  100. if err != nil {
  101. return false, err
  102. }
  103. textSections := ai.SplitText(text)
  104. for _, textSection := range textSections {
  105. if timeLimiter.Allow() {
  106. ok, err := addEmbeddedVector(authToken, textSection, storeName, obj.Key)
  107. if err != nil {
  108. return false, err
  109. }
  110. if !ok {
  111. return false, nil
  112. }
  113. } else {
  114. err := timeLimiter.Wait(context.Background())
  115. if err != nil {
  116. return false, err
  117. }
  118. ok, err := addEmbeddedVector(authToken, textSection, storeName, obj.Key)
  119. if err != nil {
  120. return false, err
  121. }
  122. if !ok {
  123. return false, nil
  124. }
  125. }
  126. }
  127. }
  128. return true, nil
  129. }
  130. func getRelatedVectors(owner string) ([]*Vector, error) {
  131. vectors, err := GetVectors(owner)
  132. if err != nil {
  133. return nil, err
  134. }
  135. if len(vectors) == 0 {
  136. return nil, fmt.Errorf("no knowledge vectors found")
  137. }
  138. return vectors, nil
  139. }
  140. func GetNearestVectorText(authToken string, owner string, question string) (string, error) {
  141. qVector, err := ai.GetEmbeddingSafe(authToken, question)
  142. if err != nil {
  143. return "", err
  144. }
  145. if qVector == nil {
  146. return "", fmt.Errorf("no qVector found")
  147. }
  148. vectors, err := getRelatedVectors(owner)
  149. if err != nil {
  150. return "", err
  151. }
  152. var nVectors [][]float32
  153. for _, candidate := range vectors {
  154. nVectors = append(nVectors, candidate.Data)
  155. }
  156. i := ai.GetNearestVectorIndex(qVector, nVectors)
  157. return vectors[i].Text, nil
  158. }