You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

vectorset_tsne.go 1.3 kB

3 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. package object
  2. import (
  3. "fmt"
  4. "github.com/danaugrs/go-tsne/tsne"
  5. "gonum.org/v1/gonum/mat"
  6. )
  7. func testTsne() {
  8. b := mat.NewDense(5, 3, []float64{
  9. 0.1, 0.1, 0.1,
  10. 0.7, 0.7, 0.7,
  11. 0.1, 0.7, 0.5,
  12. 0.7, 0.1, 0.2,
  13. 0.1, 0.7, 0.5,
  14. })
  15. t := tsne.NewTSNE(2, 300, 100, 300, true)
  16. Y := t.EmbedData(b, func(iter int, divergence float64, embedding mat.Matrix) bool {
  17. fmt.Printf("Iteration %d: divergence is %v\n", iter, divergence)
  18. return false
  19. })
  20. println(Y)
  21. }
  22. func (vectorset *Vectorset) DoTsne(dimension int) {
  23. floatArray := []float64{}
  24. for _, vector := range vectorset.Vectors {
  25. floatArray = append(floatArray, vector.Data...)
  26. }
  27. X := mat.NewDense(len(vectorset.Vectors), vectorset.Dimension, floatArray)
  28. t := tsne.NewTSNE(dimension, 300, 100, 300, true)
  29. Y := t.EmbedData(X, func(iter int, divergence float64, embedding mat.Matrix) bool {
  30. fmt.Printf("Iteration %d: divergence is %v\n", iter, divergence)
  31. return false
  32. })
  33. rowCount, columnCount := Y.Dims()
  34. if rowCount != len(vectorset.Vectors) {
  35. panic("rowCount != len(vectorset.Vectors)")
  36. }
  37. if columnCount != dimension {
  38. panic("columnCount != dimension")
  39. }
  40. for i, vector := range vectorset.Vectors {
  41. arr := []float64{}
  42. for j := 0; j < dimension; j++ {
  43. arr = append(arr, Y.At(i, j))
  44. }
  45. vector.Data = arr
  46. }
  47. }

基于Casbin的开源AI领域知识库平台

Contributors (1)