You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

2-kmeans-color-vq.ipynb 830 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# 用K-means进行颜色量化\n",
  8. "\n",
  9. "对圆明园的图像进行**像素矢量量化(VQ)**,将显示图像所需的颜色从96,615种减少到64种,同时保持整体外观质量。\n",
  10. "\n",
  11. "\n",
  12. "在本例中,像素在3d空间中表示,使用K-means找到64个颜色簇。在图像处理文献中,由K-means(聚类中心)得到的码本称为调色板。使用单个字节,最多可以寻址256种颜色,而RGB编码需要每个像素3个字节。例如,GIF文件格式就使用了这样一个调色板。\n",
  13. "\n"
  14. ]
  15. },
  16. {
  17. "cell_type": "code",
  18. "execution_count": 1,
  19. "metadata": {},
  20. "outputs": [],
  21. "source": [
  22. "% matplotlib inline\n",
  23. "import numpy as np\n",
  24. "import matplotlib.pyplot as plt\n",
  25. "from sklearn.cluster import KMeans\n",
  26. "from sklearn.metrics import pairwise_distances_argmin\n",
  27. "from sklearn.datasets import load_sample_image\n",
  28. "from sklearn.utils import shuffle\n",
  29. "from time import time"
  30. ]
  31. },
  32. {
  33. "cell_type": "code",
  34. "execution_count": 2,
  35. "metadata": {},
  36. "outputs": [
  37. {
  38. "name": "stderr",
  39. "output_type": "stream",
  40. "text": [
  41. "/home/bushuhui/.virtualenv/dl/lib/python3.5/site-packages/sklearn/datasets/base.py:762: DeprecationWarning: `imread` is deprecated!\n",
  42. "`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.\n",
  43. "Use ``imageio.imread`` instead.\n",
  44. " images = [imread(filename) for filename in filenames]\n",
  45. "/home/bushuhui/.virtualenv/dl/lib/python3.5/site-packages/sklearn/datasets/base.py:762: DeprecationWarning: `imread` is deprecated!\n",
  46. "`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.\n",
  47. "Use ``imageio.imread`` instead.\n",
  48. " images = [imread(filename) for filename in filenames]\n"
  49. ]
  50. },
  51. {
  52. "name": "stdout",
  53. "output_type": "stream",
  54. "text": [
  55. "Fitting model on a small sub-sample of the data\n",
  56. " done in 0.185s.\n",
  57. "Predicting color indices on the full image (k-means)\n",
  58. " done in 0.128s.\n",
  59. "Predicting color indices on the full image (random)\n",
  60. " done in 0.095s.\n"
  61. ]
  62. }
  63. ],
  64. "source": [
  65. "n_colors = 64\n",
  66. "\n",
  67. "# 加载圆明园的图像\n",
  68. "china = load_sample_image(\"china.jpg\")\n",
  69. "\n",
  70. "# 转化为浮点数而不是默认的8位整数编码。\n",
  71. "# 除以255是重要的这样plt.imshow在浮点数(需要在[0-1]的范围)上的表现会很好 \n",
  72. "\n",
  73. "china = np.array(china, dtype=np.float64) / 255\n",
  74. "\n",
  75. "# 加载图像并转化成2D的numpy数组。\n",
  76. "w, h, d = original_shape = tuple(china.shape)\n",
  77. "assert d == 3\n",
  78. "image_array = np.reshape(china, (w * h, d))\n",
  79. "\n",
  80. "print(\"Fitting model on a small sub-sample of the data\")\n",
  81. "t0 = time()\n",
  82. "image_array_sample = shuffle(image_array, random_state=0)[:1000]\n",
  83. "kmeans = KMeans(n_clusters=n_colors, random_state=0).fit(image_array_sample)\n",
  84. "print(\" done in %0.3fs.\" % (time() - t0))\n",
  85. "\n",
  86. "# 获得所有点的标签\n",
  87. "print(\"Predicting color indices on the full image (k-means)\")\n",
  88. "t0 = time()\n",
  89. "labels = kmeans.predict(image_array)\n",
  90. "print(\" done in %0.3fs.\" % (time() - t0))\n",
  91. "\n",
  92. "\n",
  93. "codebook_random = shuffle(image_array, random_state=0)[:n_colors + 1]\n",
  94. "print(\"Predicting color indices on the full image (random)\")\n",
  95. "t0 = time()\n",
  96. "labels_random = pairwise_distances_argmin(codebook_random,\n",
  97. " image_array,\n",
  98. " axis=0)\n",
  99. "print(\" done in %0.3fs.\" % (time() - t0))\n",
  100. "\n",
  101. "\n",
  102. "def recreate_image(codebook, labels, w, h):\n",
  103. " \"\"\"Recreate the (compressed) image from the code book & labels\"\"\"\n",
  104. " d = codebook.shape[1]\n",
  105. " image = np.zeros((w, h, d))\n",
  106. " label_idx = 0\n",
  107. " for i in range(w):\n",
  108. " for j in range(h):\n",
  109. " image[i][j] = codebook[labels[label_idx]]\n",
  110. " label_idx += 1\n",
  111. " return image"
  112. ]
  113. },
  114. {
  115. "cell_type": "code",
  116. "execution_count": 3,
  117. "metadata": {},
  118. "outputs": [
  119. {
  120. "data": {
  121. "text/plain": [
  122. "<matplotlib.image.AxesImage at 0x7f7bcdf9aba8>"
  123. ]
  124. },
  125. "execution_count": 3,
  126. "metadata": {},
  127. "output_type": "execute_result"
  128. },
  129. {
  130. "data": {

机器学习越来越多应用到飞行器、机器人等领域,其目的是利用计算机实现类似人类的智能,从而实现装备的智能化与无人化。本课程旨在引导学生掌握机器学习的基本知识、典型方法与技术,通过具体的应用案例激发学生对该学科的兴趣,鼓励学生能够从人工智能的角度来分析、解决飞行器、机器人所面临的问题和挑战。本课程主要内容包括Python编程基础,机器学习模型,无监督学习、监督学习、深度学习基础知识与实现,并学习如何利用机器学习解决实际问题,从而全面提升自我的《综合能力》。