You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset_graph.go 2.5 kB

3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. package object
  2. import (
  3. "fmt"
  4. "image/color"
  5. "math"
  6. "github.com/casbin/casbase/util"
  7. )
  8. var graphCache map[string]*Graph
  9. func init() {
  10. graphCache = map[string]*Graph{}
  11. }
  12. func GetDatasetGraph(id string) *Graph {
  13. g, ok := graphCache[id]
  14. if ok {
  15. return g
  16. }
  17. dataset := GetDataset(id)
  18. if dataset == nil {
  19. return nil
  20. }
  21. g = generateGraph(dataset.Vectors)
  22. graphCache[id] = g
  23. return g
  24. }
  25. func getDistance(v1 *Vector, v2 *Vector) float64 {
  26. res := 0.0
  27. for i := range v1.Data {
  28. res += (v1.Data[i] - v2.Data[i]) * (v1.Data[i] - v2.Data[i])
  29. }
  30. return math.Sqrt(res)
  31. }
  32. func refineVectors(vectors []*Vector) []*Vector {
  33. res := []*Vector{}
  34. for _, vector := range vectors {
  35. if len(vector.Data) > 0 {
  36. res = append(res, vector)
  37. }
  38. }
  39. return res
  40. }
  41. func getNodeColor(weight int) string {
  42. if weight > 10 {
  43. weight = 10
  44. }
  45. f := (10.0 - float64(weight)) / 10.0
  46. color1 := color.RGBA{R: 232, G: 67, B: 62}
  47. color2 := color.RGBA{R: 24, G: 144, B: 255}
  48. myColor := util.MixColor(color1, color2, f)
  49. return fmt.Sprintf("rgb(%d,%d,%d)", myColor.R, myColor.G, myColor.B)
  50. }
  51. var DistanceLimit = 12
  52. func generateGraph(vectors []*Vector) *Graph {
  53. vectors = refineVectors(vectors)
  54. //vectors = vectors[:100]
  55. g := newGraph()
  56. nodeWeightMap := map[string]int{}
  57. for i := 0; i < len(vectors); i++ {
  58. for j := i + 1; j < len(vectors); j++ {
  59. v1 := vectors[i]
  60. v2 := vectors[j]
  61. distance := int(getDistance(v1, v2))
  62. if distance >= DistanceLimit {
  63. continue
  64. }
  65. if v, ok := nodeWeightMap[v1.Name]; !ok {
  66. nodeWeightMap[v1.Name] = 1
  67. } else {
  68. nodeWeightMap[v1.Name] = v + 1
  69. }
  70. if v, ok := nodeWeightMap[v2.Name]; !ok {
  71. nodeWeightMap[v2.Name] = 1
  72. } else {
  73. nodeWeightMap[v2.Name] = v + 1
  74. }
  75. linkValue := (1*(distance-7) + 10*(DistanceLimit-1-distance)) / (DistanceLimit - 8)
  76. linkColor := "rgb(44,160,44,0.6)"
  77. linkName := fmt.Sprintf("Edge [%s] - [%s]: distance = %d, linkValue = %d", v1.Name, v2.Name, distance, linkValue)
  78. fmt.Println(linkName)
  79. g.addLink(linkName, v1.Name, v2.Name, linkValue, linkColor, "")
  80. }
  81. }
  82. for _, vector := range vectors {
  83. //value := 5
  84. value := int(math.Sqrt(float64(nodeWeightMap[vector.Name]))) + 3
  85. //nodeColor := "rgb(232,67,62)"
  86. //nodeColor := getNodeColor(value)
  87. nodeColor := vector.Color
  88. fmt.Printf("Node [%s]: weight = %d, nodeValue = %d\n", vector.Name, nodeWeightMap[vector.Name], value)
  89. g.addNode(vector.Name, vector.Name, value, nodeColor, "")
  90. }
  91. return g
  92. }

基于Casbin的开源AI领域知识库平台

Contributors (1)