Browse Source

Check vector length in getNearestVectors()

master
Yang Luo 2 years ago
parent
commit
4426466d4f
3 changed files with 15 additions and 7 deletions
  1. +5
    -1
      object/search_default.go
  2. +9
    -5
      object/search_default_util.go
  3. +1
    -1
      object/vector_embedding.go

+ 5
- 1
object/search_default.go View File

@@ -33,8 +33,12 @@ func (p *DefaultSearchProvider) Search(qVector []float32) ([]Vector, error) {
vectorData = append(vectorData, candidate.Data)
}

similarities, err := getNearestVectors(qVector, vectorData, 5)
if err != nil {
return nil, err
}

res := []Vector{}
similarities := getNearestVectors(qVector, vectorData, 5)
for _, similarity := range similarities {
vector := vectors[similarity.Index]
vector.Score = similarity.Similarity


+ 9
- 5
object/search_default_util.go View File

@@ -15,6 +15,7 @@
package object

import (
"fmt"
"math"
"sort"
)
@@ -53,11 +54,15 @@ type SimilarityIndex struct {
Index int
}

func getNearestVectors(target []float32, vectors [][]float32, n int) []SimilarityIndex {
func getNearestVectors(target []float32, vectors [][]float32, n int) ([]SimilarityIndex, error) {
targetNorm := norm(target)

similarities := []SimilarityIndex{}
for i, vector := range vectors {
if len(target) != len(vector) {
return nil, fmt.Errorf("The target vector's length: [%d] should equal to knowledge vector's length: [%d], target vector = %v, knowledge vector = %v", len(target), len(vector), target, vector)
}

similarity := cosineSimilarity(target, vector, targetNorm)
similarities = append(similarities, SimilarityIndex{similarity, i})
}
@@ -66,10 +71,9 @@ func getNearestVectors(target []float32, vectors [][]float32, n int) []Similarit
return similarities[i].Similarity > similarities[j].Similarity
})

if len(vectors) < n {
n = len(vectors)
if n > len(similarities) {
n = len(similarities)
}

res := similarities[:n]
return res
return res, nil
}

+ 1
- 1
object/vector_embedding.go View File

@@ -166,7 +166,7 @@ func GetNearestKnowledge(embeddingProvider *Provider, embeddingProviderObj embed
if err != nil {
return "", nil, err
}
if qVector == nil {
if qVector == nil || len(qVector) == 0 {
return "", nil, fmt.Errorf("no qVector found")
}



Loading…
Cancel
Save