|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- {
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# 用K-means进行颜色量化\n",
- "\n",
- "对圆明园的图像进行**像素矢量量化(VQ)**,将显示图像所需的颜色从96,615种减少到64种,同时保持整体外观质量。\n",
- "\n",
- "\n",
- "在本例中,像素在3d空间中表示,使用K-means找到64个颜色簇。在图像处理文献中,由K-means(聚类中心)得到的码本称为调色板。使用单个字节,最多可以寻址256种颜色,而RGB编码需要每个像素3个字节。例如,GIF文件格式就使用了这样一个调色板。\n",
- "\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "% matplotlib inline\n",
- "import numpy as np\n",
- "import matplotlib.pyplot as plt\n",
- "from sklearn.cluster import KMeans\n",
- "from sklearn.metrics import pairwise_distances_argmin\n",
- "from sklearn.datasets import load_sample_image\n",
- "from sklearn.utils import shuffle\n",
- "from time import time"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/bushuhui/.virtualenv/dl/lib/python3.5/site-packages/sklearn/datasets/base.py:762: DeprecationWarning: `imread` is deprecated!\n",
- "`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.\n",
- "Use ``imageio.imread`` instead.\n",
- " images = [imread(filename) for filename in filenames]\n",
- "/home/bushuhui/.virtualenv/dl/lib/python3.5/site-packages/sklearn/datasets/base.py:762: DeprecationWarning: `imread` is deprecated!\n",
- "`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.\n",
- "Use ``imageio.imread`` instead.\n",
- " images = [imread(filename) for filename in filenames]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Fitting model on a small sub-sample of the data\n",
- " done in 0.185s.\n",
- "Predicting color indices on the full image (k-means)\n",
- " done in 0.128s.\n",
- "Predicting color indices on the full image (random)\n",
- " done in 0.095s.\n"
- ]
- }
- ],
- "source": [
- "n_colors = 64\n",
- "\n",
- "# 加载圆明园的图像\n",
- "china = load_sample_image(\"china.jpg\")\n",
- "\n",
- "# 转化为浮点数而不是默认的8位整数编码。\n",
- "# 除以255是重要的这样plt.imshow在浮点数(需要在[0-1]的范围)上的表现会很好 \n",
- "\n",
- "china = np.array(china, dtype=np.float64) / 255\n",
- "\n",
- "# 加载图像并转化成2D的numpy数组。\n",
- "w, h, d = original_shape = tuple(china.shape)\n",
- "assert d == 3\n",
- "image_array = np.reshape(china, (w * h, d))\n",
- "\n",
- "print(\"Fitting model on a small sub-sample of the data\")\n",
- "t0 = time()\n",
- "image_array_sample = shuffle(image_array, random_state=0)[:1000]\n",
- "kmeans = KMeans(n_clusters=n_colors, random_state=0).fit(image_array_sample)\n",
- "print(\" done in %0.3fs.\" % (time() - t0))\n",
- "\n",
- "# 获得所有点的标签\n",
- "print(\"Predicting color indices on the full image (k-means)\")\n",
- "t0 = time()\n",
- "labels = kmeans.predict(image_array)\n",
- "print(\" done in %0.3fs.\" % (time() - t0))\n",
- "\n",
- "\n",
- "codebook_random = shuffle(image_array, random_state=0)[:n_colors + 1]\n",
- "print(\"Predicting color indices on the full image (random)\")\n",
- "t0 = time()\n",
- "labels_random = pairwise_distances_argmin(codebook_random,\n",
- " image_array,\n",
- " axis=0)\n",
- "print(\" done in %0.3fs.\" % (time() - t0))\n",
- "\n",
- "\n",
- "def recreate_image(codebook, labels, w, h):\n",
- " \"\"\"Recreate the (compressed) image from the code book & labels\"\"\"\n",
- " d = codebook.shape[1]\n",
- " image = np.zeros((w, h, d))\n",
- " label_idx = 0\n",
- " for i in range(w):\n",
- " for j in range(h):\n",
- " image[i][j] = codebook[labels[label_idx]]\n",
- " label_idx += 1\n",
- " return image"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.image.AxesImage at 0x7f7bcdf9aba8>"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
|