Browse Source

Get top 5 knowlesge

master
Yang Luo 2 years ago
parent
commit
2c1a616ad6
8 changed files with 93 additions and 47 deletions
  1. +9
    -7
      controllers/message.go
  2. +10
    -4
      object/message.go
  3. +1
    -1
      object/search.go
  4. +13
    -6
      object/search_default.go
  5. +22
    -9
      object/search_default_util.go
  6. +9
    -9
      object/search_hnsw.go
  7. +7
    -6
      object/vector.go
  8. +22
    -5
      object/vector_embedding.go

+ 9
- 7
controllers/message.go View File

@@ -132,16 +132,16 @@ func (c *ApiController) GetMessageAnswer() {

question := questionMessage.Text

nearestText, err := object.GetNearestVectorText(embeddingProviderObj, chat.Owner, question)
knowledge, vectorScores, err := object.GetNearestKnowledge(embeddingProviderObj, chat.Owner, question)
if err != nil && err.Error() != "no knowledge vectors found" {
c.ResponseErrorStream(err.Error())
return
}

realQuestion := object.GetRefinedQuestion(nearestText, question)
realQuestion := object.GetRefinedQuestion(knowledge, question)

fmt.Printf("Question: [%s]\n", question)
fmt.Printf("Context: [%s]\n", nearestText)
fmt.Printf("Knowledge: [%s]\n", knowledge)
// fmt.Printf("Refined Question: [%s]\n", realQuestion)
fmt.Printf("Answer: [")

@@ -165,6 +165,7 @@ func (c *ApiController) GetMessageAnswer() {
answer := writer.String()

message.Text = answer
message.VectorScores = vectorScores
_, err = object.UpdateMessage(message.GetId(), message)
if err != nil {
c.ResponseErrorStream(err.Error())
@@ -227,10 +228,11 @@ func (c *ApiController) AddMessage() {
Name: fmt.Sprintf("message_%s", util.GetRandomName()),
CreatedTime: util.GetCurrentTimeEx(message.CreatedTime),
// Organization: message.Organization,
Chat: message.Chat,
ReplyTo: message.GetId(),
Author: "AI",
Text: "",
Chat: message.Chat,
ReplyTo: message.GetId(),
Author: "AI",
Text: "",
VectorScores: []object.VectorScore{},
}
_, err = object.AddMessage(answerMessage)
if err != nil {


+ 10
- 4
object/message.go View File

@@ -21,16 +21,22 @@ import (
"xorm.io/core"
)

type VectorScore struct {
Vector string `xorm:"varchar(100)" json:"vector"`
Score float32 `json:"score"`
}

type Message struct {
Owner string `xorm:"varchar(100) notnull pk" json:"owner"`
Name string `xorm:"varchar(100) notnull pk" json:"name"`
CreatedTime string `xorm:"varchar(100)" json:"createdTime"`

// Organization string `xorm:"varchar(100)" json:"organization"`
Chat string `xorm:"varchar(100) index" json:"chat"`
ReplyTo string `xorm:"varchar(100) index" json:"replyTo"`
Author string `xorm:"varchar(100)" json:"author"`
Text string `xorm:"mediumtext" json:"text"`
Chat string `xorm:"varchar(100) index" json:"chat"`
ReplyTo string `xorm:"varchar(100) index" json:"replyTo"`
Author string `xorm:"varchar(100)" json:"author"`
Text string `xorm:"mediumtext" json:"text"`
VectorScores []VectorScore `xorm:"mediumtext" json:"vectorScores"`
}

func GetGlobalMessages() ([]*Message, error) {


+ 1
- 1
object/search.go View File

@@ -15,7 +15,7 @@
package object

type SearchProvider interface {
Search(qVector []float32) (string, error)
Search(qVector []float32) ([]Vector, error)
}

func GetSearchProvider(typ string, owner string) (SearchProvider, error) {


+ 13
- 6
object/search_default.go View File

@@ -22,17 +22,24 @@ func NewDefaultSearchProvider(owner string) (*DefaultSearchProvider, error) {
return &DefaultSearchProvider{owner: owner}, nil
}

func (p *DefaultSearchProvider) Search(qVector []float32) (string, error) {
func (p *DefaultSearchProvider) Search(qVector []float32) ([]Vector, error) {
vectors, err := getRelatedVectors(p.owner)
if err != nil {
return "", err
return nil, err
}

var nVectors [][]float32
var vectorData [][]float32
for _, candidate := range vectors {
nVectors = append(nVectors, candidate.Data)
vectorData = append(vectorData, candidate.Data)
}

i := getNearestVectorIndex(qVector, nVectors)
return vectors[i].Text, nil
res := []Vector{}
similarities := getNearestVectors(qVector, vectorData, 5)
for _, similarity := range similarities {
vector := vectors[similarity.Index]
vector.Score = similarity.Similarity
res = append(res, *vector)
}

return res, nil
}

object/vector_util.go → object/search_default_util.go View File

@@ -14,7 +14,10 @@

package object

import "math"
import (
"math"
"sort"
)

func dot(vec1, vec2 []float32) float32 {
if len(vec1) != len(vec2) {
@@ -45,17 +48,27 @@ func cosineSimilarity(vec1, vec2 []float32, vec1Norm float32) float32 {
return dotProduct / (vec1Norm * vec2Norm)
}

func getNearestVectorIndex(target []float32, vectors [][]float32) int {
type SimilarityIndex struct {
Similarity float32
Index int
}

func getNearestVectors(target []float32, vectors [][]float32, n int) []SimilarityIndex {
targetNorm := norm(target)

var res int
max := float32(-1.0)
similarities := []SimilarityIndex{}
for i, vector := range vectors {
similarity := cosineSimilarity(target, vector, targetNorm)
if similarity > max {
max = similarity
res = i
}
similarities = append(similarities, SimilarityIndex{similarity, i})
}

sort.Slice(similarities, func(i, j int) bool {
return similarities[i].Similarity > similarities[j].Similarity
})

if len(vectors) < n {
n = len(vectors)
}
return res

return similarities
}

+ 9
- 9
object/search_hnsw.go View File

@@ -29,13 +29,8 @@ func NewHnswSearchProvider() (*HnswSearchProvider, error) {
return &HnswSearchProvider{}, nil
}

func (p *HnswSearchProvider) Search(qVector []float32) (string, error) {
search, err := Index.Search(qVector)
if err != nil {
return "", err
}

return search.Text, nil
func (p *HnswSearchProvider) Search(qVector []float32) ([]Vector, error) {
return Index.Search(qVector)
}

var Index *HNSWIndex
@@ -75,11 +70,16 @@ func (h *HNSWIndex) Add(name string, vector []float32) error {
return h.save()
}

func (h *HNSWIndex) Search(vector []float32) (*Vector, error) {
func (h *HNSWIndex) Search(vector []float32) ([]Vector, error) {
result := h.Hnsw.Search(vector, 100, 4)
item := result.Pop()

owner, name := util.GetOwnerAndNameFromId(h.IdToStr[item.ID])
return getVector(owner, name)
v, err := getVector(owner, name)
if err != nil {
return nil, err
}
return []Vector{*v}, nil
}

func (h *HNSWIndex) save() error {


+ 7
- 6
object/vector.go View File

@@ -26,12 +26,13 @@ type Vector struct {
Name string `xorm:"varchar(100) notnull pk" json:"name"`
CreatedTime string `xorm:"varchar(100)" json:"createdTime"`

DisplayName string `xorm:"varchar(100)" json:"displayName"`
Store string `xorm:"varchar(100)" json:"store"`
Provider string `xorm:"varchar(100)" json:"provider"`
File string `xorm:"varchar(100)" json:"file"`
Index int `json:"index"`
Text string `xorm:"mediumtext" json:"text"`
DisplayName string `xorm:"varchar(100)" json:"displayName"`
Store string `xorm:"varchar(100)" json:"store"`
Provider string `xorm:"varchar(100)" json:"provider"`
File string `xorm:"varchar(100)" json:"file"`
Index int `json:"index"`
Text string `xorm:"mediumtext" json:"text"`
Score float32 `json:"score"`

Data []float32 `xorm:"mediumtext" json:"data"`
Dimension int `json:"dimension"`


+ 22
- 5
object/vector_embedding.go View File

@@ -18,6 +18,7 @@ import (
"context"
"fmt"
"path/filepath"
"strings"
"time"

"github.com/casbin/casibase/embedding"
@@ -149,19 +150,35 @@ func queryVectorSafe(embeddingProvider embedding.EmbeddingProvider, text string)
}
}

func GetNearestVectorText(embeddingProvider embedding.EmbeddingProvider, owner string, text string) (string, error) {
func GetNearestKnowledge(embeddingProvider embedding.EmbeddingProvider, owner string, text string) (string, []VectorScore, error) {
qVector, err := queryVectorSafe(embeddingProvider, text)
if err != nil {
return "", err
return "", nil, err
}
if qVector == nil {
return "", fmt.Errorf("no qVector found")
return "", nil, fmt.Errorf("no qVector found")
}

searchProvider, err := GetSearchProvider("Default", owner)
if err != nil {
return "", err
return "", nil, err
}

return searchProvider.Search(qVector)
vectors, err := searchProvider.Search(qVector)
if err != nil {
return "", nil, err
}

vectorScores := []VectorScore{}
texts := []string{}
for _, vector := range vectors {
vectorScores = append(vectorScores, VectorScore{
Vector: vector.Name,
Score: vector.Score,
})
texts = append(texts, vector.Text)
}

res := strings.Join(texts, "\n\n")
return res, vectorScores, nil
}

Loading…
Cancel
Save