You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cluster.py 2.1 kB

6 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # -*-coding:utf-8-*-
  2. import unittest
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. from pprint import pprint
  6. from jiagu.cluster.kmeans import KMeans
  7. from jiagu.cluster.dbscan import DBSCAN
  8. def load_dataset():
  9. # 西瓜数据集4.0 编号,密度,含糖率
  10. # 数据集来源:《机器学习》第九章 周志华教授
  11. data = '''
  12. 1,0.697,0.460,
  13. 2,0.774,0.376,
  14. 3,0.634,0.264,
  15. 4,0.608,0.318,
  16. 5,0.556,0.215,
  17. 6,0.403,0.237,
  18. 7,0.481,0.149,
  19. 8,0.437,0.211,
  20. 9,0.666,0.091,
  21. 10,0.243,0.267,
  22. 11,0.245,0.057,
  23. 12,0.343,0.099,
  24. 13,0.639,0.161,
  25. 14,0.657,0.198,
  26. 15,0.360,0.370,
  27. 16,0.593,0.042,
  28. 17,0.719,0.103,
  29. 18,0.359,0.188,
  30. 19,0.339,0.241,
  31. 20,0.282,0.257,
  32. 21,0.748,0.232,
  33. 22,0.714,0.346,
  34. 23,0.483,0.312,
  35. 24,0.478,0.437,
  36. 25,0.525,0.369,
  37. 26,0.751,0.489,
  38. 27,0.532,0.472,
  39. 28,0.473,0.376,
  40. 29,0.725,0.445,
  41. 30,0.446,0.459'''
  42. data_ = data.strip().split(',')
  43. dataset = [(float(data_[i]), float(data_[i + 1])) for i in range(1, len(data_) - 1, 3)]
  44. return np.array(dataset)
  45. def show_dataset():
  46. dataset = load_dataset()
  47. fig = plt.figure()
  48. ax = fig.add_subplot(111)
  49. ax.scatter(dataset[:, 0], dataset[:, 1])
  50. plt.title("Dataset")
  51. plt.show()
  52. class TestCluster(unittest.TestCase):
  53. def test_a_kmeans(self):
  54. print("=" * 68, '\n')
  55. print("test k-means ... ")
  56. X = load_dataset()
  57. print("shape of X: ", X.shape)
  58. k = 4
  59. km = KMeans(k=k, max_iter=100)
  60. clusters = km.train(X)
  61. pprint(clusters)
  62. self.assertEqual(len(clusters), k)
  63. pprint({k: len(v) for k, v in clusters.items()})
  64. print("\n\n")
  65. def test_b_dbscan(self):
  66. print("=" * 68, '\n')
  67. print("test dbscan ... ")
  68. X = load_dataset()
  69. ds = DBSCAN(eps=0.11, min_pts=5)
  70. clusters = ds.train(X)
  71. pprint(clusters)
  72. self.assertTrue(len(clusters) < len(X))
  73. # self.assertEqual(len(clusters), 6)
  74. pprint({k: len(v) for k, v in clusters.items()})
  75. if __name__ == '__main__':
  76. unittest.main()

Jiagu使用大规模语料训练而成。将提供中文分词、词性标注、命名实体识别、情感分析、知识图谱关系抽取、关键词抽取、文本摘要、新词发现、情感分析、文本聚类等常用自然语言处理功能。参考了各大工具优缺点制作,将Jiagu回馈给大家