|
|
- {
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# 逻辑回归\n",
- "\n",
- "逻辑回归(Logistic Regression, LR)模型其实仅在线性回归的基础上,套用了一个逻辑函数,但也就由于这个逻辑函数,使得逻辑回归模型能够输出类别的概率。逻辑回归的本质是:假设数据服从这个分布,然后使用极大似然估计做参数的估计。\n",
- "\n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 1. 什么是回归\n",
- "\n",
- "一说回归最先想到的是终结者那句:I'll be back\n",
- "\n",
- "regress中,re表示back,gress等于go,数值go back to mean value,也就是I'll be back 的意思\n",
- "\n",
- "在数理统计中,回归是确定多种变量相互依赖的定量关系的方法\n",
- "\n",
- "> 通俗理解:越来越接近期望值的过程,***回归*** 于事物的本质\n",
- "\n",
- "最简单的回归是线性回归(Linear Regression),也就是通过最小二乘等方法得到模型的参数。线性回归假设输出变量是若干输出变量的线性组合,并根据这一关系求解线性组合中的最优系数。\n",
- "\n",
- "通俗理解:输出一个线性函数,例如$y=f(x; \\theta)$,通过寻找最优的参数$\\theta$使得观测数据与模型数据相吻合。\n",
- "\n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "\n",
- "## 2. 逻辑回归模型\n",
- "回归是一种比较容易理解的模型,就相当于$y=f(x)$,表明自变量$x$与因变量$y$的关系。\n",
- "\n",
- "以常见的看医举例,医生治病时的望、闻、问、切,之后判定病人是否生病或生了什么病,其中的望闻问切就是获取自变量$x$,即特征数据,判断是否生病就相当于获取因变量$y$,即预测分类。$X$为数据点——肿瘤的大小,$Y$为观测值——是否是恶性肿瘤。通过构建线性回归模型,如$h_\\theta(x)$所示,构建线性回归模型后,即可以根据肿瘤大小,预测是否为恶性肿瘤$h_\\theta(x)) \\ge 0.5$为恶性,$h_\\theta(x) \\lt 0.5$为良性。\n",
- "\n",
- "\n",
- "\n",
- "然而线性回归的鲁棒性很差,例如在上图的数据集上建立回归,因最右边噪点的存在,使回归模型在训练集上表现都很差。这主要是由于线性回归在整个实数域内敏感度一致,而分类范围,需要在$[0,1]$。\n",
- "\n",
- "逻辑回归就是一种减小预测范围,将预测值限定为$[0,1]$间的一种回归模型,其回归方程与回归曲线如下图所示。逻辑曲线在$z=0$时,十分敏感,在$z>>0$或$z<<0$处,都不敏感,将预测值限定为$(0,1)$。\n",
- "\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "%matplotlib inline\n",
- "import matplotlib.pyplot as plt\n",
- "import numpy as np\n",
- "\n",
- "plt.figure()\n",
- "plt.axis([-10,10,0,1])\n",
- "plt.grid(True)\n",
- "X=np.arange(-10,10,0.1)\n",
- "y=1/(1+np.e**(-X))\n",
- "plt.plot(X,y,'b-')\n",
- "plt.title(\"Logistic function\")\n",
- "plt.savefig(\"fig-res-logstic_fuction.pdf\")\n",
- "plt.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 2.1 逻辑回归表达式\n",
- "\n",
- "这个函数称为Logistic函数(Logistic Function),也称为Sigmoid函数(Sigmoid Function)。函数公式如下:\n",
- "\n",
- "$$\n",
- "g(z) = \\frac{1}{1+e^{-z}}\n",
- "$$\n",
- "\n",
- "Logistic函数:\n",
- "* 当$z$趋近于无穷大时,$g(z)$趋近于1;\n",
- "* 当$z$趋近于无穷小时,$g(z)$趋近于0。\n",
- "\n",
- "Logistic函数的图形如上图所示。Logistic函数求导时有一个特性,这个特性将在下面的推导中用到,这个特性为:\n",
- "$$\n",
- "g'(z) = \\frac{d}{dz} \\frac{1}{1+e^{-z}} \\\\\n",
- " = \\frac{1}{(1+e^{-z})^2}(e^{-z}) \\\\\n",
- " = \\frac{1}{(1+e^{-z})} (1 - \\frac{1}{(1+e^{-z})}) \\\\\n",
- " = g(z)(1-g(z))\n",
- "$$"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "逻辑回归本质上是线性回归,只是在特征到结果的映射中加入了一层函数映射,即先把特征线性求和,然后使用函数$g(z)$将做为假设函数来预测。$g(z)$可以将连续值映射到0到1之间。线性回归模型的表达式带入$g(z)$,就得到逻辑回归的表达式:\n",
- "\n",
- "$$\n",
- "h_\\theta(x) = g(\\theta^T x) = \\frac{1}{1+e^{-\\theta^T x}}\n",
- "$$"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 2.2 逻辑回归的软分类\n",
- "\n",
- "现在我们将y的取值$h_\\theta(x)$通过Logistic函数归一化到(0,1)间,$y$的取值有特殊的含义,它表示结果取1的概率,因此对于输入$x$分类结果为类别1和类别0的概率分别为:\n",
- "\n",
- "$$\n",
- "P(y=1|x,\\theta) = h_\\theta(x) \\\\\n",
- "P(y=0|x,\\theta) = 1 - h_\\theta(x)\n",
- "$$\n",
- "\n",
- "对上面的表达式合并一下就是:\n",
- "\n",
- "$$\n",
- "p(y|x,\\theta) = (h_\\theta(x))^y (1 - h_\\theta(x))^{1-y}\n",
- "$$\n",
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 2.3 梯度上升\n",
- "\n",
- "得到了逻辑回归的表达式,下一步跟线性回归类似,构建似然函数,然后最大似然估计,最终推导出$\\theta$的迭代更新表达式。只不过这里用的不是梯度下降,而是梯度上升,因为这里是最大化似然函数。\n",
- "\n",
- "假设训练样本相互独立,那么似然函数表达式为:\n",
- "\n",
- "\n",
- "同样对似然函数取log,转换为:\n",
- "\n",
- "\n",
- "转换后的似然函数对$\\theta$求偏导,在这里我们以只有一个训练样本的情况为例:\n",
- "\n",
- "\n",
- "这个求偏导过程中:\n",
- "* 第一步是对$\\theta$偏导的转化,依据偏导公式:$y=lnx$, $y'=1/x$。\n",
- "* 第二步是根据$g(z)$求导的特性$g'(z) = g(z)(1 - g(z))$ 。\n",
- "* 第三步就是普通的变换。\n",
- "\n",
- "这样我们就得到了梯度上升每次迭代的更新方向,那么$\\theta$的迭代表达式为:\n",
- "$$\n",
- "\\theta = \\theta + \\eta (y^i - h_\\theta(x^i)) x_j^i\n",
- "$$\n",
- "\n",
- "其中$\\eta$是学习速率。"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 1.4 示例程序"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "%matplotlib inline\n",
- "\n",
- "#from __future__ import division\n",
- "import numpy as np\n",
- "import sklearn.datasets\n",
- "import matplotlib.pyplot as plt\n",
- "\n",
- "np.random.seed(0)\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Text(0.5, 1.0, 'Original Data')"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "# load sample data\n",
- "data, label = sklearn.datasets.make_moons(200, noise=0.30)\n",
- "\n",
- "plt.scatter(data[:,0], data[:,1], c=label)\n",
- "plt.savefig(\"fig-res-logistic_train_data.pdf\")\n",
- "plt.title(\"Original Data\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [],
- "source": [
- "def plot_decision_boundary(predict_func, data, label, figName=None):\n",
- " \"\"\"画出结果图\n",
- " Args:\n",
- " pred_func (callable): 预测函数\n",
- " data (numpy.ndarray): 训练数据集合\n",
- " label (numpy.ndarray): 训练数据标签\n",
- " \"\"\"\n",
- " x_min, x_max = data[:, 0].min() - .5, data[:, 0].max() + .5\n",
- " y_min, y_max = data[:, 1].min() - .5, data[:, 1].max() + .5\n",
- " h = 0.01\n",
- "\n",
- " xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))\n",
- "\n",
- " Z = predict_func(np.c_[xx.ravel(), yy.ravel()])\n",
- " Z = Z.reshape(xx.shape)\n",
- "\n",
- " plt.contourf(xx, yy, Z, cmap=plt.cm.Spectral) #画出登高线并填充\n",
- " plt.scatter(data[:, 0], data[:, 1], c=label, cmap=plt.cm.Spectral)\n",
- " if figName != None: plt.savefig(figName)\n",
- " plt.show()\n",
- "\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [],
- "source": [
- "\n",
- "def sigmoid(x):\n",
- " return 1.0 / (1 + np.exp(-x))\n",
- "\n",
- "class Logistic(object):\n",
- " \"\"\"logistic回归模型\"\"\"\n",
- " def __init__(self, data, label):\n",
- " self.data = data\n",
- " self.label = label\n",
- "\n",
- " # parameters\n",
- " self.data_num, n = np.shape(data)\n",
- " self.weights = np.ones(n)\n",
- " self.b = 1\n",
- "\n",
- " def train(self, num_iteration=150):\n",
- " \"\"\"随机梯度上升算法\n",
- " FIXME: change to same API to sklean\n",
- " Args:\n",
- " num_iteration (int): 迭代次数\n",
- " \"\"\"\n",
- " # 学习速率\n",
- " alpha = 0.01\n",
- " \n",
- " for j in range(num_iteration):\n",
- " data_index = list(range(self.data_num))\n",
- " for i in range(self.data_num):\n",
- " rand_index = int(np.random.uniform(0, len(data_index)))\n",
- " \n",
- " error = self.label[rand_index] - \\\n",
- " sigmoid(sum(self.data[rand_index] * self.weights + self.b))\n",
- " \n",
- " self.weights += alpha * error * self.data[rand_index]\n",
- " self.b += alpha * error\n",
- " \n",
- " del(data_index[rand_index])\n",
- "\n",
- " def predict(self, predict_data):\n",
- " \"\"\"预测函数\"\"\"\n",
- " result = list(map(lambda x: 1 if sum(self.weights * x + self.b) > 0 else 0,\n",
- " predict_data))\n",
- " return np.array(result)\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "logistic = Logistic(data, label)\n",
- "logistic.train(200)\n",
- "plot_decision_boundary(lambda x: logistic.predict(x), data, label, \"logistic_pred_res.pdf\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 3. 如何用sklearn解决逻辑回归问题?"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "accuracy train = 0.891667\n",
- "accuracy test = 0.825000\n"
- ]
- },
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQYAAAD4CAYAAAAO2kjhAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAWjklEQVR4nO3df5RcdX3/8ecr4UdQAhKCNEYgnIq0FCVovlaxWqBFg7XFVm2LlqKlRVvx2yqtou3XX7U/PPVXf6A9sVAoWKwiVgoIIsJBxApJGpCAiILUQCQECiZIQ7L7+v5x7+rs7GbnzmZm7p2d1+Oce3bunTufee+enfd8ft3PlW0iIlrNqzuAiGieJIaImCKJISKmSGKIiCmSGCJiiiSGiJgiiaEBJO0l6T8kPSrpM7tQzmslfbGXsdVB0hcknVp3HKMsiaELkl4jabWkrZI2lv/AP9eDol8FHAjsb/vVsy3E9idtv6QH8Uwi6VhJlvS5tuNHlcevq1jOeyRd2Ok82yfaPn+W4UYPJDFUJOmtwEeBv6T4EB8MfAw4qQfFHwJ8y/aOHpTVLw8CL5C0f8uxU4Fv9eoNVMj/ZBPYztZhA/YFtgKvnuGcPSkSx/3l9lFgz/K5Y4ENwJnAJmAj8PryufcCTwDby/c4DXgPcGFL2csAA7uV+68D7ga2APcAr205fkPL644BbgYeLX8e0/LcdcCfA18ty/kisHgnv9tE/P8IvKk8Nh+4D3gXcF3LuX8LfA/4AbAGeFF5fGXb73lLSxx/UcbxOPCM8tjvls9/HPhsS/kfAK4BVPf/xVzekp2reQGwAPjcDOf8KfB8YDlwFPA84M9anv8JigSzlOLDf7ak/Wy/m6IW8m+297Z9zkyBSHoy8HfAibYXUnz4101z3iLg8vLc/YEPA5e3feO/Bng98FRgD+CPZ3pv4F+A3y4fvxS4jSIJtrqZ4m+wCPhX4DOSFti+su33PKrlNacApwMLgXvbyjsTeJak10l6EcXf7lSXWSL6I4mhmv2BzZ65qv9a4H22N9l+kKImcErL89vL57fbvoLiW/PwWcYzDhwpaS/bG22vn+acXwLusn2B7R22LwK+Cfxyyzn/bPtbth8HPk3xgd4p2zcCiyQdTpEg/mWacy60/VD5nh+iqEl1+j3Ps72+fM32tvJ+SPF3/DBwIfBm2xs6lBe7KImhmoeAxZJ2m+GcpzH52+7e8tiPymhLLD8E9u42ENuPAb8BvBHYKOlyST9VIZ6JmJa27H9/FvFcAJwBHMc0NShJfyzpjnKE5RGKWtLiDmV+b6YnbX+doukkigQWfZbEUM3XgG3AK2Y4536KTsQJBzO1ml3VY8CTWvZ/ovVJ21fZPgFYQlEL+ESFeCZium+WMU24APgD4Iry2/xHyqr+24BfB/az/RSK/g1NhL6TMmdsFkh6E0XN4/6y/JiGpAWSbpJ0i6T1kt5bHj9P0j2S1pXb8k5lzfQNGCXbj0p6F0W/wA6KjrrtwC8Cx9l+G3AR8GeSbqb4R38XRdV3NtYBb5d0MMUH6x0TT0g6kKIv40sUnXVbKZoW7a4A/l7Sayi+ZV8JHAFcNsuYALB9j6Sfp/gGb7cQ2EExgrGbpLOAfVqefwA4QdI829PFPIWkZwLvp+gA/SFwk6Qv2F43+99iztoGHG97q6TdgRskfaF87k9sX1y1oNQYKirby2+l6FB8kKL6ewbw7+Up7wdWA7cC3wDWlsdm815XA/9WlrWGyR/meWUc9wMPAz8P/P40ZTwEvJyi8+4him/al9vePJuY2sq+wfZ0taGrgCsphjDvBf6Xyc2EiclbD0la2+l9yqbbhcAHbN9i+y7gncAFkvbcld9hLnJha7m7e7nNqpNW6dyNmDskzaf4MnkGcLbtt0s6j2JkbRvFUO9ZtrfNWE4SQ0R9Xnrck7z54UqtKtbeum09RS1swirbq6Y7V9JTKDqH30xRY/w+xZD0KuA7tt8303uljyGiRpsfHuPGK5d2PhFY8LR7/tf2iirn2n5E0rXAStsfLA9vk/TPdJ6vkj6GiDoZGMeVtk4kHVDWFJC0F3AC8E1JS8pjohhZu61TWakxRNRsfNpBpVlZApxf9jPMAz5t+zJJX5Z0AMWw8TqKOTAzSmKIqJExYz3q57N9K3D0NMeP77asNCVmSdJKSXdK+nY5Xh89JOlcSZskdaz2DrteNSV6KYlhFsqq2tnAiRSThk6WdES9Uc0551FckTmnGRjDlbZBSlNidp4HfNv23QCSPkWxLsPttUY1h9i+XtKyuuPoNwPbq00CHajUGGZnKZNn9G1g8sVJEZWNV9wGKTWGiBq5hmZCFUkMs3MfcFDL/tPZ9asWYxQZxpqXF9KUmKWbgcMkHSppD+A3gUtrjimGUDHBqXlNiSSGWSgXXDmD4mrCOygmkky3ilLMkqSLKNbBOFzSBkmn1R1Tf4ixitsgpSkxS+XybFfUHcdcZfvkumMYBAPjDWxKJDFE1MjAEw2suCcxRNRs3INtJlSRxBBRo2LmYxJDRLQwYixNiYho18SmRPNS1RCRdHrdMcx1c/1vPNGUaNpwZRLDrpnT/7QNMcf/xmLM8yptg5SmRESNDGxnft1hTNGoxLB40XwvO2j3usOo7OClu7HiqAUNnJ6yc3fdsU/nkxpkwfy92XePpw7V3/jxHVt4YvzxSnV/WwOvDVTRqMSw7KDduemqgzqfGLP2S895ad0hzHk3bu7u9prjGa6MiFZF52NqDBExSZoSEdGmuOw6iSEi2ow1cIJTEkNEjYzY7uZ9DJsXUcQISedjRExhlKZEREyVzseImMQmw5UR0U6Z+RgRkxl4okejEpIWANcDe1J8ti+2/W5JhwKfAvYH1gCn2H5iprKaV4eJGCFGjLvaVsE24HjbRwHLgZWSng98APiI7WcA/wN0XIo/iSGiZmPMq7R14sLWcnf3cjNwPHBxefx84BWdykpiiKhRcV+JeZW2KiTNl7QO2ARcDXwHeKS8SRJUvAFz+hgiatXVsm2LJa1u2V9le1XrCbbHgOWSngJ8Dvip2USVxBBRo4kaQ0Wbba+oVK79iKRrgRcAT5G0W1lrqHQD5jQlImrWq8VgJR1Q1hSQtBdwAsW9Va8FXlWedirw+U5lpcYQUSNbbB/v2cdwCXC+pPkUX/qftn2ZpNuBT0l6P/BfwDmdCkpiiKhRsR5DbyY42b4VOHqa43cDz+umrCSGiFplBaeIaFN0PmZKdES0yXoMETHJxJTopkliiKhZ1mOIiEls2D6exBARLYqmRBJDRLQZ9C3uq0hiiKhRhisjYhppSkTENLLmY0RMUqwSncQQES2M2DE+v+4wpkhiiKhZmhIRMUlGJSJiWhmViIjJqt8zYqCSGCJq1MsVnHopiSGiZqkxRMQkBnY08OrKvkYkaaWkOyV9W9JZ/XyviGHU43tX9kzfEkO5hPXZwInAEcDJko7o1/tFDKtxVGkbpH7WGJ4HfNv23eUttz8FnNTH94sYPqaRNYZ+9jEsBb7Xsr8B+Nk+vl/E0MkEp52QdDpwOsDBS2sPJ2LgRi0x3Acc1LI/7c00y7v1rgJYcdQC9zGeiMYxYmzERiVuBg6TdKikPYDfBC7t4/tFDKUmdj72rcZge4ekM4CrgPnAubbX9+v9IoaRPXpNCWxfAVzRz/eIGHYetcQQEZ008yKq5vV6RIwYW5W2TiQdJOlaSbdLWi/pD8vj75F0n6R15fayTmWlxhBRox7PY9gBnGl7raSFwBpJV5fPfcT2B6sWlMQQUaceLgZreyOwsXy8RdIdFBMNu5amRESNTO+aEq0kLQOOBr5eHjpD0q2SzpW0X6fXJzFE1KqrqysXS1rdsp0+bYnS3sBngT+y/QPg48BPAsspahQf6hRVmhIRNXP1+b6bba+Y6QRJu1MkhU/avqQo3w+0PP8J4LJOb5QaQ0TNejgqIeAc4A7bH245vqTltF8FbutUVmoMETWyezrB6YXAKcA3JK0rj72TYi2U5RRdGt8F3tCpoCSGiJr1arjS9g0w7UUVXc8+TmKIqNn4ePNmPlZKDJL2BF4JLGt9je339SesiNFguh+KHISqNYbPA48Ca4Bt/QsnYvQ0cRGSqonh6bZX9jWSiFHU287Hnqk6XHmjpGf1NZKIUeWK2wDNWGOQ9A2KkHYDXi/pboqmhADbfnb/Q4yY25pYY+jUlHj5QKKIGGFdzHwcmBkTg+17ASRdYPuU1uckXUAxmSIiZskGN3Ax2Kqdjz/TulPeZeq5vQ8nYvQ0scYwY6qS9A5JW4BnS/qBpC3l/iaKIcyI2FUN7HycMTHY/ivbC4G/sb2P7YXltr/tdwwoxog5rNoFVIPuoKzalPiCpBe3H7R9fY/jiRg9DWxKVE0Mf9LyeAHFDWvXAMf3PKKIUdLQCU6VEoPtX27dl3QQ8NF+BBQxcoa4xtBuA/DTvQwkYmQNa41B0t/z47w2j2LtuLV9iilitAxxjWF1y+MdwEW2v9qHeCJGixnOGkM5mekltl87gHgiRs7QTXACsD0GHFLeyj4ieq2BE5yqNiXuBr4q6VLgsYmDrSvRRsQsDWNTovSdcpsHLCyPNbACFDFkDBqvO4ipqiaG221/pvWApFf3IZ6IEaNG1hiqXu853XURuVYioheGrY9B0onAy4Clkv6u5al9KIYtI2JXNbBR3qkpcT/FNRG/Uv6csAV4S7+Cihgpw5YYbN8C3CLpk7a3DyimiNExjBOcWhaDpbhf5mRZDDZi12nYagz8eDHYN5U/Lyh//haNrABFDKEGfpKqLgZ7gu2jW556u6S1wFm9DOZbtz6Jlz5teS+LjDa/fefqzifFLvnmr/2wq/ObWGOoOlwpSS9s2Tmmi9dGxEysalsHkg6SdK2k2yWtl/SH5fFFkq6WdFf5c79OZVX9cJ8GfEzSdyXdC3wM+J2Kr42Inak6h6FarWIHcKbtI4DnA2+SdARFzf4a24cB11Chpl91Bac1wFGS9i33H60UZkR01qOmhO2NwMby8RZJdwBLgZOAY8vTzgeuA94+U1lVF2rZE3glsAzYbWKEwvb7ug0+IibrRx+DpGXA0cDXgQPLpAHwfeDATq+veq3E54FHKSY5bes+zIjYqeqJYbGk1t7jVbZXtZ8kaW/gs8Af2f5B61QD25Y6p6KqieHptldWPDciKlJ3V1dutr1ixvKk3SmSwidtX1IefkDSEtsbJS2huGHUjKp2Pt4o6VkVz42IbvRuVELAOcAdbWulXAqcWj4+lQp3kataY/g54HWS7qFoSoiiVpKZjxG7qnd9DC+kuNH0NyStK4+9E/hr4NOSTgPuBX69U0FVE8OJswgyIiroVeej7RsovrSn8wvdlFU1MTRwblbEHNHAT1fVxHA5RfiiuEXdocCdwM/0Ka6I0eBmTomuOsFpUsejpOcAf9CXiCJGzbAmhna210r62V4HEzGKhnYxWElvbdmdBzyHYnWniJiDqtYYFrY83kHR5/DZ3ocTMYKGtSlh+73wo6mW2N7az6AiRkZDOx8rzXyUdKSk/wLWA+slrZF0ZH9DixgRDVw+vuqU6FXAW20fYvsQ4MzyWETsqgYmhqp9DE+2fe3Eju3rJD25TzFFjAzRzKZE5ZvaSvp/TF4M9u7+hBQxQhp678qqTYnfAQ4ALqEYjVhMlnaL6I1hbEpImg9cYvu4AcQTMXoa2JToWGOwPQaMT6z3GBG9JVfbBqlqH8NWimu8rwYemzho+//2JaqIUdLAGkPVxHBJucGPf43m3XAvYtjU0H9QRad7V55Esd7j2eX+TRSdkKbD8tMRUc0wjkq8jWK9uAl7AM+lWKP+jX2KKWKkDGMfwx62v9eyf4Pth4GHM8EpokeGrSkBTLrHne0zWnYP6H04ESOmoX0MnZoSX5f0e+0HJb0BuKk/IUWMDnWxDVKnGsNbgH+X9BpgbXnsucCewCv6GFfE6GhgjWHGxGB7E3CMpOP58cKvl9v+ct8jixgRQ3sRVZkIkgwi+qGBw5WzWgw2InqkoSs4JTFE1C2JISLapcYQEVMlMUREu9QYImKyIZ35GBF9JIqrK6tslcqTzpW0SdJtLcfeI+k+SevK7WWdykliiKhbb9d8PA9YOc3xj9heXm5XdCqkb4lhuswVEVPJrrRVYft64OFdjamfNYbzmD5zRcSEqrWFXe+HOEPSreUX9n6dTu5bYuhV5oqY67pYqGWxpNUt2+kV3+LjwE8Cy4GNwIc6vSCjEhF1q14b2Gx7RdfF2w9MPJb0CeCyTq+pvfNR0ukTGXA72+oOJ2Lg+r20m6QlLbu/CnTs96u9xmB7FeUNcvfRogaO6Eb0UY9vUSfpIoo1WRdL2gC8GzhW0vLi3fgu8IZO5dSeGCJGXg+/Dm2fPM3hc7otp5/DlRcBXwMOl7RB0mn9eq+IYTVxt+thWyV61naSuSKiXcU5CoOUpkREzXIRVURM1tCLqJIYImrWxFvUJTFE1CyJISImM+l8jIip0vkYEVMlMUREq4kJTk2TxBBRJzt9DBExVUYlImKKNCUiYjID483LDEkMEXVrXl5IYoioW5oSETFVRiUiol1qDBExiQxK52NETJF5DBHRrurt5wYpiSGiTlnBKSKmyrUSETGNjEpExFSpMUTEJAaNJTFERLvm5YUkhoi6ZbgyIqZqYGLo201tI6ICU8x8rLJVIOlcSZsk3dZybJGkqyXdVf7cr1M5SQwRNRJGrrZVdB6wsu3YWcA1tg8Drin3Z5TEEFG3iQVhO22VivL1wMNth08Czi8fnw+8olM56WOIqJOB/g9XHmh7Y/n4+8CBnV6QxBBRsy6aCYslrW7ZX2V7VTfvZdtS57mWSQwRdaueGDbbXjGLd3hA0hLbGyUtATZ1ekH6GCJqVbF/YdeGNC8FTi0fnwp8vtMLkhgi6jRxt+seJQZJFwFfAw6XtEHSacBfAydIugv4xXJ/RmlKRNSthys42T55J0/9QjflJDFE1CxToiNiMgNjzVv0MYkholZZwamjLfzP5i/54nvrjqMLi4HNdQfRjS89s+4IujZ0f2PgkK7OTmKYme0D6o6hG5JWz3JcOSoaib9xEkNETJK7XUfEVAan83Gu6WqeeszK3P4bN3RUIjMfd0G3F7D0gqQxSesk3SbpM5KetAtlnSfpVeXjf5J0xAznHivpmNm+12zV8TceuP5Pie5aEsPwedz2cttHAk8Ab2x9UtKsaoG2f9f27TOcciww8MQwEpIYose+Ajyj/Db/iqRLgdslzZf0N5JulnSrpDcAqPAPku6U9CXgqRMFSbpO0ory8UpJayXdIukaScsoEtBbytrKiwb/q85VA7mIqmvpYxhSZc3gRODK8tBzgCNt3yPpdOBR2/9H0p7AVyV9ETgaOBw4gmKxjtuBc9vKPQD4BPDisqxFth+W9I/AVtsfHMgvOCoMjDevjyGJYfjsJWld+fgrwDkUVfybbN9THn8J8OyJ/gNgX+Aw4MXARbbHgPslfXma8p8PXD9Rlu32ZcKi1zKPIXrgcdvLWw9IAnis9RDwZttXtZ33sr5HF91rYGJIH8PcdBXw+5J2B5D0TElPBq4HfqPsg1gCHDfNa/8TeLGkQ8vXLiqPbwEW9j/0EWPjsbFK2yClxjA3/ROwDFirojrxIMXKwJ8DjqfoW/hvigU9JrH9YNlHcYmkeRTLgJ0A/AdwsaSTKGojXxnA7zEaGjjzUW5gNSZiVOy72wF+wcKTKp171SPnrBnUdSOpMUTUyc6oRERMo4G19iSGiJo5NYaImCwrOEVEOwMDHoqsIokhokYG3MDhyiSGiDo5C7VExDRSY4iIqRpYY8jMx4gaSbqSYon8KjbbXtnPeCYkMUTEFLm6MiKmSGKIiCmSGCJiiiSGiJgiiSEipvj/xGX+BfMFGU4AAAAASUVORK5CYII=\n",
- "text/plain": [
- "<Figure size 288x288 with 2 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "%matplotlib inline\n",
- "\n",
- "import sklearn.datasets\n",
- "from sklearn.linear_model import LogisticRegression\n",
- "from sklearn.metrics import confusion_matrix\n",
- "from sklearn.metrics import accuracy_score\n",
- "import matplotlib.pyplot as plt\n",
- "\n",
- "# 生成模拟数据\n",
- "data, label = sklearn.datasets.make_moons(200, noise=0.30)\n",
- "\n",
- "# 计算得到训练、测试数据个数\n",
- "N = len(data)\n",
- "N_train = int(N*0.6)\n",
- "N_test = N - N_train\n",
- "\n",
- "# 分割成训练、测试数据\n",
- "x_train = data[:N_train, :]\n",
- "y_train = label[:N_train]\n",
- "x_test = data[N_train:, :]\n",
- "y_test = label[N_train:]\n",
- "\n",
- "# 进行逻辑回归\n",
- "lr = LogisticRegression()\n",
- "lr.fit(x_train,y_train)\n",
- "\n",
- "# 预测\n",
- "pred_train = lr.predict(x_train)\n",
- "pred_test = lr.predict(x_test)\n",
- "\n",
- "# 计算训练/测试精度\n",
- "acc_train = accuracy_score(y_train, pred_train)\n",
- "acc_test = accuracy_score(y_test, pred_test)\n",
- "print(\"accuracy train = %f\" % acc_train)\n",
- "print(\"accuracy test = %f\" % acc_test)\n",
- "\n",
- "# 绘制混淆矩阵\n",
- "cm = confusion_matrix(y_test,pred_test)\n",
- "\n",
- "plt.matshow(cm)\n",
- "plt.title('Confusion Matrix')\n",
- "plt.colorbar()\n",
- "plt.ylabel('Groundtruth')\n",
- "plt.xlabel(u'Predict')\n",
- "plt.savefig('fig-res-logistic_confusion_matrix.pdf')\n",
- "plt.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## 4. 多类识别问题"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 4.1 加载显示数据"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {
- "scrolled": false
- },
- "outputs": [
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "<Figure size 432x432 with 64 Axes>"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "import matplotlib.pyplot as plt \n",
- "from sklearn.datasets import load_digits\n",
- "\n",
- "# load data\n",
- "digits = load_digits()\n",
- "\n",
- "# copied from notebook 02_sklearn_data.ipynb\n",
- "fig = plt.figure(figsize=(6, 6)) # figure size in inches\n",
- "fig.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0.05, wspace=0.05)\n",
- "\n",
- "# plot the digits: each image is 8x8 pixels\n",
- "for i in range(64):\n",
- " ax = fig.add_subplot(8, 8, i + 1, xticks=[], yticks=[])\n",
- " ax.imshow(digits.images[i], cmap=plt.cm.binary)\n",
- " \n",
- " # label the image with the target value\n",
- " ax.text(0, 7, str(digits.target[i]))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "(1797, 64)\n",
- "accuracy train = 1.000000, accuracy_test = 0.905556\n",
- "score_train = 1.000000, score_test = 0.905556\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/bushuhui/anaconda3/envs/dl/lib/python3.7/site-packages/sklearn/linear_model/_logistic.py:765: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
- "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
- "\n",
- "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
- " https://scikit-learn.org/stable/modules/preprocessing.html\n",
- "Please also refer to the documentation for alternative solver options:\n",
- " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
- " extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)\n"
- ]
- }
- ],
- "source": [
- "from sklearn.datasets import load_digits\n",
- "from sklearn.linear_model import LogisticRegression\n",
- "from sklearn.metrics import accuracy_score\n",
- "from sklearn.manifold import Isomap\n",
- "\n",
- "import matplotlib.pyplot as plt \n",
- "\n",
- "# 加载示例数据\n",
- "digits, dig_label = load_digits(return_X_y=True)\n",
- "print(digits.shape)\n",
- "\n",
- "# 计算训练/测试数据个数\n",
- "N = len(digits)\n",
- "N_train = int(N*0.8)\n",
- "N_test = N - N_train\n",
- "\n",
- "# 分割训练/测试数据集\n",
- "x_train = digits[:N_train, :]\n",
- "y_train = dig_label[:N_train]\n",
- "x_test = digits[N_train:, :]\n",
- "y_test = dig_label[N_train:]\n",
- "\n",
- "# 进行逻辑回归分类\n",
- "lr = LogisticRegression()\n",
- "lr.fit(x_train, y_train)\n",
- "\n",
- "pred_train = lr.predict(x_train)\n",
- "pred_test = lr.predict(x_test)\n",
- "\n",
- "# 计算测试、训练精度\n",
- "acc_train = accuracy_score(y_train, pred_train)\n",
- "acc_test = accuracy_score(y_test, pred_test)\n",
- "print(\"accuracy train = %f, accuracy_test = %f\" % (acc_train, acc_test))\n",
- "\n",
- "score_train = lr.score(x_train, y_train)\n",
- "score_test = lr.score(x_test, y_test)\n",
- "print(\"score_train = %f, score_test = %f\" % (score_train, score_test))\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### 4.2 可视化特征\n",
- "\n",
- "针对机器学习的问题,一个比较好的方法是通过降维的方法将原始的高维特征降到2-3维并可视化处理,通过这样的方法可以对所要处理的数据有一个初步的认识。这里介绍最简单的降维方法主成分分析(Principal Component Analysis, PCA)。PCA寻求具有最大方差的特征的正交线性组合,因此可以更好地了解数据的结构。在这里,我们将使用Randomized PCA,因为当数据个数$N$比较大时,计算的效率更好。\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [
- {
- "data": {
|