You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

factorset_tsne.go 1.9 kB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. // Copyright 2023 The casbin Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package object
  15. import (
  16. "fmt"
  17. "github.com/danaugrs/go-tsne/tsne"
  18. "gonum.org/v1/gonum/mat"
  19. )
  20. func testTsne() {
  21. b := mat.NewDense(5, 3, []float64{
  22. 0.1, 0.1, 0.1,
  23. 0.7, 0.7, 0.7,
  24. 0.1, 0.7, 0.5,
  25. 0.7, 0.1, 0.2,
  26. 0.1, 0.7, 0.5,
  27. })
  28. t := tsne.NewTSNE(2, 300, 100, 300, true)
  29. Y := t.EmbedData(b, func(iter int, divergence float64, embedding mat.Matrix) bool {
  30. fmt.Printf("Iteration %d: divergence is %v\n", iter, divergence)
  31. return false
  32. })
  33. println(Y)
  34. }
  35. func (factorset *Factorset) DoTsne(dimension int) {
  36. floatArray := []float64{}
  37. for _, factor := range factorset.AllFactors {
  38. floatArray = append(floatArray, factor.Data...)
  39. }
  40. X := mat.NewDense(len(factorset.AllFactors), factorset.Dimension, floatArray)
  41. t := tsne.NewTSNE(dimension, 300, 100, 300, true)
  42. Y := t.EmbedData(X, func(iter int, divergence float64, embedding mat.Matrix) bool {
  43. fmt.Printf("Iteration %d: divergence is %v\n", iter, divergence)
  44. return false
  45. })
  46. rowCount, columnCount := Y.Dims()
  47. if rowCount != len(factorset.AllFactors) {
  48. panic("rowCount != len(factorset.AllFactors)")
  49. }
  50. if columnCount != dimension {
  51. panic("columnCount != dimension")
  52. }
  53. for i, factor := range factorset.AllFactors {
  54. arr := []float64{}
  55. for j := 0; j < dimension; j++ {
  56. arr = append(arr, Y.At(i, j))
  57. }
  58. factor.Data = arr
  59. }
  60. }