You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

wordset_graph.go 3.7 kB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. // Copyright 2023 The casbin Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package object
  15. import (
  16. "fmt"
  17. "image/color"
  18. "math"
  19. "strconv"
  20. "github.com/casbin/casibase/util"
  21. )
  22. var graphCache map[string]*Graph
  23. func init() {
  24. graphCache = map[string]*Graph{}
  25. }
  26. func GetWordsetGraph(id string, clusterNumber int, distanceLimit int) (*Graph, error) {
  27. cacheId := fmt.Sprintf("%s|%d|%d", id, clusterNumber, distanceLimit)
  28. g, ok := graphCache[cacheId]
  29. if ok {
  30. return g, nil
  31. }
  32. wordset, err := GetWordset(id)
  33. if err != nil {
  34. return nil, err
  35. }
  36. if wordset == nil {
  37. return nil, nil
  38. }
  39. if len(wordset.Factors) == 0 {
  40. return nil, nil
  41. }
  42. allZero := true
  43. for _, factor := range wordset.Factors {
  44. if len(factor.Data) != 0 {
  45. allZero = false
  46. break
  47. }
  48. }
  49. if allZero {
  50. return nil, nil
  51. }
  52. runKmeans(wordset.Factors, clusterNumber)
  53. g = generateGraph(wordset.Factors, distanceLimit)
  54. // graphCache[cacheId] = g
  55. return g, nil
  56. }
  57. func getDistance(v1 *Factor, v2 *Factor) float64 {
  58. res := 0.0
  59. for i := range v1.Data {
  60. res += (v1.Data[i] - v2.Data[i]) * (v1.Data[i] - v2.Data[i])
  61. }
  62. return math.Sqrt(res)
  63. }
  64. func refineFactors(factors []*Factor) []*Factor {
  65. res := []*Factor{}
  66. for _, factor := range factors {
  67. if len(factor.Data) > 0 {
  68. res = append(res, factor)
  69. }
  70. }
  71. return res
  72. }
  73. func getNodeColor(weight int) string {
  74. if weight > 10 {
  75. weight = 10
  76. }
  77. f := (10.0 - float64(weight)) / 10.0
  78. color1 := color.RGBA{R: 232, G: 67, B: 62}
  79. color2 := color.RGBA{R: 24, G: 144, B: 255}
  80. myColor := util.MixColor(color1, color2, f)
  81. return fmt.Sprintf("rgb(%d,%d,%d)", myColor.R, myColor.G, myColor.B)
  82. }
  83. func generateGraph(factors []*Factor, distanceLimit int) *Graph {
  84. factors = refineFactors(factors)
  85. // factors = factors[:100]
  86. g := newGraph()
  87. g.Nodes = []*Node{}
  88. g.Links = []*Link{}
  89. nodeWeightMap := map[string]int{}
  90. for i := 0; i < len(factors); i++ {
  91. for j := i + 1; j < len(factors); j++ {
  92. v1 := factors[i]
  93. v2 := factors[j]
  94. distance := int(getDistance(v1, v2))
  95. if distance >= distanceLimit {
  96. continue
  97. }
  98. if v, ok := nodeWeightMap[v1.Name]; !ok {
  99. nodeWeightMap[v1.Name] = 1
  100. } else {
  101. nodeWeightMap[v1.Name] = v + 1
  102. }
  103. if v, ok := nodeWeightMap[v2.Name]; !ok {
  104. nodeWeightMap[v2.Name] = 1
  105. } else {
  106. nodeWeightMap[v2.Name] = v + 1
  107. }
  108. linkValue := (1*(distance-7) + 10*(distanceLimit-1-distance)) / (distanceLimit - 8)
  109. linkColor := "rgb(44,160,44,0.6)"
  110. linkName := fmt.Sprintf("Edge [%s] - [%s]: distance = %d, linkValue = %d", v1.Name, v2.Name, distance, linkValue)
  111. fmt.Println(linkName)
  112. g.addLink(linkName, v1.Name, v2.Name, linkValue, linkColor, strconv.Itoa(distance))
  113. }
  114. }
  115. for _, factor := range factors {
  116. // value := 5
  117. value := int(math.Sqrt(float64(nodeWeightMap[factor.Name]))) + 3
  118. weight := nodeWeightMap[factor.Name]
  119. // nodeColor := "rgb(232,67,62)"
  120. // nodeColor := getNodeColor(value)
  121. nodeColor := factor.Color
  122. fmt.Printf("Node [%s]: weight = %d, nodeValue = %d\n", factor.Name, nodeWeightMap[factor.Name], value)
  123. g.addNode(factor.Name, factor.Name, value, nodeColor, factor.Category, weight)
  124. }
  125. return g
  126. }