You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset_graph.go 2.7 kB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. package object
  2. import (
  3. "fmt"
  4. "image/color"
  5. "math"
  6. "strconv"
  7. "github.com/casbin/casbase/util"
  8. )
  9. var graphCache map[string]*Graph
  10. func init() {
  11. graphCache = map[string]*Graph{}
  12. }
  13. func GetDatasetGraph(id string, clusterNumber int) *Graph {
  14. cacheId := fmt.Sprintf("%s|%d", id, clusterNumber)
  15. g, ok := graphCache[cacheId]
  16. if ok {
  17. return g
  18. }
  19. dataset := GetDataset(id)
  20. if dataset == nil {
  21. return nil
  22. }
  23. runKmeans(dataset.Vectors, clusterNumber)
  24. g = generateGraph(dataset.Vectors)
  25. graphCache[cacheId] = g
  26. return g
  27. }
  28. func getDistance(v1 *Vector, v2 *Vector) float64 {
  29. res := 0.0
  30. for i := range v1.Data {
  31. res += (v1.Data[i] - v2.Data[i]) * (v1.Data[i] - v2.Data[i])
  32. }
  33. return math.Sqrt(res)
  34. }
  35. func refineVectors(vectors []*Vector) []*Vector {
  36. res := []*Vector{}
  37. for _, vector := range vectors {
  38. if len(vector.Data) > 0 {
  39. res = append(res, vector)
  40. }
  41. }
  42. return res
  43. }
  44. func getNodeColor(weight int) string {
  45. if weight > 10 {
  46. weight = 10
  47. }
  48. f := (10.0 - float64(weight)) / 10.0
  49. color1 := color.RGBA{R: 232, G: 67, B: 62}
  50. color2 := color.RGBA{R: 24, G: 144, B: 255}
  51. myColor := util.MixColor(color1, color2, f)
  52. return fmt.Sprintf("rgb(%d,%d,%d)", myColor.R, myColor.G, myColor.B)
  53. }
  54. var DistanceLimit = 14
  55. func generateGraph(vectors []*Vector) *Graph {
  56. vectors = refineVectors(vectors)
  57. //vectors = vectors[:100]
  58. g := newGraph()
  59. nodeWeightMap := map[string]int{}
  60. for i := 0; i < len(vectors); i++ {
  61. for j := i + 1; j < len(vectors); j++ {
  62. v1 := vectors[i]
  63. v2 := vectors[j]
  64. distance := int(getDistance(v1, v2))
  65. if distance >= DistanceLimit {
  66. continue
  67. }
  68. if v, ok := nodeWeightMap[v1.Name]; !ok {
  69. nodeWeightMap[v1.Name] = 1
  70. } else {
  71. nodeWeightMap[v1.Name] = v + 1
  72. }
  73. if v, ok := nodeWeightMap[v2.Name]; !ok {
  74. nodeWeightMap[v2.Name] = 1
  75. } else {
  76. nodeWeightMap[v2.Name] = v + 1
  77. }
  78. linkValue := (1*(distance-7) + 10*(DistanceLimit-1-distance)) / (DistanceLimit - 8)
  79. linkColor := "rgb(44,160,44,0.6)"
  80. linkName := fmt.Sprintf("Edge [%s] - [%s]: distance = %d, linkValue = %d", v1.Name, v2.Name, distance, linkValue)
  81. fmt.Println(linkName)
  82. g.addLink(linkName, v1.Name, v2.Name, linkValue, linkColor, strconv.Itoa(distance))
  83. }
  84. }
  85. for _, vector := range vectors {
  86. //value := 5
  87. value := int(math.Sqrt(float64(nodeWeightMap[vector.Name]))) + 3
  88. weight := nodeWeightMap[vector.Name]
  89. //nodeColor := "rgb(232,67,62)"
  90. //nodeColor := getNodeColor(value)
  91. nodeColor := vector.Color
  92. fmt.Printf("Node [%s]: weight = %d, nodeValue = %d\n", vector.Name, nodeWeightMap[vector.Name], value)
  93. g.addNode(vector.Name, vector.Name, value, nodeColor, vector.Category, weight)
  94. }
  95. return g
  96. }

基于Casbin的开源AI领域知识库平台

Contributors (1)