@@ -132,16 +132,16 @@ func (c *ApiController) GetMessageAnswer() { | |||
question := questionMessage.Text | |||
nearestText, err := object.GetNearestVectorText(embeddingProviderObj, chat.Owner, question) | |||
knowledge, vectorScores, err := object.GetNearestKnowledge(embeddingProviderObj, chat.Owner, question) | |||
if err != nil && err.Error() != "no knowledge vectors found" { | |||
c.ResponseErrorStream(err.Error()) | |||
return | |||
} | |||
realQuestion := object.GetRefinedQuestion(nearestText, question) | |||
realQuestion := object.GetRefinedQuestion(knowledge, question) | |||
fmt.Printf("Question: [%s]\n", question) | |||
fmt.Printf("Context: [%s]\n", nearestText) | |||
fmt.Printf("Knowledge: [%s]\n", knowledge) | |||
// fmt.Printf("Refined Question: [%s]\n", realQuestion) | |||
fmt.Printf("Answer: [") | |||
@@ -165,6 +165,7 @@ func (c *ApiController) GetMessageAnswer() { | |||
answer := writer.String() | |||
message.Text = answer | |||
message.VectorScores = vectorScores | |||
_, err = object.UpdateMessage(message.GetId(), message) | |||
if err != nil { | |||
c.ResponseErrorStream(err.Error()) | |||
@@ -227,10 +228,11 @@ func (c *ApiController) AddMessage() { | |||
Name: fmt.Sprintf("message_%s", util.GetRandomName()), | |||
CreatedTime: util.GetCurrentTimeEx(message.CreatedTime), | |||
// Organization: message.Organization, | |||
Chat: message.Chat, | |||
ReplyTo: message.GetId(), | |||
Author: "AI", | |||
Text: "", | |||
Chat: message.Chat, | |||
ReplyTo: message.GetId(), | |||
Author: "AI", | |||
Text: "", | |||
VectorScores: []object.VectorScore{}, | |||
} | |||
_, err = object.AddMessage(answerMessage) | |||
if err != nil { | |||
@@ -21,16 +21,22 @@ import ( | |||
"xorm.io/core" | |||
) | |||
type VectorScore struct { | |||
Vector string `xorm:"varchar(100)" json:"vector"` | |||
Score float32 `json:"score"` | |||
} | |||
type Message struct { | |||
Owner string `xorm:"varchar(100) notnull pk" json:"owner"` | |||
Name string `xorm:"varchar(100) notnull pk" json:"name"` | |||
CreatedTime string `xorm:"varchar(100)" json:"createdTime"` | |||
// Organization string `xorm:"varchar(100)" json:"organization"` | |||
Chat string `xorm:"varchar(100) index" json:"chat"` | |||
ReplyTo string `xorm:"varchar(100) index" json:"replyTo"` | |||
Author string `xorm:"varchar(100)" json:"author"` | |||
Text string `xorm:"mediumtext" json:"text"` | |||
Chat string `xorm:"varchar(100) index" json:"chat"` | |||
ReplyTo string `xorm:"varchar(100) index" json:"replyTo"` | |||
Author string `xorm:"varchar(100)" json:"author"` | |||
Text string `xorm:"mediumtext" json:"text"` | |||
VectorScores []VectorScore `xorm:"mediumtext" json:"vectorScores"` | |||
} | |||
func GetGlobalMessages() ([]*Message, error) { | |||
@@ -15,7 +15,7 @@ | |||
package object | |||
type SearchProvider interface { | |||
Search(qVector []float32) (string, error) | |||
Search(qVector []float32) ([]Vector, error) | |||
} | |||
func GetSearchProvider(typ string, owner string) (SearchProvider, error) { | |||
@@ -22,17 +22,24 @@ func NewDefaultSearchProvider(owner string) (*DefaultSearchProvider, error) { | |||
return &DefaultSearchProvider{owner: owner}, nil | |||
} | |||
func (p *DefaultSearchProvider) Search(qVector []float32) (string, error) { | |||
func (p *DefaultSearchProvider) Search(qVector []float32) ([]Vector, error) { | |||
vectors, err := getRelatedVectors(p.owner) | |||
if err != nil { | |||
return "", err | |||
return nil, err | |||
} | |||
var nVectors [][]float32 | |||
var vectorData [][]float32 | |||
for _, candidate := range vectors { | |||
nVectors = append(nVectors, candidate.Data) | |||
vectorData = append(vectorData, candidate.Data) | |||
} | |||
i := getNearestVectorIndex(qVector, nVectors) | |||
return vectors[i].Text, nil | |||
res := []Vector{} | |||
similarities := getNearestVectors(qVector, vectorData, 5) | |||
for _, similarity := range similarities { | |||
vector := vectors[similarity.Index] | |||
vector.Score = similarity.Similarity | |||
res = append(res, *vector) | |||
} | |||
return res, nil | |||
} |
@@ -14,7 +14,10 @@ | |||
package object | |||
import "math" | |||
import ( | |||
"math" | |||
"sort" | |||
) | |||
func dot(vec1, vec2 []float32) float32 { | |||
if len(vec1) != len(vec2) { | |||
@@ -45,17 +48,27 @@ func cosineSimilarity(vec1, vec2 []float32, vec1Norm float32) float32 { | |||
return dotProduct / (vec1Norm * vec2Norm) | |||
} | |||
func getNearestVectorIndex(target []float32, vectors [][]float32) int { | |||
type SimilarityIndex struct { | |||
Similarity float32 | |||
Index int | |||
} | |||
func getNearestVectors(target []float32, vectors [][]float32, n int) []SimilarityIndex { | |||
targetNorm := norm(target) | |||
var res int | |||
max := float32(-1.0) | |||
similarities := []SimilarityIndex{} | |||
for i, vector := range vectors { | |||
similarity := cosineSimilarity(target, vector, targetNorm) | |||
if similarity > max { | |||
max = similarity | |||
res = i | |||
} | |||
similarities = append(similarities, SimilarityIndex{similarity, i}) | |||
} | |||
sort.Slice(similarities, func(i, j int) bool { | |||
return similarities[i].Similarity > similarities[j].Similarity | |||
}) | |||
if len(vectors) < n { | |||
n = len(vectors) | |||
} | |||
return res | |||
return similarities | |||
} |
@@ -29,13 +29,8 @@ func NewHnswSearchProvider() (*HnswSearchProvider, error) { | |||
return &HnswSearchProvider{}, nil | |||
} | |||
func (p *HnswSearchProvider) Search(qVector []float32) (string, error) { | |||
search, err := Index.Search(qVector) | |||
if err != nil { | |||
return "", err | |||
} | |||
return search.Text, nil | |||
func (p *HnswSearchProvider) Search(qVector []float32) ([]Vector, error) { | |||
return Index.Search(qVector) | |||
} | |||
var Index *HNSWIndex | |||
@@ -75,11 +70,16 @@ func (h *HNSWIndex) Add(name string, vector []float32) error { | |||
return h.save() | |||
} | |||
func (h *HNSWIndex) Search(vector []float32) (*Vector, error) { | |||
func (h *HNSWIndex) Search(vector []float32) ([]Vector, error) { | |||
result := h.Hnsw.Search(vector, 100, 4) | |||
item := result.Pop() | |||
owner, name := util.GetOwnerAndNameFromId(h.IdToStr[item.ID]) | |||
return getVector(owner, name) | |||
v, err := getVector(owner, name) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return []Vector{*v}, nil | |||
} | |||
func (h *HNSWIndex) save() error { | |||
@@ -26,12 +26,13 @@ type Vector struct { | |||
Name string `xorm:"varchar(100) notnull pk" json:"name"` | |||
CreatedTime string `xorm:"varchar(100)" json:"createdTime"` | |||
DisplayName string `xorm:"varchar(100)" json:"displayName"` | |||
Store string `xorm:"varchar(100)" json:"store"` | |||
Provider string `xorm:"varchar(100)" json:"provider"` | |||
File string `xorm:"varchar(100)" json:"file"` | |||
Index int `json:"index"` | |||
Text string `xorm:"mediumtext" json:"text"` | |||
DisplayName string `xorm:"varchar(100)" json:"displayName"` | |||
Store string `xorm:"varchar(100)" json:"store"` | |||
Provider string `xorm:"varchar(100)" json:"provider"` | |||
File string `xorm:"varchar(100)" json:"file"` | |||
Index int `json:"index"` | |||
Text string `xorm:"mediumtext" json:"text"` | |||
Score float32 `json:"score"` | |||
Data []float32 `xorm:"mediumtext" json:"data"` | |||
Dimension int `json:"dimension"` | |||
@@ -18,6 +18,7 @@ import ( | |||
"context" | |||
"fmt" | |||
"path/filepath" | |||
"strings" | |||
"time" | |||
"github.com/casbin/casibase/embedding" | |||
@@ -149,19 +150,35 @@ func queryVectorSafe(embeddingProvider embedding.EmbeddingProvider, text string) | |||
} | |||
} | |||
func GetNearestVectorText(embeddingProvider embedding.EmbeddingProvider, owner string, text string) (string, error) { | |||
func GetNearestKnowledge(embeddingProvider embedding.EmbeddingProvider, owner string, text string) (string, []VectorScore, error) { | |||
qVector, err := queryVectorSafe(embeddingProvider, text) | |||
if err != nil { | |||
return "", err | |||
return "", nil, err | |||
} | |||
if qVector == nil { | |||
return "", fmt.Errorf("no qVector found") | |||
return "", nil, fmt.Errorf("no qVector found") | |||
} | |||
searchProvider, err := GetSearchProvider("Default", owner) | |||
if err != nil { | |||
return "", err | |||
return "", nil, err | |||
} | |||
return searchProvider.Search(qVector) | |||
vectors, err := searchProvider.Search(qVector) | |||
if err != nil { | |||
return "", nil, err | |||
} | |||
vectorScores := []VectorScore{} | |||
texts := []string{} | |||
for _, vector := range vectors { | |||
vectorScores = append(vectorScores, VectorScore{ | |||
Vector: vector.Name, | |||
Score: vector.Score, | |||
}) | |||
texts = append(texts, vector.Text) | |||
} | |||
res := strings.Join(texts, "\n\n") | |||
return res, vectorScores, nil | |||
} |