diff --git a/controllers/message.go b/controllers/message.go index 08e6c14..ab67f40 100644 --- a/controllers/message.go +++ b/controllers/message.go @@ -132,16 +132,16 @@ func (c *ApiController) GetMessageAnswer() { question := questionMessage.Text - nearestText, err := object.GetNearestVectorText(embeddingProviderObj, chat.Owner, question) + knowledge, vectorScores, err := object.GetNearestKnowledge(embeddingProviderObj, chat.Owner, question) if err != nil && err.Error() != "no knowledge vectors found" { c.ResponseErrorStream(err.Error()) return } - realQuestion := object.GetRefinedQuestion(nearestText, question) + realQuestion := object.GetRefinedQuestion(knowledge, question) fmt.Printf("Question: [%s]\n", question) - fmt.Printf("Context: [%s]\n", nearestText) + fmt.Printf("Knowledge: [%s]\n", knowledge) // fmt.Printf("Refined Question: [%s]\n", realQuestion) fmt.Printf("Answer: [") @@ -165,6 +165,7 @@ func (c *ApiController) GetMessageAnswer() { answer := writer.String() message.Text = answer + message.VectorScores = vectorScores _, err = object.UpdateMessage(message.GetId(), message) if err != nil { c.ResponseErrorStream(err.Error()) @@ -227,10 +228,11 @@ func (c *ApiController) AddMessage() { Name: fmt.Sprintf("message_%s", util.GetRandomName()), CreatedTime: util.GetCurrentTimeEx(message.CreatedTime), // Organization: message.Organization, - Chat: message.Chat, - ReplyTo: message.GetId(), - Author: "AI", - Text: "", + Chat: message.Chat, + ReplyTo: message.GetId(), + Author: "AI", + Text: "", + VectorScores: []object.VectorScore{}, } _, err = object.AddMessage(answerMessage) if err != nil { diff --git a/object/message.go b/object/message.go index 5f19110..8260dac 100644 --- a/object/message.go +++ b/object/message.go @@ -21,16 +21,22 @@ import ( "xorm.io/core" ) +type VectorScore struct { + Vector string `xorm:"varchar(100)" json:"vector"` + Score float32 `json:"score"` +} + type Message struct { Owner string `xorm:"varchar(100) notnull pk" json:"owner"` Name string `xorm:"varchar(100) notnull pk" json:"name"` CreatedTime string `xorm:"varchar(100)" json:"createdTime"` // Organization string `xorm:"varchar(100)" json:"organization"` - Chat string `xorm:"varchar(100) index" json:"chat"` - ReplyTo string `xorm:"varchar(100) index" json:"replyTo"` - Author string `xorm:"varchar(100)" json:"author"` - Text string `xorm:"mediumtext" json:"text"` + Chat string `xorm:"varchar(100) index" json:"chat"` + ReplyTo string `xorm:"varchar(100) index" json:"replyTo"` + Author string `xorm:"varchar(100)" json:"author"` + Text string `xorm:"mediumtext" json:"text"` + VectorScores []VectorScore `xorm:"mediumtext" json:"vectorScores"` } func GetGlobalMessages() ([]*Message, error) { diff --git a/object/search.go b/object/search.go index fece54c..d102d16 100644 --- a/object/search.go +++ b/object/search.go @@ -15,7 +15,7 @@ package object type SearchProvider interface { - Search(qVector []float32) (string, error) + Search(qVector []float32) ([]Vector, error) } func GetSearchProvider(typ string, owner string) (SearchProvider, error) { diff --git a/object/search_default.go b/object/search_default.go index 88c2f46..a6b7283 100644 --- a/object/search_default.go +++ b/object/search_default.go @@ -22,17 +22,24 @@ func NewDefaultSearchProvider(owner string) (*DefaultSearchProvider, error) { return &DefaultSearchProvider{owner: owner}, nil } -func (p *DefaultSearchProvider) Search(qVector []float32) (string, error) { +func (p *DefaultSearchProvider) Search(qVector []float32) ([]Vector, error) { vectors, err := getRelatedVectors(p.owner) if err != nil { - return "", err + return nil, err } - var nVectors [][]float32 + var vectorData [][]float32 for _, candidate := range vectors { - nVectors = append(nVectors, candidate.Data) + vectorData = append(vectorData, candidate.Data) } - i := getNearestVectorIndex(qVector, nVectors) - return vectors[i].Text, nil + res := []Vector{} + similarities := getNearestVectors(qVector, vectorData, 5) + for _, similarity := range similarities { + vector := vectors[similarity.Index] + vector.Score = similarity.Similarity + res = append(res, *vector) + } + + return res, nil } diff --git a/object/vector_util.go b/object/search_default_util.go similarity index 74% rename from object/vector_util.go rename to object/search_default_util.go index c785e78..8f4d656 100644 --- a/object/vector_util.go +++ b/object/search_default_util.go @@ -14,7 +14,10 @@ package object -import "math" +import ( + "math" + "sort" +) func dot(vec1, vec2 []float32) float32 { if len(vec1) != len(vec2) { @@ -45,17 +48,27 @@ func cosineSimilarity(vec1, vec2 []float32, vec1Norm float32) float32 { return dotProduct / (vec1Norm * vec2Norm) } -func getNearestVectorIndex(target []float32, vectors [][]float32) int { +type SimilarityIndex struct { + Similarity float32 + Index int +} + +func getNearestVectors(target []float32, vectors [][]float32, n int) []SimilarityIndex { targetNorm := norm(target) - var res int - max := float32(-1.0) + similarities := []SimilarityIndex{} for i, vector := range vectors { similarity := cosineSimilarity(target, vector, targetNorm) - if similarity > max { - max = similarity - res = i - } + similarities = append(similarities, SimilarityIndex{similarity, i}) + } + + sort.Slice(similarities, func(i, j int) bool { + return similarities[i].Similarity > similarities[j].Similarity + }) + + if len(vectors) < n { + n = len(vectors) } - return res + + return similarities } diff --git a/object/search_hnsw.go b/object/search_hnsw.go index 5ffb7f9..b093d8e 100644 --- a/object/search_hnsw.go +++ b/object/search_hnsw.go @@ -29,13 +29,8 @@ func NewHnswSearchProvider() (*HnswSearchProvider, error) { return &HnswSearchProvider{}, nil } -func (p *HnswSearchProvider) Search(qVector []float32) (string, error) { - search, err := Index.Search(qVector) - if err != nil { - return "", err - } - - return search.Text, nil +func (p *HnswSearchProvider) Search(qVector []float32) ([]Vector, error) { + return Index.Search(qVector) } var Index *HNSWIndex @@ -75,11 +70,16 @@ func (h *HNSWIndex) Add(name string, vector []float32) error { return h.save() } -func (h *HNSWIndex) Search(vector []float32) (*Vector, error) { +func (h *HNSWIndex) Search(vector []float32) ([]Vector, error) { result := h.Hnsw.Search(vector, 100, 4) item := result.Pop() + owner, name := util.GetOwnerAndNameFromId(h.IdToStr[item.ID]) - return getVector(owner, name) + v, err := getVector(owner, name) + if err != nil { + return nil, err + } + return []Vector{*v}, nil } func (h *HNSWIndex) save() error { diff --git a/object/vector.go b/object/vector.go index 11e17ac..4e2491f 100644 --- a/object/vector.go +++ b/object/vector.go @@ -26,12 +26,13 @@ type Vector struct { Name string `xorm:"varchar(100) notnull pk" json:"name"` CreatedTime string `xorm:"varchar(100)" json:"createdTime"` - DisplayName string `xorm:"varchar(100)" json:"displayName"` - Store string `xorm:"varchar(100)" json:"store"` - Provider string `xorm:"varchar(100)" json:"provider"` - File string `xorm:"varchar(100)" json:"file"` - Index int `json:"index"` - Text string `xorm:"mediumtext" json:"text"` + DisplayName string `xorm:"varchar(100)" json:"displayName"` + Store string `xorm:"varchar(100)" json:"store"` + Provider string `xorm:"varchar(100)" json:"provider"` + File string `xorm:"varchar(100)" json:"file"` + Index int `json:"index"` + Text string `xorm:"mediumtext" json:"text"` + Score float32 `json:"score"` Data []float32 `xorm:"mediumtext" json:"data"` Dimension int `json:"dimension"` diff --git a/object/vector_embedding.go b/object/vector_embedding.go index cc95a76..c7aee14 100644 --- a/object/vector_embedding.go +++ b/object/vector_embedding.go @@ -18,6 +18,7 @@ import ( "context" "fmt" "path/filepath" + "strings" "time" "github.com/casbin/casibase/embedding" @@ -149,19 +150,35 @@ func queryVectorSafe(embeddingProvider embedding.EmbeddingProvider, text string) } } -func GetNearestVectorText(embeddingProvider embedding.EmbeddingProvider, owner string, text string) (string, error) { +func GetNearestKnowledge(embeddingProvider embedding.EmbeddingProvider, owner string, text string) (string, []VectorScore, error) { qVector, err := queryVectorSafe(embeddingProvider, text) if err != nil { - return "", err + return "", nil, err } if qVector == nil { - return "", fmt.Errorf("no qVector found") + return "", nil, fmt.Errorf("no qVector found") } searchProvider, err := GetSearchProvider("Default", owner) if err != nil { - return "", err + return "", nil, err } - return searchProvider.Search(qVector) + vectors, err := searchProvider.Search(qVector) + if err != nil { + return "", nil, err + } + + vectorScores := []VectorScore{} + texts := []string{} + for _, vector := range vectors { + vectorScores = append(vectorScores, VectorScore{ + Vector: vector.Name, + Score: vector.Score, + }) + texts = append(texts, vector.Text) + } + + res := strings.Join(texts, "\n\n") + return res, vectorScores, nil }