| @@ -132,16 +132,16 @@ func (c *ApiController) GetMessageAnswer() { | |||
| question := questionMessage.Text | |||
| nearestText, err := object.GetNearestVectorText(embeddingProviderObj, chat.Owner, question) | |||
| knowledge, vectorScores, err := object.GetNearestKnowledge(embeddingProviderObj, chat.Owner, question) | |||
| if err != nil && err.Error() != "no knowledge vectors found" { | |||
| c.ResponseErrorStream(err.Error()) | |||
| return | |||
| } | |||
| realQuestion := object.GetRefinedQuestion(nearestText, question) | |||
| realQuestion := object.GetRefinedQuestion(knowledge, question) | |||
| fmt.Printf("Question: [%s]\n", question) | |||
| fmt.Printf("Context: [%s]\n", nearestText) | |||
| fmt.Printf("Knowledge: [%s]\n", knowledge) | |||
| // fmt.Printf("Refined Question: [%s]\n", realQuestion) | |||
| fmt.Printf("Answer: [") | |||
| @@ -165,6 +165,7 @@ func (c *ApiController) GetMessageAnswer() { | |||
| answer := writer.String() | |||
| message.Text = answer | |||
| message.VectorScores = vectorScores | |||
| _, err = object.UpdateMessage(message.GetId(), message) | |||
| if err != nil { | |||
| c.ResponseErrorStream(err.Error()) | |||
| @@ -227,10 +228,11 @@ func (c *ApiController) AddMessage() { | |||
| Name: fmt.Sprintf("message_%s", util.GetRandomName()), | |||
| CreatedTime: util.GetCurrentTimeEx(message.CreatedTime), | |||
| // Organization: message.Organization, | |||
| Chat: message.Chat, | |||
| ReplyTo: message.GetId(), | |||
| Author: "AI", | |||
| Text: "", | |||
| Chat: message.Chat, | |||
| ReplyTo: message.GetId(), | |||
| Author: "AI", | |||
| Text: "", | |||
| VectorScores: []object.VectorScore{}, | |||
| } | |||
| _, err = object.AddMessage(answerMessage) | |||
| if err != nil { | |||
| @@ -21,16 +21,22 @@ import ( | |||
| "xorm.io/core" | |||
| ) | |||
| type VectorScore struct { | |||
| Vector string `xorm:"varchar(100)" json:"vector"` | |||
| Score float32 `json:"score"` | |||
| } | |||
| type Message struct { | |||
| Owner string `xorm:"varchar(100) notnull pk" json:"owner"` | |||
| Name string `xorm:"varchar(100) notnull pk" json:"name"` | |||
| CreatedTime string `xorm:"varchar(100)" json:"createdTime"` | |||
| // Organization string `xorm:"varchar(100)" json:"organization"` | |||
| Chat string `xorm:"varchar(100) index" json:"chat"` | |||
| ReplyTo string `xorm:"varchar(100) index" json:"replyTo"` | |||
| Author string `xorm:"varchar(100)" json:"author"` | |||
| Text string `xorm:"mediumtext" json:"text"` | |||
| Chat string `xorm:"varchar(100) index" json:"chat"` | |||
| ReplyTo string `xorm:"varchar(100) index" json:"replyTo"` | |||
| Author string `xorm:"varchar(100)" json:"author"` | |||
| Text string `xorm:"mediumtext" json:"text"` | |||
| VectorScores []VectorScore `xorm:"mediumtext" json:"vectorScores"` | |||
| } | |||
| func GetGlobalMessages() ([]*Message, error) { | |||
| @@ -15,7 +15,7 @@ | |||
| package object | |||
| type SearchProvider interface { | |||
| Search(qVector []float32) (string, error) | |||
| Search(qVector []float32) ([]Vector, error) | |||
| } | |||
| func GetSearchProvider(typ string, owner string) (SearchProvider, error) { | |||
| @@ -22,17 +22,24 @@ func NewDefaultSearchProvider(owner string) (*DefaultSearchProvider, error) { | |||
| return &DefaultSearchProvider{owner: owner}, nil | |||
| } | |||
| func (p *DefaultSearchProvider) Search(qVector []float32) (string, error) { | |||
| func (p *DefaultSearchProvider) Search(qVector []float32) ([]Vector, error) { | |||
| vectors, err := getRelatedVectors(p.owner) | |||
| if err != nil { | |||
| return "", err | |||
| return nil, err | |||
| } | |||
| var nVectors [][]float32 | |||
| var vectorData [][]float32 | |||
| for _, candidate := range vectors { | |||
| nVectors = append(nVectors, candidate.Data) | |||
| vectorData = append(vectorData, candidate.Data) | |||
| } | |||
| i := getNearestVectorIndex(qVector, nVectors) | |||
| return vectors[i].Text, nil | |||
| res := []Vector{} | |||
| similarities := getNearestVectors(qVector, vectorData, 5) | |||
| for _, similarity := range similarities { | |||
| vector := vectors[similarity.Index] | |||
| vector.Score = similarity.Similarity | |||
| res = append(res, *vector) | |||
| } | |||
| return res, nil | |||
| } | |||
| @@ -14,7 +14,10 @@ | |||
| package object | |||
| import "math" | |||
| import ( | |||
| "math" | |||
| "sort" | |||
| ) | |||
| func dot(vec1, vec2 []float32) float32 { | |||
| if len(vec1) != len(vec2) { | |||
| @@ -45,17 +48,27 @@ func cosineSimilarity(vec1, vec2 []float32, vec1Norm float32) float32 { | |||
| return dotProduct / (vec1Norm * vec2Norm) | |||
| } | |||
| func getNearestVectorIndex(target []float32, vectors [][]float32) int { | |||
| type SimilarityIndex struct { | |||
| Similarity float32 | |||
| Index int | |||
| } | |||
| func getNearestVectors(target []float32, vectors [][]float32, n int) []SimilarityIndex { | |||
| targetNorm := norm(target) | |||
| var res int | |||
| max := float32(-1.0) | |||
| similarities := []SimilarityIndex{} | |||
| for i, vector := range vectors { | |||
| similarity := cosineSimilarity(target, vector, targetNorm) | |||
| if similarity > max { | |||
| max = similarity | |||
| res = i | |||
| } | |||
| similarities = append(similarities, SimilarityIndex{similarity, i}) | |||
| } | |||
| sort.Slice(similarities, func(i, j int) bool { | |||
| return similarities[i].Similarity > similarities[j].Similarity | |||
| }) | |||
| if len(vectors) < n { | |||
| n = len(vectors) | |||
| } | |||
| return res | |||
| return similarities | |||
| } | |||
| @@ -29,13 +29,8 @@ func NewHnswSearchProvider() (*HnswSearchProvider, error) { | |||
| return &HnswSearchProvider{}, nil | |||
| } | |||
| func (p *HnswSearchProvider) Search(qVector []float32) (string, error) { | |||
| search, err := Index.Search(qVector) | |||
| if err != nil { | |||
| return "", err | |||
| } | |||
| return search.Text, nil | |||
| func (p *HnswSearchProvider) Search(qVector []float32) ([]Vector, error) { | |||
| return Index.Search(qVector) | |||
| } | |||
| var Index *HNSWIndex | |||
| @@ -75,11 +70,16 @@ func (h *HNSWIndex) Add(name string, vector []float32) error { | |||
| return h.save() | |||
| } | |||
| func (h *HNSWIndex) Search(vector []float32) (*Vector, error) { | |||
| func (h *HNSWIndex) Search(vector []float32) ([]Vector, error) { | |||
| result := h.Hnsw.Search(vector, 100, 4) | |||
| item := result.Pop() | |||
| owner, name := util.GetOwnerAndNameFromId(h.IdToStr[item.ID]) | |||
| return getVector(owner, name) | |||
| v, err := getVector(owner, name) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| return []Vector{*v}, nil | |||
| } | |||
| func (h *HNSWIndex) save() error { | |||
| @@ -26,12 +26,13 @@ type Vector struct { | |||
| Name string `xorm:"varchar(100) notnull pk" json:"name"` | |||
| CreatedTime string `xorm:"varchar(100)" json:"createdTime"` | |||
| DisplayName string `xorm:"varchar(100)" json:"displayName"` | |||
| Store string `xorm:"varchar(100)" json:"store"` | |||
| Provider string `xorm:"varchar(100)" json:"provider"` | |||
| File string `xorm:"varchar(100)" json:"file"` | |||
| Index int `json:"index"` | |||
| Text string `xorm:"mediumtext" json:"text"` | |||
| DisplayName string `xorm:"varchar(100)" json:"displayName"` | |||
| Store string `xorm:"varchar(100)" json:"store"` | |||
| Provider string `xorm:"varchar(100)" json:"provider"` | |||
| File string `xorm:"varchar(100)" json:"file"` | |||
| Index int `json:"index"` | |||
| Text string `xorm:"mediumtext" json:"text"` | |||
| Score float32 `json:"score"` | |||
| Data []float32 `xorm:"mediumtext" json:"data"` | |||
| Dimension int `json:"dimension"` | |||
| @@ -18,6 +18,7 @@ import ( | |||
| "context" | |||
| "fmt" | |||
| "path/filepath" | |||
| "strings" | |||
| "time" | |||
| "github.com/casbin/casibase/embedding" | |||
| @@ -149,19 +150,35 @@ func queryVectorSafe(embeddingProvider embedding.EmbeddingProvider, text string) | |||
| } | |||
| } | |||
| func GetNearestVectorText(embeddingProvider embedding.EmbeddingProvider, owner string, text string) (string, error) { | |||
| func GetNearestKnowledge(embeddingProvider embedding.EmbeddingProvider, owner string, text string) (string, []VectorScore, error) { | |||
| qVector, err := queryVectorSafe(embeddingProvider, text) | |||
| if err != nil { | |||
| return "", err | |||
| return "", nil, err | |||
| } | |||
| if qVector == nil { | |||
| return "", fmt.Errorf("no qVector found") | |||
| return "", nil, fmt.Errorf("no qVector found") | |||
| } | |||
| searchProvider, err := GetSearchProvider("Default", owner) | |||
| if err != nil { | |||
| return "", err | |||
| return "", nil, err | |||
| } | |||
| return searchProvider.Search(qVector) | |||
| vectors, err := searchProvider.Search(qVector) | |||
| if err != nil { | |||
| return "", nil, err | |||
| } | |||
| vectorScores := []VectorScore{} | |||
| texts := []string{} | |||
| for _, vector := range vectors { | |||
| vectorScores = append(vectorScores, VectorScore{ | |||
| Vector: vector.Name, | |||
| Score: vector.Score, | |||
| }) | |||
| texts = append(texts, vector.Text) | |||
| } | |||
| res := strings.Join(texts, "\n\n") | |||
| return res, vectorScores, nil | |||
| } | |||