diff --git a/model/embedding.go b/model/embedding.go index 30a463f..51e9787 100644 --- a/model/embedding.go +++ b/model/embedding.go @@ -59,18 +59,3 @@ func GetEmbeddingSafe(authToken string, text string) ([]float32, error) { return embedding, nil } } - -func GetNearestVectorIndex(target []float32, vectors [][]float32) int { - targetNorm := norm(target) - - var res int - max := float32(-1.0) - for i, vector := range vectors { - similarity := cosineSimilarity(target, vector, targetNorm) - if similarity > max { - max = similarity - res = i - } - } - return res -} diff --git a/model/query_test.go b/model/query_test.go index be5c5ba..4aef4ff 100644 --- a/model/query_test.go +++ b/model/query_test.go @@ -20,6 +20,7 @@ package model_test import ( "testing" + "github.com/casbin/casibase/model" "github.com/casbin/casibase/object" "github.com/casbin/casibase/proxy" "github.com/sashabaranov/go-openai" diff --git a/model/util.go b/model/util.go index b04ffc3..d7db0f2 100644 --- a/model/util.go +++ b/model/util.go @@ -14,11 +14,7 @@ package model -import ( - "math" - - "github.com/pkoukk/tiktoken-go" -) +import "github.com/pkoukk/tiktoken-go" func GetTokenSize(model string, prompt string) (int, error) { tkm, err := tiktoken.EncodingForModel(model) @@ -30,32 +26,3 @@ func GetTokenSize(model string, prompt string) (int, error) { res := len(token) return res, nil } - -func cosineSimilarity(vec1, vec2 []float32, vec1Norm float32) float32 { - dotProduct := dot(vec1, vec2) - vec2Norm := norm(vec2) - if vec2Norm == 0 { - return 0.0 - } - return dotProduct / (vec1Norm * vec2Norm) -} - -func dot(vec1, vec2 []float32) float32 { - if len(vec1) != len(vec2) { - panic("Vector lengths do not match") - } - - dotProduct := float32(0.0) - for i := range vec1 { - dotProduct += vec1[i] * vec2[i] - } - return dotProduct -} - -func norm(vec []float32) float32 { - normSquared := float32(0.0) - for _, val := range vec { - normSquared += val * val - } - return float32(math.Sqrt(float64(normSquared))) -} diff --git a/object/vector_embedding.go b/object/vector_embedding.go index d53b682..de018e5 100644 --- a/object/vector_embedding.go +++ b/object/vector_embedding.go @@ -146,6 +146,6 @@ func GetNearestVectorText(authToken string, owner string, question string) (stri nVectors = append(nVectors, candidate.Data) } - i := model.GetNearestVectorIndex(qVector, nVectors) + i := getNearestVectorIndex(qVector, nVectors) return vectors[i].Text, nil } diff --git a/object/vector_util.go b/object/vector_util.go new file mode 100644 index 0000000..c785e78 --- /dev/null +++ b/object/vector_util.go @@ -0,0 +1,61 @@ +// Copyright 2023 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package object + +import "math" + +func dot(vec1, vec2 []float32) float32 { + if len(vec1) != len(vec2) { + panic("Vector lengths do not match") + } + + dotProduct := float32(0.0) + for i := range vec1 { + dotProduct += vec1[i] * vec2[i] + } + return dotProduct +} + +func norm(vec []float32) float32 { + normSquared := float32(0.0) + for _, val := range vec { + normSquared += val * val + } + return float32(math.Sqrt(float64(normSquared))) +} + +func cosineSimilarity(vec1, vec2 []float32, vec1Norm float32) float32 { + dotProduct := dot(vec1, vec2) + vec2Norm := norm(vec2) + if vec2Norm == 0 { + return 0.0 + } + return dotProduct / (vec1Norm * vec2Norm) +} + +func getNearestVectorIndex(target []float32, vectors [][]float32) int { + targetNorm := norm(target) + + var res int + max := float32(-1.0) + for i, vector := range vectors { + similarity := cosineSimilarity(target, vector, targetNorm) + if similarity > max { + max = similarity + res = i + } + } + return res +}