|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- {
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# Color Quantization by K-Means\n",
- "\n",
- "Performs a pixel-wise **Vector Quantization (VQ)** of an image of the summer palace (China), reducing the number of colors required to show the image from 96,615 unique colors to 64, while preserving the overall appearance quality.\n",
- "\n",
- "In this example, pixels are represented in a 3D-space and K-means is used to find 64 color clusters. In the image processing literature, the codebook obtained from K-means (the cluster centers) is called the color palette. Using a single byte, up to 256 colors can be addressed, whereas an RGB encoding requires 3 bytes per pixel. The GIF file format, for example, uses such a palette.\n",
- "\n",
- "\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "% matplotlib inline\n",
- "import numpy as np\n",
- "import matplotlib.pyplot as plt\n",
- "from sklearn.cluster import KMeans\n",
- "from sklearn.metrics import pairwise_distances_argmin\n",
- "from sklearn.datasets import load_sample_image\n",
- "from sklearn.utils import shuffle\n",
- "from time import time"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/bushuhui/.virtualenv/dl/lib/python3.5/site-packages/sklearn/datasets/base.py:762: DeprecationWarning: `imread` is deprecated!\n",
- "`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.\n",
- "Use ``imageio.imread`` instead.\n",
- " images = [imread(filename) for filename in filenames]\n",
- "/home/bushuhui/.virtualenv/dl/lib/python3.5/site-packages/sklearn/datasets/base.py:762: DeprecationWarning: `imread` is deprecated!\n",
- "`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.\n",
- "Use ``imageio.imread`` instead.\n",
- " images = [imread(filename) for filename in filenames]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Fitting model on a small sub-sample of the data\n",
- " done in 0.185s.\n",
- "Predicting color indices on the full image (k-means)\n",
- " done in 0.128s.\n",
- "Predicting color indices on the full image (random)\n",
- " done in 0.095s.\n"
- ]
- }
- ],
- "source": [
- "n_colors = 64\n",
- "\n",
- "# Load the Summer Palace photo\n",
- "china = load_sample_image(\"china.jpg\")\n",
- "\n",
- "# Convert to floats instead of the default 8 bits integer coding. Dividing by\n",
- "# 255 is important so that plt.imshow behaves works well on float data (need to\n",
- "# be in the range [0-1])\n",
- "china = np.array(china, dtype=np.float64) / 255\n",
- "\n",
- "# Load Image and transform to a 2D numpy array.\n",
- "w, h, d = original_shape = tuple(china.shape)\n",
- "assert d == 3\n",
- "image_array = np.reshape(china, (w * h, d))\n",
- "\n",
- "print(\"Fitting model on a small sub-sample of the data\")\n",
- "t0 = time()\n",
- "image_array_sample = shuffle(image_array, random_state=0)[:1000]\n",
- "kmeans = KMeans(n_clusters=n_colors, random_state=0).fit(image_array_sample)\n",
- "print(\" done in %0.3fs.\" % (time() - t0))\n",
- "\n",
- "# Get labels for all points\n",
- "print(\"Predicting color indices on the full image (k-means)\")\n",
- "t0 = time()\n",
- "labels = kmeans.predict(image_array)\n",
- "print(\" done in %0.3fs.\" % (time() - t0))\n",
- "\n",
- "\n",
- "codebook_random = shuffle(image_array, random_state=0)[:n_colors + 1]\n",
- "print(\"Predicting color indices on the full image (random)\")\n",
- "t0 = time()\n",
- "labels_random = pairwise_distances_argmin(codebook_random,\n",
- " image_array,\n",
- " axis=0)\n",
- "print(\" done in %0.3fs.\" % (time() - t0))\n",
- "\n",
- "\n",
- "def recreate_image(codebook, labels, w, h):\n",
- " \"\"\"Recreate the (compressed) image from the code book & labels\"\"\"\n",
- " d = codebook.shape[1]\n",
- " image = np.zeros((w, h, d))\n",
- " label_idx = 0\n",
- " for i in range(w):\n",
- " for j in range(h):\n",
- " image[i][j] = codebook[labels[label_idx]]\n",
- " label_idx += 1\n",
- " return image"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.image.AxesImage at 0x7f7bcdf9aba8>"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
|