diff --git a/object/search_default.go b/object/search_default.go index a6b7283..399d0a2 100644 --- a/object/search_default.go +++ b/object/search_default.go @@ -33,8 +33,12 @@ func (p *DefaultSearchProvider) Search(qVector []float32) ([]Vector, error) { vectorData = append(vectorData, candidate.Data) } + similarities, err := getNearestVectors(qVector, vectorData, 5) + if err != nil { + return nil, err + } + res := []Vector{} - similarities := getNearestVectors(qVector, vectorData, 5) for _, similarity := range similarities { vector := vectors[similarity.Index] vector.Score = similarity.Similarity diff --git a/object/search_default_util.go b/object/search_default_util.go index 59a61f2..107679a 100644 --- a/object/search_default_util.go +++ b/object/search_default_util.go @@ -15,6 +15,7 @@ package object import ( + "fmt" "math" "sort" ) @@ -53,11 +54,15 @@ type SimilarityIndex struct { Index int } -func getNearestVectors(target []float32, vectors [][]float32, n int) []SimilarityIndex { +func getNearestVectors(target []float32, vectors [][]float32, n int) ([]SimilarityIndex, error) { targetNorm := norm(target) similarities := []SimilarityIndex{} for i, vector := range vectors { + if len(target) != len(vector) { + return nil, fmt.Errorf("The target vector's length: [%d] should equal to knowledge vector's length: [%d], target vector = %v, knowledge vector = %v", len(target), len(vector), target, vector) + } + similarity := cosineSimilarity(target, vector, targetNorm) similarities = append(similarities, SimilarityIndex{similarity, i}) } @@ -66,10 +71,9 @@ func getNearestVectors(target []float32, vectors [][]float32, n int) []Similarit return similarities[i].Similarity > similarities[j].Similarity }) - if len(vectors) < n { - n = len(vectors) + if n > len(similarities) { + n = len(similarities) } - res := similarities[:n] - return res + return res, nil } diff --git a/object/vector_embedding.go b/object/vector_embedding.go index 7e1f22b..f07db19 100644 --- a/object/vector_embedding.go +++ b/object/vector_embedding.go @@ -166,7 +166,7 @@ func GetNearestKnowledge(embeddingProvider *Provider, embeddingProviderObj embed if err != nil { return "", nil, err } - if qVector == nil { + if qVector == nil || len(qVector) == 0 { return "", nil, fmt.Errorf("no qVector found") }