Browse Source

Refactor out vector_util.go

HEAD
Yang Luo 2 years ago
parent
commit
12acd815d7
5 changed files with 64 additions and 50 deletions
  1. +0
    -15
      model/embedding.go
  2. +1
    -0
      model/query_test.go
  3. +1
    -34
      model/util.go
  4. +1
    -1
      object/vector_embedding.go
  5. +61
    -0
      object/vector_util.go

+ 0
- 15
model/embedding.go View File

@@ -59,18 +59,3 @@ func GetEmbeddingSafe(authToken string, text string) ([]float32, error) {
return embedding, nil
}
}

func GetNearestVectorIndex(target []float32, vectors [][]float32) int {
targetNorm := norm(target)

var res int
max := float32(-1.0)
for i, vector := range vectors {
similarity := cosineSimilarity(target, vector, targetNorm)
if similarity > max {
max = similarity
res = i
}
}
return res
}

+ 1
- 0
model/query_test.go View File

@@ -20,6 +20,7 @@ package model_test
import (
"testing"

"github.com/casbin/casibase/model"
"github.com/casbin/casibase/object"
"github.com/casbin/casibase/proxy"
"github.com/sashabaranov/go-openai"


+ 1
- 34
model/util.go View File

@@ -14,11 +14,7 @@

package model

import (
"math"

"github.com/pkoukk/tiktoken-go"
)
import "github.com/pkoukk/tiktoken-go"

func GetTokenSize(model string, prompt string) (int, error) {
tkm, err := tiktoken.EncodingForModel(model)
@@ -30,32 +26,3 @@ func GetTokenSize(model string, prompt string) (int, error) {
res := len(token)
return res, nil
}

func cosineSimilarity(vec1, vec2 []float32, vec1Norm float32) float32 {
dotProduct := dot(vec1, vec2)
vec2Norm := norm(vec2)
if vec2Norm == 0 {
return 0.0
}
return dotProduct / (vec1Norm * vec2Norm)
}

func dot(vec1, vec2 []float32) float32 {
if len(vec1) != len(vec2) {
panic("Vector lengths do not match")
}

dotProduct := float32(0.0)
for i := range vec1 {
dotProduct += vec1[i] * vec2[i]
}
return dotProduct
}

func norm(vec []float32) float32 {
normSquared := float32(0.0)
for _, val := range vec {
normSquared += val * val
}
return float32(math.Sqrt(float64(normSquared)))
}

+ 1
- 1
object/vector_embedding.go View File

@@ -146,6 +146,6 @@ func GetNearestVectorText(authToken string, owner string, question string) (stri
nVectors = append(nVectors, candidate.Data)
}

i := model.GetNearestVectorIndex(qVector, nVectors)
i := getNearestVectorIndex(qVector, nVectors)
return vectors[i].Text, nil
}

+ 61
- 0
object/vector_util.go View File

@@ -0,0 +1,61 @@
// Copyright 2023 The casbin Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package object

import "math"

func dot(vec1, vec2 []float32) float32 {
if len(vec1) != len(vec2) {
panic("Vector lengths do not match")
}

dotProduct := float32(0.0)
for i := range vec1 {
dotProduct += vec1[i] * vec2[i]
}
return dotProduct
}

func norm(vec []float32) float32 {
normSquared := float32(0.0)
for _, val := range vec {
normSquared += val * val
}
return float32(math.Sqrt(float64(normSquared)))
}

func cosineSimilarity(vec1, vec2 []float32, vec1Norm float32) float32 {
dotProduct := dot(vec1, vec2)
vec2Norm := norm(vec2)
if vec2Norm == 0 {
return 0.0
}
return dotProduct / (vec1Norm * vec2Norm)
}

func getNearestVectorIndex(target []float32, vectors [][]float32) int {
targetNorm := norm(target)

var res int
max := float32(-1.0)
for i, vector := range vectors {
similarity := cosineSimilarity(target, vector, targetNorm)
if similarity > max {
max = similarity
res = i
}
}
return res
}

Loading…
Cancel
Save