You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kmeans-color-vq.ipynb 830 kB

7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
7 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# Color Quantization by K-Means\n",
  8. "\n",
  9. "Performs a pixel-wise **Vector Quantization (VQ)** of an image of the summer palace (China), reducing the number of colors required to show the image from 96,615 unique colors to 64, while preserving the overall appearance quality.\n",
  10. "\n",
  11. "In this example, pixels are represented in a 3D-space and K-means is used to find 64 color clusters. In the image processing literature, the codebook obtained from K-means (the cluster centers) is called the color palette. Using a single byte, up to 256 colors can be addressed, whereas an RGB encoding requires 3 bytes per pixel. The GIF file format, for example, uses such a palette.\n",
  12. "\n",
  13. "\n"
  14. ]
  15. },
  16. {
  17. "cell_type": "code",
  18. "execution_count": 1,
  19. "metadata": {},
  20. "outputs": [],
  21. "source": [
  22. "% matplotlib inline\n",
  23. "import numpy as np\n",
  24. "import matplotlib.pyplot as plt\n",
  25. "from sklearn.cluster import KMeans\n",
  26. "from sklearn.metrics import pairwise_distances_argmin\n",
  27. "from sklearn.datasets import load_sample_image\n",
  28. "from sklearn.utils import shuffle\n",
  29. "from time import time"
  30. ]
  31. },
  32. {
  33. "cell_type": "code",
  34. "execution_count": 2,
  35. "metadata": {},
  36. "outputs": [
  37. {
  38. "name": "stderr",
  39. "output_type": "stream",
  40. "text": [
  41. "/home/bushuhui/.virtualenv/dl/lib/python3.5/site-packages/sklearn/datasets/base.py:762: DeprecationWarning: `imread` is deprecated!\n",
  42. "`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.\n",
  43. "Use ``imageio.imread`` instead.\n",
  44. " images = [imread(filename) for filename in filenames]\n",
  45. "/home/bushuhui/.virtualenv/dl/lib/python3.5/site-packages/sklearn/datasets/base.py:762: DeprecationWarning: `imread` is deprecated!\n",
  46. "`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.\n",
  47. "Use ``imageio.imread`` instead.\n",
  48. " images = [imread(filename) for filename in filenames]\n"
  49. ]
  50. },
  51. {
  52. "name": "stdout",
  53. "output_type": "stream",
  54. "text": [
  55. "Fitting model on a small sub-sample of the data\n",
  56. " done in 0.185s.\n",
  57. "Predicting color indices on the full image (k-means)\n",
  58. " done in 0.128s.\n",
  59. "Predicting color indices on the full image (random)\n",
  60. " done in 0.095s.\n"
  61. ]
  62. }
  63. ],
  64. "source": [
  65. "n_colors = 64\n",
  66. "\n",
  67. "# Load the Summer Palace photo\n",
  68. "china = load_sample_image(\"china.jpg\")\n",
  69. "\n",
  70. "# Convert to floats instead of the default 8 bits integer coding. Dividing by\n",
  71. "# 255 is important so that plt.imshow behaves works well on float data (need to\n",
  72. "# be in the range [0-1])\n",
  73. "china = np.array(china, dtype=np.float64) / 255\n",
  74. "\n",
  75. "# Load Image and transform to a 2D numpy array.\n",
  76. "w, h, d = original_shape = tuple(china.shape)\n",
  77. "assert d == 3\n",
  78. "image_array = np.reshape(china, (w * h, d))\n",
  79. "\n",
  80. "print(\"Fitting model on a small sub-sample of the data\")\n",
  81. "t0 = time()\n",
  82. "image_array_sample = shuffle(image_array, random_state=0)[:1000]\n",
  83. "kmeans = KMeans(n_clusters=n_colors, random_state=0).fit(image_array_sample)\n",
  84. "print(\" done in %0.3fs.\" % (time() - t0))\n",
  85. "\n",
  86. "# Get labels for all points\n",
  87. "print(\"Predicting color indices on the full image (k-means)\")\n",
  88. "t0 = time()\n",
  89. "labels = kmeans.predict(image_array)\n",
  90. "print(\" done in %0.3fs.\" % (time() - t0))\n",
  91. "\n",
  92. "\n",
  93. "codebook_random = shuffle(image_array, random_state=0)[:n_colors + 1]\n",
  94. "print(\"Predicting color indices on the full image (random)\")\n",
  95. "t0 = time()\n",
  96. "labels_random = pairwise_distances_argmin(codebook_random,\n",
  97. " image_array,\n",
  98. " axis=0)\n",
  99. "print(\" done in %0.3fs.\" % (time() - t0))\n",
  100. "\n",
  101. "\n",
  102. "def recreate_image(codebook, labels, w, h):\n",
  103. " \"\"\"Recreate the (compressed) image from the code book & labels\"\"\"\n",
  104. " d = codebook.shape[1]\n",
  105. " image = np.zeros((w, h, d))\n",
  106. " label_idx = 0\n",
  107. " for i in range(w):\n",
  108. " for j in range(h):\n",
  109. " image[i][j] = codebook[labels[label_idx]]\n",
  110. " label_idx += 1\n",
  111. " return image"
  112. ]
  113. },
  114. {
  115. "cell_type": "code",
  116. "execution_count": 3,
  117. "metadata": {},
  118. "outputs": [
  119. {
  120. "data": {
  121. "text/plain": [
  122. "<matplotlib.image.AxesImage at 0x7f7bcdf9aba8>"
  123. ]
  124. },
  125. "execution_count": 3,
  126. "metadata": {},
  127. "output_type": "execute_result"
  128. },
  129. {
  130. "data": {

机器学习越来越多应用到飞行器、机器人等领域,其目的是利用计算机实现类似人类的智能,从而实现装备的智能化与无人化。本课程旨在引导学生掌握机器学习的基本知识、典型方法与技术,通过具体的应用案例激发学生对该学科的兴趣,鼓励学生能够从人工智能的角度来分析、解决飞行器、机器人所面临的问题和挑战。本课程主要内容包括Python编程基础,机器学习模型,无监督学习、监督学习、深度学习基础知识与实现,并学习如何利用机器学习解决实际问题,从而全面提升自我的《综合能力》。