You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

vector_util.go 1.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. // Copyright 2023 The casbin Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package object
  15. import "math"
  16. func dot(vec1, vec2 []float32) float32 {
  17. if len(vec1) != len(vec2) {
  18. panic("Vector lengths do not match")
  19. }
  20. dotProduct := float32(0.0)
  21. for i := range vec1 {
  22. dotProduct += vec1[i] * vec2[i]
  23. }
  24. return dotProduct
  25. }
  26. func norm(vec []float32) float32 {
  27. normSquared := float32(0.0)
  28. for _, val := range vec {
  29. normSquared += val * val
  30. }
  31. return float32(math.Sqrt(float64(normSquared)))
  32. }
  33. func cosineSimilarity(vec1, vec2 []float32, vec1Norm float32) float32 {
  34. dotProduct := dot(vec1, vec2)
  35. vec2Norm := norm(vec2)
  36. if vec2Norm == 0 {
  37. return 0.0
  38. }
  39. return dotProduct / (vec1Norm * vec2Norm)
  40. }
  41. func getNearestVectorIndex(target []float32, vectors [][]float32) int {
  42. targetNorm := norm(target)
  43. var res int
  44. max := float32(-1.0)
  45. for i, vector := range vectors {
  46. similarity := cosineSimilarity(target, vector, targetNorm)
  47. if similarity > max {
  48. max = similarity
  49. res = i
  50. }
  51. }
  52. return res
  53. }