|
|
- {
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {
- "scrolled": true
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "2.2.0\n",
- "sys.version_info(major=3, minor=6, micro=9, releaselevel='final', serial=0)\n",
- "matplotlib 3.3.4\n",
- "numpy 1.19.5\n",
- "pandas 1.1.5\n",
- "sklearn 0.24.2\n",
- "tensorflow 2.2.0\n",
- "tensorflow.keras 2.3.0-tf\n"
- ]
- }
- ],
- "source": [
- "import matplotlib as mpl\n",
- "import matplotlib.pyplot as plt\n",
- "%matplotlib inline\n",
- "import numpy as np\n",
- "import sklearn\n",
- "import pandas as pd\n",
- "import os\n",
- "import sys\n",
- "import time\n",
- "import tensorflow as tf\n",
- "\n",
- "from tensorflow import keras\n",
- "\n",
- "print(tf.__version__)\n",
- "print(sys.version_info)\n",
- "for module in mpl, np, pd, sklearn, tf, keras:\n",
- " print(module.__name__, module.__version__)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- " survived sex age n_siblings_spouses parch fare class deck \\\n",
- "0 0 male 22.0 1 0 7.2500 Third unknown \n",
- "1 1 female 38.0 1 0 71.2833 First C \n",
- "2 1 female 26.0 0 0 7.9250 Third unknown \n",
- "3 1 female 35.0 1 0 53.1000 First C \n",
- "4 0 male 28.0 0 0 8.4583 Third unknown \n",
- "\n",
- " embark_town alone \n",
- "0 Southampton n \n",
- "1 Cherbourg n \n",
- "2 Southampton y \n",
- "3 Southampton n \n",
- "4 Queenstown y \n",
- " survived sex age n_siblings_spouses parch fare class \\\n",
- "0 0 male 35.0 0 0 8.0500 Third \n",
- "1 0 male 54.0 0 0 51.8625 First \n",
- "2 1 female 58.0 0 0 26.5500 First \n",
- "3 1 female 55.0 0 0 16.0000 Second \n",
- "4 1 male 34.0 0 0 13.0000 Second \n",
- "\n",
- " deck embark_town alone \n",
- "0 unknown Southampton y \n",
- "1 E Southampton y \n",
- "2 C Southampton y \n",
- "3 unknown Southampton y \n",
- "4 D Southampton y \n",
- "--------------------------------------------------\n",
- "(627, 10)\n",
- "(264, 10)\n"
- ]
- }
- ],
- "source": [
- "# https://storage.googleapis.com/tf-datasets/titanic/train.csv\n",
- "# https://storage.googleapis.com/tf-datasets/titanic/eval.csv\n",
- "#fare 票价\n",
- "# n_siblings_spouses 兄弟姐妹,配偶总计数目\n",
- "# parch 不同代直系亲属\n",
- "#class 舱位的等级\n",
- "train_file = \"./data/titanic/train.csv\"\n",
- "eval_file = \"./data/titanic/eval.csv\"\n",
- "\n",
- "train_df = pd.read_csv(train_file)\n",
- "eval_df = pd.read_csv(eval_file)\n",
- "\n",
- "print(train_df.head())\n",
- "print(eval_df.head())\n",
- "print('-'*50)\n",
- "print(train_df.shape)\n",
- "print(eval_df.shape)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- " sex age n_siblings_spouses parch fare class deck \\\n",
- "0 male 22.0 1 0 7.2500 Third unknown \n",
- "1 female 38.0 1 0 71.2833 First C \n",
- "2 female 26.0 0 0 7.9250 Third unknown \n",
- "3 female 35.0 1 0 53.1000 First C \n",
- "4 male 28.0 0 0 8.4583 Third unknown \n",
- "\n",
- " embark_town alone \n",
- "0 Southampton n \n",
- "1 Cherbourg n \n",
- "2 Southampton y \n",
- "3 Southampton n \n",
- "4 Queenstown y \n",
- "--------------------------------------------------\n",
- " sex age n_siblings_spouses parch fare class deck \\\n",
- "0 male 35.0 0 0 8.0500 Third unknown \n",
- "1 male 54.0 0 0 51.8625 First E \n",
- "2 female 58.0 0 0 26.5500 First C \n",
- "3 female 55.0 0 0 16.0000 Second unknown \n",
- "4 male 34.0 0 0 13.0000 Second D \n",
- "\n",
- " embark_town alone \n",
- "0 Southampton y \n",
- "1 Southampton y \n",
- "2 Southampton y \n",
- "3 Southampton y \n",
- "4 Southampton y \n",
- "--------------------------------------------------\n",
- "0 0\n",
- "1 1\n",
- "2 1\n",
- "3 1\n",
- "4 0\n",
- "Name: survived, dtype: int64\n",
- "--------------------------------------------------\n",
- "0 0\n",
- "1 0\n",
- "2 1\n",
- "3 1\n",
- "4 1\n",
- "Name: survived, dtype: int64\n"
- ]
- }
- ],
- "source": [
- "#把目标值提取出来\n",
- "y_train = train_df.pop('survived')\n",
- "y_eval = eval_df.pop('survived')\n",
- "\n",
- "print(train_df.head())\n",
- "print('-'*50)\n",
- "print(eval_df.head())\n",
- "print('-'*50)\n",
- "print(y_train.head())\n",
- "print('-'*50)\n",
- "print(y_eval.head())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "<div>\n",
- "<style scoped>\n",
- " .dataframe tbody tr th:only-of-type {\n",
- " vertical-align: middle;\n",
- " }\n",
- "\n",
- " .dataframe tbody tr th {\n",
- " vertical-align: top;\n",
- " }\n",
- "\n",
- " .dataframe thead th {\n",
- " text-align: right;\n",
- " }\n",
- "</style>\n",
- "<table border=\"1\" class=\"dataframe\">\n",
- " <thead>\n",
- " <tr style=\"text-align: right;\">\n",
- " <th></th>\n",
- " <th>age</th>\n",
- " <th>n_siblings_spouses</th>\n",
- " <th>parch</th>\n",
- " <th>fare</th>\n",
- " </tr>\n",
- " </thead>\n",
- " <tbody>\n",
- " <tr>\n",
- " <th>count</th>\n",
- " <td>627.000000</td>\n",
- " <td>627.000000</td>\n",
- " <td>627.000000</td>\n",
- " <td>627.000000</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <th>mean</th>\n",
- " <td>29.631308</td>\n",
- " <td>0.545455</td>\n",
- " <td>0.379585</td>\n",
- " <td>34.385399</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <th>std</th>\n",
- " <td>12.511818</td>\n",
- " <td>1.151090</td>\n",
- " <td>0.792999</td>\n",
- " <td>54.597730</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <th>min</th>\n",
- " <td>0.750000</td>\n",
- " <td>0.000000</td>\n",
- " <td>0.000000</td>\n",
- " <td>0.000000</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <th>25%</th>\n",
- " <td>23.000000</td>\n",
- " <td>0.000000</td>\n",
- " <td>0.000000</td>\n",
- " <td>7.895800</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <th>50%</th>\n",
- " <td>28.000000</td>\n",
- " <td>0.000000</td>\n",
- " <td>0.000000</td>\n",
- " <td>15.045800</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <th>75%</th>\n",
- " <td>35.000000</td>\n",
- " <td>1.000000</td>\n",
- " <td>0.000000</td>\n",
- " <td>31.387500</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <th>max</th>\n",
- " <td>80.000000</td>\n",
- " <td>8.000000</td>\n",
- " <td>5.000000</td>\n",
- " <td>512.329200</td>\n",
- " </tr>\n",
- " </tbody>\n",
- "</table>\n",
- "</div>"
- ],
- "text/plain": [
- " age n_siblings_spouses parch fare\n",
- "count 627.000000 627.000000 627.000000 627.000000\n",
- "mean 29.631308 0.545455 0.379585 34.385399\n",
- "std 12.511818 1.151090 0.792999 54.597730\n",
- "min 0.750000 0.000000 0.000000 0.000000\n",
- "25% 23.000000 0.000000 0.000000 7.895800\n",
- "50% 28.000000 0.000000 0.000000 15.045800\n",
- "75% 35.000000 1.000000 0.000000 31.387500\n",
- "max 80.000000 8.000000 5.000000 512.329200"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "train_df.describe() #查看数据分布"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- "<div>\n",
- "<style scoped>\n",
- " .dataframe tbody tr th:only-of-type {\n",
- " vertical-align: middle;\n",
- " }\n",
- "\n",
- " .dataframe tbody tr th {\n",
- " vertical-align: top;\n",
- " }\n",
- "\n",
- " .dataframe thead th {\n",
- " text-align: right;\n",
- " }\n",
- "</style>\n",
- "<table border=\"1\" class=\"dataframe\">\n",
- " <thead>\n",
- " <tr style=\"text-align: right;\">\n",
- " <th></th>\n",
- " <th>sex</th>\n",
- " <th>class</th>\n",
- " <th>deck</th>\n",
- " <th>embark_town</th>\n",
- " <th>alone</th>\n",
- " </tr>\n",
- " </thead>\n",
- " <tbody>\n",
- " <tr>\n",
- " <th>count</th>\n",
- " <td>627</td>\n",
- " <td>627</td>\n",
- " <td>627</td>\n",
- " <td>627</td>\n",
- " <td>627</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <th>unique</th>\n",
- " <td>2</td>\n",
- " <td>3</td>\n",
- " <td>8</td>\n",
- " <td>4</td>\n",
- " <td>2</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <th>top</th>\n",
- " <td>male</td>\n",
- " <td>Third</td>\n",
- " <td>unknown</td>\n",
- " <td>Southampton</td>\n",
- " <td>y</td>\n",
- " </tr>\n",
- " <tr>\n",
- " <th>freq</th>\n",
- " <td>410</td>\n",
- " <td>341</td>\n",
- " <td>481</td>\n",
- " <td>450</td>\n",
- " <td>372</td>\n",
- " </tr>\n",
- " </tbody>\n",
- "</table>\n",
- "</div>"
- ],
- "text/plain": [
- " sex class deck embark_town alone\n",
- "count 627 627 627 627 627\n",
- "unique 2 3 8 4 2\n",
- "top male Third unknown Southampton y\n",
- "freq 410 341 481 450 372"
- ]
- },
- "execution_count": 5,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "train_df.describe(include='object') #训练集总计样本数627"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "(627, 9) (264, 9)\n"
- ]
- }
- ],
- "source": [
- "print(train_df.shape, eval_df.shape)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<AxesSubplot:>"
- ]
- },
- "execution_count": 7,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAVd0lEQVR4nO3df7Bcd13/8efbFhFzmYTaeifftHphjHVKI5Hs1DowzL3UH6E4FBynttPBRqoXZuqI2hlN0RGUYabf75cf4qBosLVFMbdIW6hp/VFjrxXHgrm1NiltoYWAzTcm0KYJtzAMKW//2HO/XS97c+/u2b177qfPx8zO3f2cc/a8srt53b2fPbsbmYkkqSzfMeoAkqTBs9wlqUCWuyQVyHKXpAJZ7pJUoNNHHQDgzDPPzImJiZ62efrpp1m3bt1wAtVgrt41NVtTc0FzszU1FzQ3W51cc3NzX8nMs7ouzMxTnoBzgLuBzwAPAm+txs8A7gI+V/18UTUewB8AjwIPAC9fbh/btm3LXt199909b7MazNW7pmZraq7M5mZraq7M5markwvYl0v06kqmZU4C12TmecCFwNURcR6wE9ibmZuBvdVlgNcAm6vTNPDBHn4RSZIGYNlyz8zDmXlfdf6rwEPAJuAS4KZqtZuA11fnLwE+XP1iuRfYEBEbBx1ckrS0yB7eoRoRE8A9wPnAlzJzQzUewLHM3BARe4DrMvOT1bK9wG9m5r5F1zVN+5k94+Pj22ZmZnoKPj8/z9jYWE/brAZz9a6p2ZqaC5qbram5oLnZ6uSampqay8xW14VLzdcsPgFjwBzwM9XlpxYtP1b93AO8smN8L9A61XU75z58Tc2V2dxsTc2V2dxsTc2V2dxso5xzJyKeB9wCfCQzb62GjyxMt1Q/j1bjh2i/CLvg7GpMkrRKli33asrleuChzHxvx6LbgSur81cCn+gY//louxA4npmHB5hZkrSMlRzn/grgjcD+iLi/GnsbcB3w0Yi4CvgicGm17E7gYtqHQn4N+IVBBpYkLW/Zcs/2C6OxxOKLuqyfwNU1c0mSavDjBySpQI34+AGtHRM77+h724PXvXaASSSdis/cJalAlrskFchyl6QCWe6SVCDLXZIKZLlLUoEsd0kqkOUuSQWy3CWpQJa7JBXIcpekAlnuklQgy12SCmS5S1KBLHdJKpDlLkkFWskXZN8QEUcj4kDH2M0RcX91Orjw3aoRMRERX+9Y9sdDzC5JWsJKvonpRuADwIcXBjLz5xbOR8R7gOMd6z+WmVsHlE+S1IeVfEH2PREx0W1ZRARwKfDqAeeSJNUQmbn8Su1y35OZ5y8afxXw3sxsdaz3IPBZ4ATw25n5z0tc5zQwDTA+Pr5tZmamp+Dz8/OMjY31tM1qKD3X/kPHl19pCVs2re86XvptNgxNzdbUXNDcbHVyTU1NzS3072J1vyD7cmB3x+XDwPdl5hMRsQ34eES8NDNPLN4wM3cBuwBarVZOTk72tOPZ2Vl63WY1lJ5rR50vyL6i+/5Lv82GoanZmpoLmpttWLn6PlomIk4Hfga4eWEsM7+RmU9U5+eAx4AfrBtSktSbOodC/jjwcGY+vjAQEWdFxGnV+ZcAm4HP14soSerVSg6F3A38K3BuRDweEVdViy7jf07JALwKeKA6NPJjwFsy88kB5pUkrcBKjpa5fInxHV3GbgFuqR9LklSH71CVpAJZ7pJUIMtdkgpkuUtSgSx3SSqQ5S5JBbLcJalAlrskFchyl6QCWe6SVCDLXZIKZLlLUoEsd0kqkOUuSQWy3CWpQJa7JBXIcpekAlnuklSglXyH6g0RcTQiDnSMvSMiDkXE/dXp4o5l10bEoxHxSET81LCCS5KWtpJn7jcC27uMvy8zt1anOwEi4jzaX5z90mqbP4qI0wYVVpK0MsuWe2beAzy5wuu7BJjJzG9k5heAR4ELauSTJPUhMnP5lSImgD2ZeX51+R3ADuAEsA+4JjOPRcQHgHsz8y+q9a4H/iYzP9blOqeBaYDx8fFtMzMzPQWfn59nbGysp21WQ+m59h863ve2Wzat7zpe+m02DE3N1tRc0NxsdXJNTU3NZWar27LT+8zzQeCdQFY/3wO8qZcryMxdwC6AVquVk5OTPQWYnZ2l121WQ+m5duy8o+9tD17Rff+l32bD0NRsTc0Fzc02rFx9HS2TmUcy85nM/BbwIZ6dejkEnNOx6tnVmCRpFfVV7hGxsePiG4CFI2luBy6LiOdHxIuBzcCn60WUJPVq2WmZiNgNTAJnRsTjwNuByYjYSnta5iDwZoDMfDAiPgp8BjgJXJ2ZzwwluSRpScuWe2Ze3mX4+lOs/y7gXXVCSZLq8R2qklQgy12SCmS5S1KBLHdJKpDlLkkFstwlqUCWuyQVyHKXpAJZ7pJUIMtdkgpkuUtSgSx3SSqQ5S5JBbLcJalAlrskFchyl6QCWe6SVKBlyz0iboiIoxFxoGPs/0bEwxHxQETcFhEbqvGJiPh6RNxfnf54iNklSUtYyTP3G4Hti8buAs7PzB8GPgtc27HssczcWp3eMpiYkqReLFvumXkP8OSisb/PzJPVxXuBs4eQTZLUp8jM5VeKmAD2ZOb5XZb9NXBzZv5Ftd6DtJ/NnwB+OzP/eYnrnAamAcbHx7fNzMz0FHx+fp6xsbGetlkNpefaf+h439tu2bS+63jpt9kwNDVbU3NBc7PVyTU1NTWXma1uy06vEyoifgs4CXykGjoMfF9mPhER24CPR8RLM/PE4m0zcxewC6DVauXk5GRP+56dnaXXbVZD6bl27Lyj720PXtF9/6XfZsPQ1GxNzQXNzTasXH0fLRMRO4CfBq7I6ul/Zn4jM5+ozs8BjwE/OICckqQe9FXuEbEd+A3gdZn5tY7xsyLitOr8S4DNwOcHEVSStHLLTstExG5gEjgzIh4H3k776JjnA3dFBMC91ZExrwJ+LyK+CXwLeEtmPtn1iiVJQ7NsuWfm5V2Gr19i3VuAW+qGkiTV4ztUJalAlrskFchyl6QCWe6SVCDLXZIKZLlLUoEsd0kqkOUuSQWy3CWpQJa7JBXIcpekAlnuklQgy12SCmS5S1KBLHdJKpDlLkkFstwlqUCWuyQVaEXlHhE3RMTRiDjQMXZGRNwVEZ+rfr6oGo+I+IOIeDQiHoiIlw8rvCSpu5U+c78R2L5obCewNzM3A3urywCvATZXp2ngg/VjSpJ6saJyz8x7gCcXDV8C3FSdvwl4fcf4h7PtXmBDRGwcQFZJ0gpFZq5sxYgJYE9mnl9dfiozN1TnAziWmRsiYg9wXWZ+slq2F/jNzNy36PqmaT+zZ3x8fNvMzExPwefn5xkbG+tpm9VQeq79h473ve2WTeu7jpd+mw1DU7M1NRc0N1udXFNTU3OZ2eq27PRaqSqZmRGxst8Sz26zC9gF0Gq1cnJysqd9zs7O0us2q6H0XDt23tH3tgev6L7/0m+zYWhqtqbmguZmG1auOkfLHFmYbql+Hq3GDwHndKx3djUmSVoldcr9duDK6vyVwCc6xn++OmrmQuB4Zh6usR9JUo9WNC0TEbuBSeDMiHgceDtwHfDRiLgK+CJwabX6ncDFwKPA14BfGHBmSdIyVlTumXn5Eosu6rJuAlfXCSVJqsd3qEpSgSx3SSqQ5S5JBbLcJalAlrskFchyl6QCWe6SVCDLXZIKZLlLUoEsd0kqkOUuSQWy3CWpQJa7JBXIcpekAlnuklQgy12SCmS5S1KBLHdJKtCKvmavm4g4F7i5Y+glwO8AG4BfAr5cjb8tM+/sdz+SpN71Xe6Z+QiwFSAiTgMOAbfR/kLs92XmuwcRUJLUu0FNy1wEPJaZXxzQ9UmSaojMrH8lETcA92XmByLiHcAO4ASwD7gmM4912WYamAYYHx/fNjMz09M+5+fnGRsbq5l88ErPtf/Q8b633bJpfdfx0m+zYWhqtqbmguZmq5NrampqLjNb3ZbVLveI+E7g/wEvzcwjETEOfAVI4J3Axsx806muo9Vq5b59+3ra7+zsLJOTk/2FHqLSc03svKPvbQ9e99qu46XfZsPQ1GxNzQXNzVYnV0QsWe6DmJZ5De1n7UcAMvNIZj6Tmd8CPgRcMIB9SJJ6MIhyvxzYvXAhIjZ2LHsDcGAA+5Ak9aDvo2UAImId8BPAmzuG/09EbKU9LXNw0TJJ0iqoVe6Z+TTwPYvG3lgrkSSpNt+hKkkFstwlqUCWuyQVyHKXpALVekFVa1OdNyJJWht85i5JBbLcJalAlrskFchyl6QCWe6SVCDLXZIK5KGQWjVLHYJ5zZaT7Bjy4ZlLfZa8VCqfuUtSgSx3SSqQ5S5JBbLcJalAvqC6BvXz2TCr8aKlpOaoXe4RcRD4KvAMcDIzWxFxBnAzMEH7q/YuzcxjdfclSVqZQU3LTGXm1sxsVZd3AnszczOwt7osSVolw5pzvwS4qTp/E/D6Ie1HktRFZGa9K4j4AnAMSOBPMnNXRDyVmRuq5QEcW7jcsd00MA0wPj6+bWZmpqf9zs/PMzY2Viv7MKxGrv2Hjve8zfgL4MjXhxBmAFYj25ZN63vepqmPMWhutqbmguZmq5NrampqrmPG5H8YxAuqr8zMQxHxvcBdEfFw58LMzIj4tt8gmbkL2AXQarVycnKyp53Ozs7S6zarYTVy9fPC6DVbTvKe/c18/Xw1sh28YrLnbZr6GIPmZmtqLmhutmHlqj0tk5mHqp9HgduAC4AjEbERoPp5tO5+JEkrV6vcI2JdRLxw4Tzwk8AB4Hbgymq1K4FP1NmPJKk3df8WHgdua0+rczrwl5n5txHxb8BHI+Iq4IvApTX3I0nqQa1yz8zPAy/rMv4EcFGd65Yk9c+PH5CkAlnuklQgy12SCmS5S1KBLHdJKpDlLkkFstwlqUCWuyQVyHKXpAJZ7pJUIMtdkgpkuUtSgSx3SSqQ5S5JBWrm965JAzbR51cT7th5Bweve+0QEknD5TN3SSqQ5S5JBbLcJalAfc+5R8Q5wIdpf49qArsy8/0R8Q7gl4AvV6u+LTPvrBtUWov6metf4Fy/6qjzgupJ4JrMvC8iXgjMRcRd1bL3Zea768eTJPWj73LPzMPA4er8VyPiIWDToIJJkvoXmVn/SiImgHuA84FfB3YAJ4B9tJ/dH+uyzTQwDTA+Pr5tZmamp33Oz88zNjZWK/cwrEau/YeO97zN+AvgyNeHEGYAmpptIdeWTev7vo5+7qsFp9rvc/nx36+mZquTa2pqai4zW92W1S73iBgD/gl4V2beGhHjwFdoz8O/E9iYmW861XW0Wq3ct29fT/udnZ1lcnISaNa8ZmeuYen3mO337G/m2xqamm0hV53HyLAem6vxOOtHU3NBc7PVyRURS5Z7rf9REfE84BbgI5l5K0BmHulY/iFgT519SM9Vp/rFsPAGq6X4Yqz6PhQyIgK4HngoM9/bMb6xY7U3AAf6jydJ6kedZ+6vAN4I7I+I+6uxtwGXR8RW2tMyB4E319hHser8ua7V5X2ltajO0TKfBKLLIo9pl6QR8x2qklQgy12SCmS5S1KBLHdJKpDlLkkFstwlqUCWuyQVyHKXpAJZ7pJUoOZ9FN8a0u1t6ct9oJO0FvT7kQvXbDnJ5GCjqE8+c5ekAlnuklQgy12SCvScn3P341wlleg5X+6SBqtJX3v5XOa0jCQVyHKXpAI5LSMV6Ln4WtJy/+ZTvQelxOmgoZV7RGwH3g+cBvxpZl43rH1JKsNz8ZfSsAxlWiYiTgP+EHgNcB7tL80+bxj7kiR9u2E9c78AeDQzPw8QETPAJcBnhrQ/SRqZOn9x3Lh93QCTPCsyc/BXGvGzwPbM/MXq8huBH83MX+5YZxqYri6eCzzS427OBL4ygLiDZq7eNTVbU3NBc7M1NRc0N1udXN+fmWd1WzCyF1Qzcxewq9/tI2JfZrYGGGkgzNW7pmZrai5obram5oLmZhtWrmEdCnkIOKfj8tnVmCRpFQyr3P8N2BwRL46I7wQuA24f0r4kSYsMZVomM09GxC8Df0f7UMgbMvPBAe+m7ymdITNX75qaram5oLnZmpoLmpttKLmG8oKqJGm0/PgBSSqQ5S5JBVpz5R4R2yPikYh4NCJ2jjjLDRFxNCIOdIydERF3RcTnqp8vGkGucyLi7oj4TEQ8GBFvbUK2iPiuiPh0RPxHlet3q/EXR8Snqvv05upF+FUXEadFxL9HxJ6G5ToYEfsj4v6I2FeNjfxxVuXYEBEfi4iHI+KhiPixUWeLiHOr22rhdCIifnXUuapsv1Y99g9ExO7q/8RQHmdrqtwb+LEGNwLbF43tBPZm5mZgb3V5tZ0ErsnM84ALgaur22nU2b4BvDozXwZsBbZHxIXA/wbel5k/ABwDrlrlXAveCjzUcbkpuQCmMnNrx/HQo74vF7wf+NvM/CHgZbRvv5Fmy8xHqttqK7AN+Bpw26hzRcQm4FeAVmaeT/tgk8sY1uMsM9fMCfgx4O86Ll8LXDviTBPAgY7LjwAbq/MbgUcacLt9AviJJmUDvhu4D/hR2u/OO73bfbyKec6m/R/+1cAeIJqQq9r3QeDMRWMjvy+B9cAXqA7MaFK2jiw/CfxLE3IBm4D/BM6gfaTiHuCnhvU4W1PP3Hn2xlnweDXWJOOZebg6/1/A+CjDRMQE8CPAp2hAtmrq437gKHAX8BjwVGaerFYZ1X36+8BvAN+qLn9PQ3IBJPD3ETFXfWwHNOC+BF4MfBn4s2o6608jYl1Dsi24DNhdnR9prsw8BLwb+BJwGDgOzDGkx9laK/c1Jdu/ikd2rGlEjAG3AL+amSc6l40qW2Y+k+0/l8+m/QFzP7TaGRaLiJ8Gjmbm3KizLOGVmfly2tORV0fEqzoXjvBxdjrwcuCDmfkjwNMsmuoY5f+Bau76dcBfLV42ilzVHP8ltH8p/i9gHd8+rTswa63c18LHGhyJiI0A1c+jowgREc+jXewfycxbm5QNIDOfAu6m/WfohohYeEPdKO7TVwCvi4iDwAztqZn3NyAX8P+f8ZGZR2nPHV9AM+7Lx4HHM/NT1eWP0S77JmSD9i/D+zLzSHV51Ll+HPhCZn45M78J3Er7sTeUx9laK/e18LEGtwNXVuevpD3fvaoiIoDrgYcy871NyRYRZ0XEhur8C2i/DvAQ7ZL/2VHlysxrM/PszJyg/Zj6x8y8YtS5ACJiXUS8cOE87TnkAzTgcZaZ/wX8Z0ScWw1dRPtjvUeerXI5z07JwOhzfQm4MCK+u/o/unB7DedxNqoXOmq8KHEx8Fnac7W/NeIsu2nPnX2T9rOYq2jP1e4FPgf8A3DGCHK9kvafnA8A91eni0edDfhh4N+rXAeA36nGXwJ8GniU9p/Qzx/hfToJ7GlKrirDf1SnBxce86O+LzvybQX2Vffpx4EXNSEb7SmPJ4D1HWNNyPW7wMPV4//PgecP63Hmxw9IUoHW2rSMJGkFLHdJKpDlLkkFstwlqUCWuyQVyHKXpAJZ7pJUoP8GIfrJfg0S5+QAAAAASUVORK5CYII=\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "train_df.age.hist(bins = 20)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<AxesSubplot:>"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYYAAAD4CAYAAADo30HgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAMdUlEQVR4nO3cf4xld1nH8c8D225NS4rQhmxacChuJKRAW0tFRQKICF1DQTAhEigJoVEUNabRIpHUVLSCKJqgpCgWFQVBDAghiLTGBLF11/7Y1nah2jVSKw0SlpomVenXP+5ZmGec2XbbmXtmy+uVTPbcc+/e88x3cve959y7W2OMAMBhj5h7AAC2F2EAoBEGABphAKARBgCaHXMPsBlOOeWUsbKyMvcYAMeUffv2fWmMcera/Q+LMKysrGTv3r1zjwFwTKmqf11vv0tJADTCAEAjDAA0wgBAIwwANMIAQCMMADTCAEAjDAA0wgBAIwwANMIAQCMMADTCAEAjDAA0wgBAIwwANMIAQCMMADTCAEAjDAA0wgBAIwwANMIAQCMMADTCAECzY+4BNsP+Ow5l5ZKPzz0GrOvg5XvmHgGOijMGABphAKARBgAaYQCgEQYAGmEAoBEGABphAKARBgAaYQCgEQYAGmEAoBEGABphAKARBgAaYQCgEQYAGmEAoBEGABphAKARBgAaYQCgEQYAmvsNQ1X9VFXdUlXv24oBqurSqrp4K54bgKO34wE85vVJnj/G+MJWDwPA/I4Yhqp6V5Izknyiqt6f5ElJzkxyXJJLxxgfqarXJHlJkhOT7E7y60mOT/KqJPcmOX+M8eWqel2Si6b7bkvyqjHGPWuO96Qk70xyapJ7krxujHHr5nyrADwQR7yUNMb4sST/nuS5WfzBf9UY47zp9tuq6sTpoWcm+eEkz0jyliT3jDHOTvLZJK+eHvPhMcYzxhhPT3JLkteuc8grkrxhjPGdSS5O8jsbzVZVF1XV3qra+7V7Dj2w7xaA+/VALiUd9oIkL171fsAJSZ4wbV89xrg7yd1VdSjJX0779yd52rR9ZlX9cpJHJzkpySdXP3lVnZTke5J8sKoO79650TBjjCuyCEl27to9juL7AOAIjiYMleRlY4wDbWfVd2Vxyeiw+1bdvm/VMa5M8pIxxg3T5afnrHn+RyT5yhjjrKOYCYBNdjQfV/1kkjfU9Nf5qjr7KI/1qCR3VtVxSV659s4xxleT3F5VPzI9f1XV04/yGAA8REcThsuyeNP5xqq6ebp9NH4xyTVJPpNkozeUX5nktVV1Q5Kbk1xwlMcA4CGqMY79y/M7d+0euy58x9xjwLoOXr5n7hFgXVW1b4xx7tr9/uUzAI0wANAIAwCNMADQCAMAjTAA0AgDAI0wANAIAwCNMADQCAMAjTAA0AgDAI0wANAIAwCNMADQCAMAjTAA0AgDAI0wANAIAwDNjrkH2AxPPe3k7L18z9xjADwsOGMAoBEGABphAKARBgAaYQCgEQYAGmEAoBEGABphAKARBgAaYQCgEQYAGmEAoBEGABphAKARBgAaYQCgEQYAGmEAoBEGABphAKARBgAaYQCgEQYAGmEAoBEGABphAKARBgAaYQCgEQYAGmEAoBEGABphAKARBgAaYQCgEQYAGmEAoBEGABphAKARBgAaYQCgEQYAGmEAoBEGABphAKARBgAaYQCgEQYAGmEAoBEGABphAKARBgAaYQCg2TH3AJth/x2HsnLJx+ceA2CpDl6+Z0ue1xkDAI0wANAIAwCNMADQCAMAjTAA0AgDAI0wANAIAwCNMADQCAMAjTAA0AgDAI0wANAIAwCNMADQCAMAjTAA0AgDAI0wANAIAwCNMADQCAMAzbYIQ1U9p6o+NvccAGyTMACwfWxaGKpqpapuraorq+pzVfW+qnp+VX2mqj5fVedNX5+tquuq6u+q6jvWeZ4Tq+o9VXXt9LgLNmtGAO7fZp8xfHuStyd58vT1o0meleTiJL+Q5NYk3zfGODvJm5P8yjrP8aYkV40xzkvy3CRvq6oT1z6oqi6qqr1Vtfdr9xza5G8D4JvXjk1+vtvHGPuTpKpuTvLpMcaoqv1JVpKcnOS9VbU7yUhy3DrP8YIkL66qi6fbJyR5QpJbVj9ojHFFkiuSZOeu3WOTvw+Ab1qbHYZ7V23ft+r2fdOxLkty9RjjpVW1kuRv1nmOSvKyMcaBTZ4NgAdg2W8+n5zkjmn7NRs85pNJ3lBVlSRVdfYS5gJgsuwwvDXJr1bVddn4bOWyLC4x3ThdjrpsWcMBkNQYx/7l+Z27do9dF75j7jEAlurg5Xse0u+vqn1jjHPX7vfvGABohAGARhgAaIQBgEYYAGiEAYBGGABohAGARhgAaIQBgEYYAGiEAYBGGABohAGARhgAaIQBgEYYAGiEAYBGGABohAGARhgAaIQBgGbH3ANshqeednL2Xr5n7jEAHhacMQDQCAMAjTAA0AgDAI0wANAIAwCNMADQCAMAjTAA0AgDAI0wANAIAwCNMADQCAMAjTAA0AgDAI0wANAIAwCNMADQCAMAjTAA0AgDAI0wANAIAwCNMADQCAMATY0x5p7hIauqu5McmHuODZyS5EtzD7GO7TpXYrYHy2wPzjfzbN82xjh17c4dW3jAZTowxjh37iHWU1V7t+Ns23WuxGwPltkeHLP9fy4lAdAIAwDNwyUMV8w9wBFs19m261yJ2R4ssz04ZlvjYfHmMwCb5+FyxgDAJhEGAJpjOgxV9cKqOlBVt1XVJdtgnoNVtb+qrq+qvdO+x1TVp6rq89Ov37qkWd5TVXdV1U2r9q07Sy389rSON1bVOTPMdmlV3TGt3fVVdf6q+944zXagqn5wi2d7fFVdXVX/VFU3V9VPT/tnXbsjzDX7ulXVCVV1bVXdMM32S9P+J1bVNdMMH6iq46f9O6fbt033r8ww25VVdfuqdTtr2r/U18J0zEdW1XVV9bHp9uzrljHGMfmV5JFJ/jnJGUmOT3JDkqfMPNPBJKes2ffWJJdM25ck+bUlzfLsJOckuen+ZklyfpJPJKkkz0xyzQyzXZrk4nUe+5TpZ7szyROnn/kjt3C2XUnOmbYfleRz0wyzrt0R5pp93abv/aRp+7gk10xr8WdJXjHtf1eSH5+2X5/kXdP2K5J8YAt/nhvNdmWSl6/z+KW+FqZj/mySP0nysen27Ot2LJ8xnJfktjHGv4wx/jvJ+5NcMPNM67kgyXun7fcmeckyDjrG+NskX36As1yQ5A/Hwt8neXRV7VrybBu5IMn7xxj3jjFuT3JbFj/7rZrtzjHGP07bdye5JclpmXntjjDXRpa2btP3/l/TzeOmr5HkeUk+NO1fu2aH1/JDSb6/qmrJs21kqa+Fqjo9yZ4kvzfdrmyDdTuWw3Bakn9bdfsLOfILZRlGkr+qqn1VddG073FjjDun7f9I8rh5RjviLNtlLX9yOn1/z6pLbrPNNp2qn53F3zK3zdqtmSvZBus2XQ65PsldST6VxRnKV8YY/7vO8b8+23T/oSSPXdZsY4zD6/aWad1+s6p2rp1tnbm3wjuS/FyS+6bbj802WLdjOQzb0bPGGOckeVGSn6iqZ6++cyzOAbfF54O30yyT303ypCRnJbkzydvnHKaqTkry50l+Zozx1dX3zbl268y1LdZtjPG1McZZSU7P4szkyXPMsZ61s1XVmUnemMWMz0jymCQ/v+y5quqHktw1xti37GPfn2M5DHckefyq26dP+2Yzxrhj+vWuJH+RxQvki4dPRadf75pvwg1nmX0txxhfnF7A9yV5d75x2WPps1XVcVn84fu+McaHp92zr916c22ndZvm+UqSq5N8dxaXYQ7/f2yrj//12ab7T07yn0uc7YXTpbkxxrg3yR9knnX73iQvrqqDWVwKf16S38o2WLdjOQz/kGT39A7+8Vm8GfPRuYapqhOr6lGHt5O8IMlN00wXTg+7MMlH5pkwOcIsH03y6ukTGc9McmjVZZOlWHMd96VZrN3h2V4xfSLjiUl2J7l2C+eoJL+f5JYxxm+sumvWtdtoru2wblV1alU9etr+liQ/kMV7IFcnefn0sLVrdngtX57kquksbFmz3boq8pXFNfzV67aU18IY441jjNPHGCtZ/Pl11RjjldkG67al77Zv9VcWnyD4XBbXM9808yxnZPEpkBuS3Hx4niyuAX46yeeT/HWSxyxpnj/N4tLC/2RxnfK1G82SxScw3jmt4/4k584w2x9Nx74xixfArlWPf9M024EkL9ri2Z6VxWWiG5NcP32dP/faHWGu2dctydOSXDfNcFOSN696TVybxRvfH0yyc9p/wnT7tun+M2aY7app3W5K8sf5xieXlvpaWDXnc/KNTyXNvm7+SwwAmmP5UhIAW0AYAGiEAYBGGABohAGARhgAaIQBgOb/AEYEJAXn01RlAAAAAElFTkSuQmCC\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "train_df.sex.value_counts().plot(kind = 'barh') #类别适合画条形图"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "<class 'pandas.core.frame.DataFrame'>\n",
- "RangeIndex: 627 entries, 0 to 626\n",
- "Data columns (total 9 columns):\n",
- " # Column Non-Null Count Dtype \n",
- "--- ------ -------------- ----- \n",
- " 0 sex 627 non-null object \n",
- " 1 age 627 non-null float64\n",
- " 2 n_siblings_spouses 627 non-null int64 \n",
- " 3 parch 627 non-null int64 \n",
- " 4 fare 627 non-null float64\n",
- " 5 class 627 non-null object \n",
- " 6 deck 627 non-null object \n",
- " 7 embark_town 627 non-null object \n",
- " 8 alone 627 non-null object \n",
- "dtypes: float64(2), int64(2), object(5)\n",
- "memory usage: 44.2+ KB\n"
- ]
- }
- ],
- "source": [
- "train_df.info()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "0 n\n",
- "1 n\n",
- "2 y\n",
- "3 n\n",
- "4 y\n",
- " ..\n",
- "622 y\n",
- "623 y\n",
- "624 y\n",
- "625 n\n",
- "626 y\n",
- "Name: alone, Length: 627, dtype: object"
- ]
- },
- "execution_count": 10,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "train_df.alone "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<AxesSubplot:>"
- ]
- },
- "execution_count": 11,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAD4CAYAAADy46FuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAN6klEQVR4nO3de4yld13H8ffHbbtQWxahG1gKcdraSKCVZbsoIGBR0MJKyqUG+EMx0WyCGG2M0RKSpiokreAlEpS0EUFLoIiihEZu0qIJkbpbt90WetNdI0uhKdil5VJh+frHeZYexz3fvXR2nnOm71dyMs/5Pc+c+czvnJnPPpedk6pCkqRZvm/sAJKk+WZRSJJaFoUkqWVRSJJaFoUkqXXC2AFW0mmnnVZLS0tjx5CkhbJz5857q2rjrPVrqiiWlpbYsWPH2DEkaaEk+c9uvYeeJEkti0KS1LIoJEkti0KS1LIoJEkti0KS1LIoJEkti0KS1LIoJEkti0KS1LIoJEkti0KS1LIoJEkti0KS1LIoJEkti0KS1FpTb1y0e99+li65duwYOg72Xr5t7AjSI5Z7FJKklkUhSWpZFJKklkUhSWpZFJKklkUhSWpZFJKklkUhSWpZFJKklkUhSWpZFJKklkUhSWpZFJKklkUhSWodUVEkeVOSW5PcnGRXkh873sGWff3zk3xkNb+mJGnisO9HkeQ5wM8CW6rqwSSnAScd92SSpLlwJHsUm4B7q+pBgKq6t6q+mOS8JJ9OsjPJx5JsAkjyQ0k+meSmJDcmOSsTb01yS5LdSV49bHt+kuuTfDDJbUnemyTDuguGsRuBVx6n71+SdBhHUhQfB56S5I4kf5rkJ5KcCLwduKiqzgPeBbxl2P69wDuq6hnAc4G7mfyi3ww8A3gR8NaDxQI8E7gYeBpwJvDjSR4FXAW8DDgPeOLD/UYlScfmsIeequqBJOcBzwdeCFwDvBk4B/jEsAOwDrg7yanA6VX1oeFzvwWQ5HnA+6rqAPDlJJ8GngV8Dbihqr4wbLcLWAIeAPZU1Z3D+NXA9kPlS7L94Lp1j9l49DMgSWod0XtmD7/grweuT7IbeANwa1U9Z3q7oSiO1oNTyweONNNUtiuBKwHWbzq7juHrS5Iahz30lOSHk5w9NbQZ+DywcTjRTZITkzy9qu4HvpDk5cP4+iQnA/8MvDrJuiQbgRcANzRf9jZgKclZw/3XHuX3JUlaIUdyjuIU4D1JPpfkZibnEi4FLgKuSHITsIvJ+QiAnwd+bdj2M0zOL3wIuBm4CfgU8FtV9aVZX3A4ZLUduHY4mX3PMXxvkqQVkKq1c7Rm/aaza9Pr/njsGDoO9l6+bewI0pqVZGdVbZ213v+ZLUlqWRSSpJZFIUlqWRSSpJZFIUlqWRSSpJZFIUlqWRSSpJZFIUlqWRSSpJZFIUlqWRSSpJZFIUlqHdWbBM27c0/fwA7/yqgkrSj3KCRJLYtCktSyKCRJLYtCktSyKCRJLYtCktSyKCRJLYtCktSyKCRJLYtCktSyKCRJLYtCktSyKCRJLYtCktSyKCRJLYtCktSyKCRJLYtCktSyKCRJLYtCktSyKCRJLYtCktSyKCRJLYtCktSyKCRJLYtCktSyKCRJLYtCktSyKCRJLYtCktSyKCRJLYtCktSyKCRJLYtCktSyKCRJrRPGDrCSdu/bz9Il144dQ2vI3su3jR1BGp17FJKklkUhSWpZFJKklkUhSWpZFJKklkUhSWpZFJKklkUhSWpZFJKklkUhSWpZFJKklkUhSWpZFJKklkUhSWod96JIciDJrqnbUpLPHOVjXJzk5OOVUZI022q8H8U3q2rzsrHnLt8oyQlV9Z0Zj3ExcDXwjZWNJkk6nFHeuCjJA1V1SpLzgd8D/ht4apJnAh8AngysG9Y9AXgScF2Se6vqhWNklqRHqtUoikcn2TUs76mqVyxbvwU4p6r2JHkV8MWq2gaQZENV7U/yG8ALq+re5Q+eZDuwHWDdYzYet29Ckh6pVuNk9jeravNwW14SADdU1Z5heTfw4iRXJHl+Ve0/3INX1ZVVtbWqtq47ecOKBpckzcdVT18/uFBVdzDZw9gNvDnJpaOlkiQBI52jmCXJk4CvVtXVSe4DfnlYdT9wKvD/Dj1Jko6vuSoK4FzgrUm+C3wbeP0wfiXw0SRf9GS2JK2u414UVXXKrLGquh64fmr8Y8DHDrH924G3H7eQkqSZ5uEchSRpjlkUkqSWRSFJalkUkqSWRSFJalkUkqSWRSFJalkUkqSWRSFJalkUkqSWRSFJalkUkqTWvP312Ifl3NM3sOPybWPHkKQ1xT0KSVLLopAktSwKSVLLopAktSwKSVLLopAktSwKSVLLopAktSwKSVLLopAktSwKSVLLopAktSwKSVLLopAktSwKSVLLopAktSwKSVLLopAktSwKSVLLopAktSwKSVLLopAktSwKSVLLopAktSwKSVLLopAktSwKSVLLopAktSwKSVLLopAktSwKSVLLopAktSwKSVLLopAktSwKSVLrhLEDrKTd+/azdMm1Y8eQpFW19/Jtx/Xx3aOQJLUsCklSy6KQJLUsCklSy6KQJLUsCklSy6KQJLUsCklSy6KQJLUsCklSy6KQJLUsCklSy6KQJLUsCklSa8WKIsnjk+wabl9Ksm9Yvi/J52Z8zu8medERPPb5ST6yUlklSUduxd6Poqq+AmwGSHIZ8EBVvS3JEnDIX/JVdemhxpOsq6oDK5VNknTsVuvQ07okVyW5NcnHkzwaIMm7k1w0LO9NckWSG4GfS3JBktuG+69cpZySpGVWqyjOBt5RVU8H7gNeNWO7r1TVFuDvgKuAlwHnAU9chYySpENYraLYU1W7huWdwNKM7a4ZPj51+Jw7q6qAq2c9cJLtSXYk2XHgG/tXKq8kabBaRfHg1PIBZp8b+frRPnBVXVlVW6tq67qTNxxTOEnSbPN6eextwFKSs4b7rx0zjCQ9ks1lUVTVt4DtwLXDyex7Ro4kSY9YK3Z57LSqumxqeS9wztT9t00t/+LU8tKyx/gok3MVkqQRzeUehSRpflgUkqSWRSFJalkUkqSWRSFJalkUkqSWRSFJalkUkqSWRSFJalkUkqSWRSFJalkUkqSWRSFJah2Xvx47lnNP38COy7eNHUOS1hT3KCRJLYtCktSyKCRJLYtCktSyKCRJLYtCktSyKCRJLYtCktSyKCRJLYtCktSyKCRJLYtCktSyKCRJLYtCktSyKCRJLYtCktSyKCRJrVTV2BlWTJL7gdvHznGMTgPuHTvEMVjU3GD2sSxq9kXNDYfP/oNVtXHWyjX1VqjA7VW1dewQxyLJjkXMvqi5wexjWdTsi5obHn52Dz1JkloWhSSptdaK4sqxAzwMi5p9UXOD2ceyqNkXNTc8zOxr6mS2JGnlrbU9CknSCrMoJEmtNVEUSS5IcnuSu5JcMnaew0myN8nuJLuS7BjGHpfkE0nuHD7+wNg5AZK8K8k9SW6ZGjtk1kz8yfA83Jxky3jJZ2a/LMm+Ye53JXnp1Lo3DtlvT/Iz46SGJE9Jcl2SzyW5NcmvD+NzP+9N9kWY90cluSHJTUP23xnGz0jy2SHjNUlOGsbXD/fvGtYvzVnudyfZMzXnm4fxo3+9VNVC34B1wL8DZwInATcBTxs712Ey7wVOWzb2+8Alw/IlwBVj5xyyvADYAtxyuKzAS4F/AAI8G/jsHGa/DPjNQ2z7tOG1sx44Y3hNrRsp9yZgy7B8KnDHkG/u573JvgjzHuCUYflE4LPDfH4AeM0w/k7g9cPyrwDvHJZfA1wzZ7nfDVx0iO2P+vWyFvYofhS4q6r+o6r+B3g/cOHImY7FhcB7huX3AC8fL8pDquqfgK8uG56V9ULgL2viX4DHJtm0KkEPYUb2WS4E3l9VD1bVHuAuJq+tVVdVd1fVjcPy/cDngdNZgHlvss8yT/NeVfXAcPfE4VbATwIfHMaXz/vB5+ODwE8lyeqkfUiTe5ajfr2shaI4HfivqftfoH9hzoMCPp5kZ5Ltw9gTquruYflLwBPGiXZEZmVdlOfiV4dd7ndNHeKby+zD4YxnMvlX4kLN+7LssADznmRdkl3APcAnmOzh3FdV3xk2mc73vezD+v3A41c18GB57qo6OOdvGeb8j5KsH8aOes7XQlEsoudV1RbgJcAbkrxgemVN9g8X4rrlRco6+DPgLGAzcDfwB6OmaSQ5Bfgb4OKq+tr0unmf90NkX4h5r6oDVbUZeDKTPZunjpvoyCzPneQc4I1M8j8LeBzw28f6+GuhKPYBT5m6/+RhbG5V1b7h4z3Ah5i8IL98cPdv+HjPeAkPa1bWuX8uqurLww/Vd4GreOgwx1xlT3Iik1+0762qvx2GF2LeD5V9Ueb9oKq6D7gOeA6TQzMH/y7edL7vZR/WbwC+srpJ/6+p3BcMhwGrqh4E/oKHMedroSj+FTh7uDLhJCYnlT48cqaZknx/klMPLgM/DdzCJPPrhs1eB/z9OAmPyKysHwZ+Ybiq4tnA/qlDJXNh2bHYVzCZe5hkf81wJcsZwNnADaudDyZXpQB/Dny+qv5watXcz/us7Asy7xuTPHZYfjTwYibnWK4DLho2Wz7vB5+Pi4BPDXt6q2pG7tum/lERJudVpuf86F4vY5ylX+kbk7P4dzA5nvimsfMcJuuZTK7yuAm49WBeJsc2/xG4E/gk8Lixsw653sfkUMG3mRzL/KVZWZlcRfGO4XnYDWydw+x/NWS7efiB2TS1/ZuG7LcDLxkx9/OYHFa6Gdg13F66CPPeZF+Eef8R4N+GjLcAlw7jZzIpr7uAvwbWD+OPGu7fNaw/c85yf2qY81uAq3noyqijfr34JzwkSa21cOhJknQcWRSSpJZFIUlqWRSSpJZFIUlqWRSSpJZFIUlq/S9VU+heLhrM6wAAAABJRU5ErkJggg==\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "train_df['class'].value_counts().plot(kind = 'barh')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/zhangmeng/.virtualenvs/tf_py/lib/python3.6/site-packages/ipykernel_launcher.py:4: UserWarning: Boolean Series key will be reindexed to match DataFrame index.\n",
- " after removing the cwd from sys.path.\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "<AxesSubplot:ylabel='age'>"
- ]
- },
- "execution_count": 12,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "<Figure size 432x288 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "#看男性,女性的生成比率,存活是survived为1,未存活是0,谁的存活均值高,就是存活比例\n",
- "\n",
- "df_age=pd.concat([train_df, y_train], axis = 1)\n",
- "df_age[df_age['age']>30][df_age['age']<50].groupby('age').survived.mean().plot(kind='barh')\n",
- "\n",
- "\n",
- "# pd.concat([train_df, y_train], axis = 1).groupby('sex').survived.mean().plot(kind='barh')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "array(['male', 'female'], dtype=object)"
- ]
- },
- "execution_count": 13,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "train_df['sex'].unique()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "sex ['male' 'female']\n",
- "n_siblings_spouses [1 0 3 4 2 5 8]\n",
- "parch [0 1 2 5 3 4]\n",
- "class ['Third' 'First' 'Second']\n",
- "deck ['unknown' 'C' 'G' 'A' 'B' 'D' 'F' 'E']\n",
- "embark_town ['Southampton' 'Cherbourg' 'Queenstown' 'unknown']\n",
- "alone ['n' 'y']\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "[IndicatorColumn(categorical_column=VocabularyListCategoricalColumn(key='sex', vocabulary_list=('male', 'female'), dtype=tf.string, default_value=-1, num_oov_buckets=0)),\n",
- " IndicatorColumn(categorical_column=VocabularyListCategoricalColumn(key='n_siblings_spouses', vocabulary_list=(1, 0, 3, 4, 2, 5, 8), dtype=tf.int64, default_value=-1, num_oov_buckets=0)),\n",
- " IndicatorColumn(categorical_column=VocabularyListCategoricalColumn(key='parch', vocabulary_list=(0, 1, 2, 5, 3, 4), dtype=tf.int64, default_value=-1, num_oov_buckets=0)),\n",
- " IndicatorColumn(categorical_column=VocabularyListCategoricalColumn(key='class', vocabulary_list=('Third', 'First', 'Second'), dtype=tf.string, default_value=-1, num_oov_buckets=0)),\n",
- " IndicatorColumn(categorical_column=VocabularyListCategoricalColumn(key='deck', vocabulary_list=('unknown', 'C', 'G', 'A', 'B', 'D', 'F', 'E'), dtype=tf.string, default_value=-1, num_oov_buckets=0)),\n",
- " IndicatorColumn(categorical_column=VocabularyListCategoricalColumn(key='embark_town', vocabulary_list=('Southampton', 'Cherbourg', 'Queenstown', 'unknown'), dtype=tf.string, default_value=-1, num_oov_buckets=0)),\n",
- " IndicatorColumn(categorical_column=VocabularyListCategoricalColumn(key='alone', vocabulary_list=('n', 'y'), dtype=tf.string, default_value=-1, num_oov_buckets=0)),\n",
- " NumericColumn(key='age', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=None),\n",
- " NumericColumn(key='fare', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=None)]"
- ]
- },
- "execution_count": 14,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "#tf的特征工程\n",
- "#离散型特征\n",
- "#n_siblings_spouses 家庭成员表\n",
- "categorical_columns = ['sex', 'n_siblings_spouses', 'parch', 'class',\n",
- " 'deck', 'embark_town', 'alone']\n",
- "#连续,年龄,票价\n",
- "numeric_columns = ['age', 'fare']\n",
- "\n",
- "feature_columns = []\n",
- "#离散特征处理\n",
- "for categorical_column in categorical_columns:\n",
- " vocab = train_df[categorical_column].unique() #得到某一列的类别数array(['male', 'female'], dtype=object)\n",
- " print(categorical_column, vocab)\n",
- " feature_columns.append(\n",
- " #indicator_column做one-hot编码,将类别值变为one-hot编码\n",
- " tf.feature_column.indicator_column(\n",
- " tf.feature_column.categorical_column_with_vocabulary_list(\n",
- " categorical_column, vocab)))\n",
- "#连续\n",
- "for categorical_column in numeric_columns:\n",
- " feature_columns.append(\n",
- " tf.feature_column.numeric_column(\n",
- " categorical_column, dtype=tf.float32))\n",
- "#特征类别\n",
- "feature_columns"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 18,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "{'sex': 0 male\n",
- " 1 female\n",
- " 2 female\n",
- " 3 female\n",
- " 4 male\n",
- " ... \n",
- " 622 male\n",
- " 623 male\n",
- " 624 female\n",
- " 625 female\n",
- " 626 male\n",
- " Name: sex, Length: 627, dtype: object,\n",
- " 'age': 0 22.0\n",
- " 1 38.0\n",
- " 2 26.0\n",
- " 3 35.0\n",
- " 4 28.0\n",
- " ... \n",
- " 622 28.0\n",
- " 623 25.0\n",
- " 624 19.0\n",
- " 625 28.0\n",
- " 626 32.0\n",
- " Name: age, Length: 627, dtype: float64,\n",
- " 'n_siblings_spouses': 0 1\n",
- " 1 1\n",
- " 2 0\n",
- " 3 1\n",
- " 4 0\n",
- " ..\n",
- " 622 0\n",
- " 623 0\n",
- " 624 0\n",
- " 625 1\n",
- " 626 0\n",
- " Name: n_siblings_spouses, Length: 627, dtype: int64,\n",
- " 'parch': 0 0\n",
- " 1 0\n",
- " 2 0\n",
- " 3 0\n",
- " 4 0\n",
- " ..\n",
- " 622 0\n",
- " 623 0\n",
- " 624 0\n",
- " 625 2\n",
- " 626 0\n",
- " Name: parch, Length: 627, dtype: int64,\n",
- " 'fare': 0 7.2500\n",
- " 1 71.2833\n",
- " 2 7.9250\n",
- " 3 53.1000\n",
- " 4 8.4583\n",
- " ... \n",
- " 622 10.5000\n",
- " 623 7.0500\n",
- " 624 30.0000\n",
- " 625 23.4500\n",
- " 626 7.7500\n",
- " Name: fare, Length: 627, dtype: float64,\n",
- " 'class': 0 Third\n",
- " 1 First\n",
- " 2 Third\n",
- " 3 First\n",
- " 4 Third\n",
- " ... \n",
- " 622 Second\n",
- " 623 Third\n",
- " 624 First\n",
- " 625 Third\n",
- " 626 Third\n",
- " Name: class, Length: 627, dtype: object,\n",
- " 'deck': 0 unknown\n",
- " 1 C\n",
- " 2 unknown\n",
- " 3 C\n",
- " 4 unknown\n",
- " ... \n",
- " 622 unknown\n",
- " 623 unknown\n",
- " 624 B\n",
- " 625 unknown\n",
- " 626 unknown\n",
- " Name: deck, Length: 627, dtype: object,\n",
- " 'embark_town': 0 Southampton\n",
- " 1 Cherbourg\n",
- " 2 Southampton\n",
- " 3 Southampton\n",
- " 4 Queenstown\n",
- " ... \n",
- " 622 Southampton\n",
- " 623 Southampton\n",
- " 624 Southampton\n",
- " 625 Southampton\n",
- " 626 Queenstown\n",
- " Name: embark_town, Length: 627, dtype: object,\n",
- " 'alone': 0 n\n",
- " 1 n\n",
- " 2 y\n",
- " 3 n\n",
- " 4 y\n",
- " ..\n",
- " 622 y\n",
- " 623 y\n",
- " 624 y\n",
- " 625 n\n",
- " 626 y\n",
- " Name: alone, Length: 627, dtype: object}"
- ]
- },
- "execution_count": 18,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "dict(train_df) #把df变为字典,就是前面键是列名,后面是对应列的series"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "0 0\n",
- "1 1\n",
- "2 1\n",
- "3 1\n",
- "4 0\n",
- " ..\n",
- "622 0\n",
- "623 0\n",
- "624 1\n",
- "625 0\n",
- "626 0\n",
- "Name: survived, Length: 627, dtype: int64"
- ]
- },
- "execution_count": 15,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "y_train"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "metadata": {},
- "outputs": [],
- "source": [
- "#把特征,目标都传入,构建我们的dataset\n",
- "def make_dataset(data_df, label_df, epochs = 10, shuffle = True,\n",
- " batch_size = 32):\n",
- " #把df类变为dataset\n",
- " dataset = tf.data.Dataset.from_tensor_slices((dict(data_df), label_df))\n",
- "# dataset = tf.data.Dataset.from_tensor_slices((data_df, label_df))\n",
- " #是否需要洗牌,不同样本顺序打乱\n",
- " if shuffle:\n",
- " dataset = dataset.shuffle(10000) #10000是缓存的大小\n",
- " dataset = dataset.repeat(epochs).batch(batch_size)\n",
- " return dataset #变为了batchdataset类型"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {},
- "outputs": [],
- "source": [
- "train_dataset = make_dataset(train_df, y_train, batch_size = 5) #做个测试"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 22,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "{'sex': <tf.Tensor: shape=(5,), dtype=string, numpy=array([b'male', b'male', b'female', b'male', b'female'], dtype=object)>, 'age': <tf.Tensor: shape=(5,), dtype=float64, numpy=array([28., 26., 45., 19., 11.])>, 'n_siblings_spouses': <tf.Tensor: shape=(5,), dtype=int64, numpy=array([0, 0, 0, 0, 4])>, 'parch': <tf.Tensor: shape=(5,), dtype=int64, numpy=array([0, 0, 0, 0, 2])>, 'fare': <tf.Tensor: shape=(5,), dtype=float64, numpy=array([ 7.8958, 7.775 , 7.75 , 7.775 , 31.275 ])>, 'class': <tf.Tensor: shape=(5,), dtype=string, numpy=array([b'Third', b'Third', b'Third', b'Third', b'Third'], dtype=object)>, 'deck': <tf.Tensor: shape=(5,), dtype=string, numpy=\n",
- "array([b'unknown', b'unknown', b'unknown', b'unknown', b'unknown'],\n",
- " dtype=object)>, 'embark_town': <tf.Tensor: shape=(5,), dtype=string, numpy=\n",
- "array([b'Southampton', b'Southampton', b'Southampton', b'Southampton',\n",
- " b'Southampton'], dtype=object)>, 'alone': <tf.Tensor: shape=(5,), dtype=string, numpy=array([b'y', b'y', b'y', b'y', b'n'], dtype=object)>}\n",
- "--------------------------------------------------\n",
- "tf.Tensor([0 0 0 0 0], shape=(5,), dtype=int64)\n"
- ]
- }
- ],
- "source": [
- "for x, y in train_dataset.take(1):\n",
- " print(x)\n",
- " print('-'*50)\n",
- " print(y)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# 通过feature_columns把离散型(非数值类型)tensor变为数值类型"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 23,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "--------------------------------------------------\n",
- "WARNING:tensorflow:Layer dense_features is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2. The layer has dtype float32 because it's dtype defaults to floatx.\n",
- "\n",
- "If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.\n",
- "\n",
- "To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.\n",
- "\n",
- "[[19.]\n",
- " [ 9.]\n",
- " [28.]\n",
- " [24.]\n",
- " [20.]]\n",
- "WARNING:tensorflow:Layer dense_features_1 is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2. The layer has dtype float32 because it's dtype defaults to floatx.\n",
- "\n",
- "If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.\n",
- "\n",
- "To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.\n",
- "\n",
- "tf.Tensor(\n",
- "[[1. 0.]\n",
- " [1. 0.]\n",
- " [0. 1.]\n",
- " [1. 0.]\n",
- " [1. 0.]], shape=(5, 2), dtype=float32)\n"
- ]
- }
- ],
- "source": [
- "# print(feature_columns)\n",
- "print('-'*50)\n",
- "# keras.layers.DenseFeature\n",
- "# A layer that produces a dense Tensor based on given feature_columns\n",
- "for x, y in train_dataset.take(1):\n",
- " #特征类型\n",
- " age_column = feature_columns[7]\n",
- " gender_column = feature_columns[0]\n",
- " #DenseFeatures进行转换\n",
- " print(keras.layers.DenseFeatures(age_column)(x).numpy())\n",
- " print(keras.layers.DenseFeatures(gender_column)(x)) #变为one-host编码"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 24,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "WARNING:tensorflow:Layer dense_features_2 is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2. The layer has dtype float32 because it's dtype defaults to floatx.\n",
- "\n",
- "If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.\n",
- "\n",
- "To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.\n",
- "\n",
- "tf.Tensor(\n",
- "[[ 58. 1. 0. 0. 1. 0. 0. 1.\n",
- " 0. 0. 0. 0. 0. 0. 1. 0.\n",
- " 0. 0. 153.4625 0. 1. 0. 0. 0.\n",
- " 0. 0. 0. 1. 0. 0. 0. 0.\n",
- " 0. 1. ]\n",
- " [ 36. 1. 0. 0. 0. 1. 1. 0.\n",
- " 0. 0. 0. 0. 0. 0. 1. 0.\n",
- " 0. 0. 27.75 1. 0. 0. 0. 0.\n",
- " 0. 0. 0. 0. 1. 0. 0. 0.\n",
- " 1. 0. ]\n",
- " [ 42. 1. 0. 0. 0. 1. 1. 0.\n",
- " 0. 0. 0. 0. 0. 0. 1. 0.\n",
- " 0. 0. 27. 1. 0. 0. 0. 0.\n",
- " 0. 0. 1. 0. 0. 0. 0. 0.\n",
- " 1. 0. ]\n",
- " [ 3. 1. 0. 0. 0. 1. 1. 0.\n",
- " 0. 0. 0. 0. 0. 0. 1. 0.\n",
- " 0. 0. 18.75 1. 0. 0. 0. 0.\n",
- " 0. 0. 0. 1. 0. 0. 0. 0.\n",
- " 1. 0. ]\n",
- " [ 49. 0. 1. 1. 0. 0. 1. 0.\n",
- " 0. 0. 0. 0. 0. 0. 1. 0.\n",
- " 0. 0. 0. 0. 1. 0. 0. 0.\n",
- " 0. 0. 1. 0. 0. 0. 0. 0.\n",
- " 1. 0. ]], shape=(5, 34), dtype=float32)\n"
- ]
- }
- ],
- "source": [
- "# keras.layers.DenseFeature\n",
- "for x, y in train_dataset.take(1):#拿第一个batch\n",
- " print(keras.layers.DenseFeatures(feature_columns)(x))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 18,
- "metadata": {},
- "outputs": [],
- "source": [
- "#DenseFeatures可以直接将非数值类型特征转为数值类型特征\n",
- "model = keras.models.Sequential([\n",
- " keras.layers.DenseFeatures(feature_columns),#这里传入feature_columns,直接处理所有特征\n",
- " keras.layers.Dense(100, activation='relu'),\n",
- " keras.layers.Dense(100, activation='relu'),\n",
- " keras.layers.Dense(2, activation='softmax'),\n",
- "])\n",
- "model.compile(loss='sparse_categorical_crossentropy',\n",
- " optimizer = keras.optimizers.SGD(lr=0.005),metrics=['accuracy'])\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 19,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "[<tensorflow.python.feature_column.dense_features_v2.DenseFeatures at 0x7f5da04fd828>,\n",
- " <tensorflow.python.keras.layers.core.Dense at 0x7f5da04fdfd0>,\n",
- " <tensorflow.python.keras.layers.core.Dense at 0x7f5da052c278>,\n",
- " <tensorflow.python.keras.layers.core.Dense at 0x7f5da052c4e0>]"
- ]
- },
- "execution_count": 19,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "model.layers"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 31,
- "metadata": {},
- "outputs": [],
- "source": [
- "# model.summary() 使用DenseFeatures,没办法使用summary\n",
- "# model.variables"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 32,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "19.90625"
- ]
- },
- "execution_count": 32,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "637/32"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 33,
- "metadata": {
- "collapsed": true
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 1/100\n",
- "WARNING:tensorflow:Layer dense_features_3 is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2. The layer has dtype float32 because it's dtype defaults to floatx.\n",
- "\n",
- "If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.\n",
- "\n",
- "To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.\n",
- "\n",
- "19/19 [==============================] - 1s 59ms/step - loss: 1.6348 - accuracy: 0.5658 - val_loss: 0.7247 - val_accuracy: 0.6641\n",
- "Epoch 2/100\n",
- "19/19 [==============================] - 0s 6ms/step - loss: 0.9808 - accuracy: 0.6020 - val_loss: 0.6643 - val_accuracy: 0.7070\n",
- "Epoch 3/100\n",
- "19/19 [==============================] - 0s 6ms/step - loss: 0.6924 - accuracy: 0.6776 - val_loss: 0.6271 - val_accuracy: 0.6328\n",
- "Epoch 4/100\n",
- "19/19 [==============================] - 0s 6ms/step - loss: 0.6906 - accuracy: 0.6875 - val_loss: 0.6064 - val_accuracy: 0.6836\n",
- "Epoch 5/100\n",
- "19/19 [==============================] - 0s 4ms/step - loss: 0.6415 - accuracy: 0.6727 - val_loss: 0.5961 - val_accuracy: 0.7148\n",
- "Epoch 6/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.6482 - accuracy: 0.6760 - val_loss: 0.6056 - val_accuracy: 0.7031\n",
- "Epoch 7/100\n",
- "19/19 [==============================] - 0s 6ms/step - loss: 0.6552 - accuracy: 0.6497 - val_loss: 0.6448 - val_accuracy: 0.6602\n",
- "Epoch 8/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.5845 - accuracy: 0.6875 - val_loss: 0.6202 - val_accuracy: 0.6562\n",
- "Epoch 9/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.6384 - accuracy: 0.6464 - val_loss: 0.5930 - val_accuracy: 0.7109\n",
- "Epoch 10/100\n",
- "19/19 [==============================] - 0s 6ms/step - loss: 0.6100 - accuracy: 0.6875 - val_loss: 0.6665 - val_accuracy: 0.6289\n",
- "Epoch 11/100\n",
- "19/19 [==============================] - 0s 9ms/step - loss: 0.6134 - accuracy: 0.6612 - val_loss: 0.5899 - val_accuracy: 0.7031\n",
- "Epoch 12/100\n",
- "19/19 [==============================] - 0s 16ms/step - loss: 0.6154 - accuracy: 0.6776 - val_loss: 0.5820 - val_accuracy: 0.7031\n",
- "Epoch 13/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.5937 - accuracy: 0.6941 - val_loss: 0.6564 - val_accuracy: 0.6133\n",
- "Epoch 14/100\n",
- "19/19 [==============================] - 0s 8ms/step - loss: 0.6062 - accuracy: 0.6875 - val_loss: 0.5792 - val_accuracy: 0.7031\n",
- "Epoch 15/100\n",
- "19/19 [==============================] - 0s 6ms/step - loss: 0.5699 - accuracy: 0.7056 - val_loss: 0.5909 - val_accuracy: 0.6797\n",
- "Epoch 16/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.5868 - accuracy: 0.6760 - val_loss: 0.6631 - val_accuracy: 0.6562\n",
- "Epoch 17/100\n",
- "19/19 [==============================] - 0s 8ms/step - loss: 0.6175 - accuracy: 0.6793 - val_loss: 0.5750 - val_accuracy: 0.7109\n",
- "Epoch 18/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.6196 - accuracy: 0.6562 - val_loss: 0.5981 - val_accuracy: 0.6641\n",
- "Epoch 19/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.5879 - accuracy: 0.6974 - val_loss: 0.5780 - val_accuracy: 0.6992\n",
- "Epoch 20/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.5752 - accuracy: 0.7171 - val_loss: 0.5764 - val_accuracy: 0.7148\n",
- "Epoch 21/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.5970 - accuracy: 0.6924 - val_loss: 0.6356 - val_accuracy: 0.6328\n",
- "Epoch 22/100\n",
- "19/19 [==============================] - 0s 6ms/step - loss: 0.5937 - accuracy: 0.7089 - val_loss: 0.5692 - val_accuracy: 0.7227\n",
- "Epoch 23/100\n",
- "19/19 [==============================] - 0s 8ms/step - loss: 0.5777 - accuracy: 0.7056 - val_loss: 0.5817 - val_accuracy: 0.7031\n",
- "Epoch 24/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.5804 - accuracy: 0.6957 - val_loss: 0.5662 - val_accuracy: 0.7188\n",
- "Epoch 25/100\n",
- "19/19 [==============================] - 0s 6ms/step - loss: 0.5770 - accuracy: 0.7072 - val_loss: 0.5827 - val_accuracy: 0.7109\n",
- "Epoch 26/100\n",
- "19/19 [==============================] - 0s 6ms/step - loss: 0.5817 - accuracy: 0.7105 - val_loss: 0.5854 - val_accuracy: 0.6797\n",
- "Epoch 27/100\n",
- "19/19 [==============================] - 0s 8ms/step - loss: 0.5823 - accuracy: 0.7039 - val_loss: 0.5649 - val_accuracy: 0.6992\n",
- "Epoch 28/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.5729 - accuracy: 0.7270 - val_loss: 0.5539 - val_accuracy: 0.7227\n",
- "Epoch 29/100\n",
- "19/19 [==============================] - 0s 8ms/step - loss: 0.5694 - accuracy: 0.6974 - val_loss: 0.5540 - val_accuracy: 0.7188\n",
- "Epoch 30/100\n",
- "19/19 [==============================] - 0s 5ms/step - loss: 0.5706 - accuracy: 0.7270 - val_loss: 0.5607 - val_accuracy: 0.6953\n",
- "Epoch 31/100\n",
- "19/19 [==============================] - 0s 10ms/step - loss: 0.5768 - accuracy: 0.7253 - val_loss: 0.5836 - val_accuracy: 0.6484\n",
- "Epoch 32/100\n",
- "19/19 [==============================] - 0s 13ms/step - loss: 0.5717 - accuracy: 0.7056 - val_loss: 0.5533 - val_accuracy: 0.7109\n",
- "Epoch 33/100\n",
- "19/19 [==============================] - 0s 6ms/step - loss: 0.5645 - accuracy: 0.7122 - val_loss: 0.5806 - val_accuracy: 0.6836\n",
- "Epoch 34/100\n",
- "19/19 [==============================] - 0s 5ms/step - loss: 0.5564 - accuracy: 0.7352 - val_loss: 0.5539 - val_accuracy: 0.7109\n",
- "Epoch 35/100\n",
- "19/19 [==============================] - 0s 5ms/step - loss: 0.5838 - accuracy: 0.7138 - val_loss: 0.5531 - val_accuracy: 0.7266\n",
- "Epoch 36/100\n",
- "19/19 [==============================] - 0s 6ms/step - loss: 0.5649 - accuracy: 0.7204 - val_loss: 0.5561 - val_accuracy: 0.7344\n",
- "Epoch 37/100\n",
- "19/19 [==============================] - 0s 4ms/step - loss: 0.5630 - accuracy: 0.7303 - val_loss: 0.5460 - val_accuracy: 0.7305\n",
- "Epoch 38/100\n",
- "19/19 [==============================] - 0s 10ms/step - loss: 0.5529 - accuracy: 0.7352 - val_loss: 0.5460 - val_accuracy: 0.7188\n",
- "Epoch 39/100\n",
- "19/19 [==============================] - 0s 9ms/step - loss: 0.5879 - accuracy: 0.7105 - val_loss: 0.5653 - val_accuracy: 0.7109\n",
- "Epoch 40/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.5764 - accuracy: 0.7155 - val_loss: 0.5658 - val_accuracy: 0.7031\n",
- "Epoch 41/100\n",
- "19/19 [==============================] - 0s 6ms/step - loss: 0.5607 - accuracy: 0.7105 - val_loss: 0.5686 - val_accuracy: 0.7188\n",
- "Epoch 42/100\n",
- "19/19 [==============================] - 0s 9ms/step - loss: 0.6006 - accuracy: 0.7138 - val_loss: 0.5411 - val_accuracy: 0.7188\n",
- "Epoch 43/100\n",
- "19/19 [==============================] - 0s 5ms/step - loss: 0.5600 - accuracy: 0.7220 - val_loss: 0.5581 - val_accuracy: 0.7227\n",
- "Epoch 44/100\n",
- "19/19 [==============================] - 0s 8ms/step - loss: 0.5509 - accuracy: 0.7516 - val_loss: 0.5548 - val_accuracy: 0.7305\n",
- "Epoch 45/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.5581 - accuracy: 0.7286 - val_loss: 0.5423 - val_accuracy: 0.7266\n",
- "Epoch 46/100\n",
- "19/19 [==============================] - 0s 5ms/step - loss: 0.5403 - accuracy: 0.7533 - val_loss: 0.5361 - val_accuracy: 0.7227\n",
- "Epoch 47/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.5738 - accuracy: 0.7138 - val_loss: 0.5308 - val_accuracy: 0.7305\n",
- "Epoch 48/100\n",
- "19/19 [==============================] - 0s 11ms/step - loss: 0.5554 - accuracy: 0.7319 - val_loss: 0.5390 - val_accuracy: 0.7266\n",
- "Epoch 49/100\n",
- "19/19 [==============================] - 0s 10ms/step - loss: 0.5614 - accuracy: 0.7122 - val_loss: 0.5323 - val_accuracy: 0.7266\n",
- "Epoch 50/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.5808 - accuracy: 0.7007 - val_loss: 0.5294 - val_accuracy: 0.7383\n",
- "Epoch 51/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.5450 - accuracy: 0.7270 - val_loss: 0.6058 - val_accuracy: 0.6523\n",
- "Epoch 52/100\n",
- "19/19 [==============================] - 0s 6ms/step - loss: 0.5535 - accuracy: 0.7434 - val_loss: 0.5364 - val_accuracy: 0.7266\n",
- "Epoch 53/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.5363 - accuracy: 0.7500 - val_loss: 0.5505 - val_accuracy: 0.7266\n",
- "Epoch 54/100\n",
- "19/19 [==============================] - 0s 11ms/step - loss: 0.5521 - accuracy: 0.7401 - val_loss: 0.5354 - val_accuracy: 0.7227\n",
- "Epoch 55/100\n",
- "19/19 [==============================] - 0s 12ms/step - loss: 0.5175 - accuracy: 0.7648 - val_loss: 0.5477 - val_accuracy: 0.7266\n",
- "Epoch 56/100\n",
- "19/19 [==============================] - 0s 10ms/step - loss: 0.5386 - accuracy: 0.7484 - val_loss: 0.5433 - val_accuracy: 0.7344\n",
- "Epoch 57/100\n",
- "19/19 [==============================] - 0s 8ms/step - loss: 0.5678 - accuracy: 0.7007 - val_loss: 0.5391 - val_accuracy: 0.7344\n",
- "Epoch 58/100\n",
- "19/19 [==============================] - 0s 8ms/step - loss: 0.5226 - accuracy: 0.7434 - val_loss: 0.5691 - val_accuracy: 0.7266\n",
- "Epoch 59/100\n",
- "19/19 [==============================] - 0s 6ms/step - loss: 0.5508 - accuracy: 0.7368 - val_loss: 0.5308 - val_accuracy: 0.7383\n",
- "Epoch 60/100\n",
- "19/19 [==============================] - 0s 8ms/step - loss: 0.5464 - accuracy: 0.7352 - val_loss: 0.5257 - val_accuracy: 0.7383\n",
- "Epoch 61/100\n",
- "19/19 [==============================] - 0s 9ms/step - loss: 0.5719 - accuracy: 0.7188 - val_loss: 0.5185 - val_accuracy: 0.7383\n",
- "Epoch 62/100\n",
- "19/19 [==============================] - 0s 8ms/step - loss: 0.5675 - accuracy: 0.7352 - val_loss: 0.5203 - val_accuracy: 0.7305\n",
- "Epoch 63/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.5375 - accuracy: 0.7500 - val_loss: 0.5256 - val_accuracy: 0.7422\n",
- "Epoch 64/100\n",
- "19/19 [==============================] - 0s 6ms/step - loss: 0.5197 - accuracy: 0.7467 - val_loss: 0.5179 - val_accuracy: 0.7266\n",
- "Epoch 65/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.5418 - accuracy: 0.7303 - val_loss: 0.5151 - val_accuracy: 0.7305\n",
- "Epoch 66/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.5415 - accuracy: 0.7319 - val_loss: 0.5171 - val_accuracy: 0.7344\n",
- "Epoch 67/100\n",
- "19/19 [==============================] - 0s 6ms/step - loss: 0.5446 - accuracy: 0.7467 - val_loss: 0.6223 - val_accuracy: 0.6797\n",
- "Epoch 68/100\n",
- "19/19 [==============================] - 0s 6ms/step - loss: 0.5550 - accuracy: 0.7303 - val_loss: 0.5630 - val_accuracy: 0.7305\n",
- "Epoch 69/100\n",
- "19/19 [==============================] - 0s 6ms/step - loss: 0.5047 - accuracy: 0.7632 - val_loss: 0.5114 - val_accuracy: 0.7422\n",
- "Epoch 70/100\n",
- "19/19 [==============================] - 0s 11ms/step - loss: 0.5578 - accuracy: 0.7434 - val_loss: 0.5562 - val_accuracy: 0.7266\n",
- "Epoch 71/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.5290 - accuracy: 0.7368 - val_loss: 0.5365 - val_accuracy: 0.7422\n",
- "Epoch 72/100\n",
- "19/19 [==============================] - 0s 6ms/step - loss: 0.5085 - accuracy: 0.7484 - val_loss: 0.5249 - val_accuracy: 0.7383\n",
- "Epoch 73/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.5272 - accuracy: 0.7451 - val_loss: 0.5063 - val_accuracy: 0.7695\n",
- "Epoch 74/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.5289 - accuracy: 0.7599 - val_loss: 0.5061 - val_accuracy: 0.7539\n",
- "Epoch 75/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.5288 - accuracy: 0.7566 - val_loss: 0.5457 - val_accuracy: 0.7148\n",
- "Epoch 76/100\n",
- "19/19 [==============================] - 0s 5ms/step - loss: 0.5457 - accuracy: 0.7582 - val_loss: 0.5351 - val_accuracy: 0.7227\n",
- "Epoch 77/100\n",
- "19/19 [==============================] - 0s 6ms/step - loss: 0.5110 - accuracy: 0.7599 - val_loss: 0.5524 - val_accuracy: 0.7227\n",
- "Epoch 78/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.5156 - accuracy: 0.7401 - val_loss: 0.5335 - val_accuracy: 0.7305\n",
- "Epoch 79/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.5466 - accuracy: 0.7188 - val_loss: 0.5024 - val_accuracy: 0.7500\n",
- "Epoch 80/100\n",
- "19/19 [==============================] - 0s 9ms/step - loss: 0.5362 - accuracy: 0.7352 - val_loss: 0.5054 - val_accuracy: 0.7695\n",
- "Epoch 81/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.5526 - accuracy: 0.7204 - val_loss: 0.5390 - val_accuracy: 0.7227\n",
- "Epoch 82/100\n",
- "19/19 [==============================] - 0s 8ms/step - loss: 0.4990 - accuracy: 0.7714 - val_loss: 0.5243 - val_accuracy: 0.7539\n",
- "Epoch 83/100\n",
- "19/19 [==============================] - 0s 8ms/step - loss: 0.4918 - accuracy: 0.7961 - val_loss: 0.5003 - val_accuracy: 0.7500\n",
- "Epoch 84/100\n",
- "19/19 [==============================] - 0s 6ms/step - loss: 0.5606 - accuracy: 0.7171 - val_loss: 0.5543 - val_accuracy: 0.7188\n",
- "Epoch 85/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.5410 - accuracy: 0.7451 - val_loss: 0.5011 - val_accuracy: 0.7539\n",
- "Epoch 86/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.5198 - accuracy: 0.7434 - val_loss: 0.5084 - val_accuracy: 0.7305\n",
- "Epoch 87/100\n",
- "19/19 [==============================] - 0s 6ms/step - loss: 0.5142 - accuracy: 0.7516 - val_loss: 0.5121 - val_accuracy: 0.7578\n",
- "Epoch 88/100\n",
- "19/19 [==============================] - 0s 6ms/step - loss: 0.5621 - accuracy: 0.7253 - val_loss: 0.5153 - val_accuracy: 0.7422\n",
- "Epoch 89/100\n",
- "19/19 [==============================] - 0s 8ms/step - loss: 0.4897 - accuracy: 0.7747 - val_loss: 0.4928 - val_accuracy: 0.7539\n",
- "Epoch 90/100\n",
- "19/19 [==============================] - 0s 10ms/step - loss: 0.5021 - accuracy: 0.7829 - val_loss: 0.4975 - val_accuracy: 0.7461\n",
- "Epoch 91/100\n",
- "19/19 [==============================] - 0s 6ms/step - loss: 0.5165 - accuracy: 0.7599 - val_loss: 0.5363 - val_accuracy: 0.7070\n",
- "Epoch 92/100\n",
- "19/19 [==============================] - 0s 13ms/step - loss: 0.5383 - accuracy: 0.7467 - val_loss: 0.4968 - val_accuracy: 0.7539\n",
- "Epoch 93/100\n",
- "19/19 [==============================] - 0s 6ms/step - loss: 0.5140 - accuracy: 0.7582 - val_loss: 0.5597 - val_accuracy: 0.7266\n",
- "Epoch 94/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.5385 - accuracy: 0.7566 - val_loss: 0.4961 - val_accuracy: 0.7656\n",
- "Epoch 95/100\n",
- "19/19 [==============================] - 0s 9ms/step - loss: 0.5342 - accuracy: 0.7434 - val_loss: 0.5080 - val_accuracy: 0.7188\n",
- "Epoch 96/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.5023 - accuracy: 0.7664 - val_loss: 0.5634 - val_accuracy: 0.6680\n",
- "Epoch 97/100\n",
- "19/19 [==============================] - 0s 5ms/step - loss: 0.5102 - accuracy: 0.7928 - val_loss: 0.5013 - val_accuracy: 0.7422\n",
- "Epoch 98/100\n",
- "19/19 [==============================] - 0s 10ms/step - loss: 0.5202 - accuracy: 0.7368 - val_loss: 0.4993 - val_accuracy: 0.7578\n",
- "Epoch 99/100\n",
- "19/19 [==============================] - 0s 7ms/step - loss: 0.4922 - accuracy: 0.7878 - val_loss: 0.5875 - val_accuracy: 0.7148\n",
- "Epoch 100/100\n",
- "19/19 [==============================] - 0s 5ms/step - loss: 0.5336 - accuracy: 0.7451 - val_loss: 0.5038 - val_accuracy: 0.7578\n"
- ]
- }
- ],
- "source": [
- "# 1. model.fit 第一种方式\n",
- "# 2. model -> estimator -> train 第二种\n",
- "\n",
- "train_dataset = make_dataset(train_df, y_train, epochs = 100)\n",
- "eval_dataset = make_dataset(eval_df, y_eval, epochs = 100, shuffle = False) #验证集\n",
- "#train_dataset里边包含了特征,目标,feature_columns只处理特征\n",
- "history = model.fit(train_dataset,\n",
- " validation_data = eval_dataset,\n",
- " steps_per_epoch = 19,\n",
- " validation_steps = 8,\n",
- " epochs = 100)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 37,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Model: \"sequential\"\n",
- "_________________________________________________________________\n",
- "Layer (type) Output Shape Param # \n",
- "=================================================================\n",
- "dense_features_3 (DenseFeatu multiple 0 \n",
- "_________________________________________________________________\n",
- "dense (Dense) multiple 3500 \n",
- "_________________________________________________________________\n",
- "dense_1 (Dense) multiple 10100 \n",
- "_________________________________________________________________\n",
- "dense_2 (Dense) multiple 202 \n",
- "=================================================================\n",
- "Total params: 13,802\n",
- "Trainable params: 13,802\n",
- "Non-trainable params: 0\n",
- "_________________________________________________________________\n"
- ]
- }
- ],
- "source": [
- "model.summary()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 36,
- "metadata": {
- "scrolled": false
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "INFO:tensorflow:Using default config.\n",
- "WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmprqx91llq\n",
- "INFO:tensorflow:Using the Keras model provided.\n",
- "WARNING:tensorflow:You are creating an Estimator from a Keras model manually subclassed from `Model`, that was already called on some inputs (and thus already had weights). We are currently unable to preserve the model's state (its weights) as part of the estimator in this case. Be warned that the estimator has been created using a freshly initialized version of your model.\n",
- "Note that this doesn't affect the state of the model instance you passed as `keras_model` argument.\n",
- "INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmprqx91llq', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n",
- "graph_options {\n",
- " rewrite_options {\n",
- " meta_optimizer_iterations: ONE\n",
- " }\n",
- "}\n",
- ", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}\n",
- "INFO:tensorflow:Calling model_fn.\n"
- ]
- },
- {
- "ename": "ValueError",
- "evalue": "Unexpectedly found an instance of type `<class 'dict'>`. Expected a symbolic tensor instance.",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m<ipython-input-36-11b6f07b1baa>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;31m#是否是上面用了dict\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m estimator.train(input_fn = lambda : make_dataset(\n\u001b[0;32m----> 7\u001b[0;31m train_df, y_train, epochs=100),steps=1)\n\u001b[0m",
- "\u001b[0;32m~/.virtualenvs/tf_py3/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/estimator.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, input_fn, hooks, steps, max_steps, saving_listeners)\u001b[0m\n\u001b[1;32m 347\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 348\u001b[0m \u001b[0msaving_listeners\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_check_listeners_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msaving_listeners\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 349\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_train_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msaving_listeners\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 350\u001b[0m \u001b[0mlogging\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Loss for final step: %s.'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 351\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/.virtualenvs/tf_py3/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/estimator.py\u001b[0m in \u001b[0;36m_train_model\u001b[0;34m(self, input_fn, hooks, saving_listeners)\u001b[0m\n\u001b[1;32m 1180\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_train_model_distributed\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msaving_listeners\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1181\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1182\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_train_model_default\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msaving_listeners\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1183\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1184\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_train_model_default\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_fn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msaving_listeners\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/.virtualenvs/tf_py3/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/estimator.py\u001b[0m in \u001b[0;36m_train_model_default\u001b[0;34m(self, input_fn, hooks, saving_listeners)\u001b[0m\n\u001b[1;32m 1209\u001b[0m \u001b[0mworker_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mextend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_hooks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1210\u001b[0m estimator_spec = self._call_model_fn(features, labels, ModeKeys.TRAIN,\n\u001b[0;32m-> 1211\u001b[0;31m self.config)\n\u001b[0m\u001b[1;32m 1212\u001b[0m \u001b[0mglobal_step_tensor\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mv1\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_global_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1213\u001b[0m return self._train_with_estimator_spec(estimator_spec, worker_hooks,\n",
- "\u001b[0;32m~/.virtualenvs/tf_py3/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/estimator.py\u001b[0m in \u001b[0;36m_call_model_fn\u001b[0;34m(self, features, labels, mode, config)\u001b[0m\n\u001b[1;32m 1168\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1169\u001b[0m \u001b[0mlogging\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Calling model_fn.'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1170\u001b[0;31m \u001b[0mmodel_fn_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_model_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1171\u001b[0m \u001b[0mlogging\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Done calling model_fn.'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1172\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/.virtualenvs/tf_py3/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/keras.py\u001b[0m in \u001b[0;36mmodel_fn\u001b[0;34m(features, labels, mode)\u001b[0m\n\u001b[1;32m 284\u001b[0m \u001b[0mfeatures\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 285\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 286\u001b[0;31m optimizer_config=optimizer_config)\n\u001b[0m\u001b[1;32m 287\u001b[0m \u001b[0mmodel_output_names\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 288\u001b[0m \u001b[0;31m# We need to make sure that the output names of the last layer in the model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/.virtualenvs/tf_py3/lib/python3.6/site-packages/tensorflow_estimator/python/estimator/keras.py\u001b[0m in \u001b[0;36m_clone_and_build_model\u001b[0;34m(mode, keras_model, custom_objects, features, labels, optimizer_config)\u001b[0m\n\u001b[1;32m 222\u001b[0m \u001b[0min_place_reset\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mnot\u001b[0m \u001b[0mkeras_model\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_is_graph_network\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[0moptimizer_iterations\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mglobal_step\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 224\u001b[0;31m optimizer_config=optimizer_config)\n\u001b[0m\u001b[1;32m 225\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 226\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0msample_weight_tensors\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/.virtualenvs/tf_py3/lib/python3.6/site-packages/tensorflow/python/keras/models.py\u001b[0m in \u001b[0;36mclone_and_build_model\u001b[0;34m(model, input_tensors, target_tensors, custom_objects, compile_clone, in_place_reset, optimizer_iterations, optimizer_config)\u001b[0m\n\u001b[1;32m 651\u001b[0m \u001b[0mclone\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclone_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_tensors\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minput_tensors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 652\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSequential\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 653\u001b[0;31m \u001b[0mclone\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclone_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_tensors\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minput_tensors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 654\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;32mnot\u001b[0m \u001b[0mclone\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_is_graph_network\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_build_input_shape\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 655\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexecuting_eagerly_outside_functions\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/.virtualenvs/tf_py3/lib/python3.6/site-packages/tensorflow/python/keras/models.py\u001b[0m in \u001b[0;36mclone_model\u001b[0;34m(model, input_tensors, clone_function)\u001b[0m\n\u001b[1;32m 422\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSequential\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 423\u001b[0m return _clone_sequential_model(\n\u001b[0;32m--> 424\u001b[0;31m model, input_tensors=input_tensors, layer_fn=clone_function)\n\u001b[0m\u001b[1;32m 425\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 426\u001b[0m return _clone_functional_model(\n",
- "\u001b[0;32m~/.virtualenvs/tf_py3/lib/python3.6/site-packages/tensorflow/python/keras/models.py\u001b[0m in \u001b[0;36m_clone_sequential_model\u001b[0;34m(model, input_tensors, layer_fn)\u001b[0m\n\u001b[1;32m 341\u001b[0m \u001b[0minput_tensors\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_tensors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 342\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgeneric_utils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_list\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_tensors\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 343\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mK\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_keras_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 344\u001b[0m \u001b[0morigin_layer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_keras_history\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 345\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0morigin_layer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mInputLayer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m~/.virtualenvs/tf_py3/lib/python3.6/site-packages/tensorflow/python/keras/backend.py\u001b[0m in \u001b[0;36mis_keras_tensor\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 1018\u001b[0m sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)):\n\u001b[1;32m 1019\u001b[0m raise ValueError('Unexpectedly found an instance of type `' + str(type(x)) +\n\u001b[0;32m-> 1020\u001b[0;31m '`. Expected a symbolic tensor instance.')\n\u001b[0m\u001b[1;32m 1021\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'_keras_history'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1022\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;31mValueError\u001b[0m: Unexpectedly found an instance of type `<class 'dict'>`. Expected a symbolic tensor instance."
- ]
- }
- ],
- "source": [
- "#第二种方法,TensorFlow的已知bug,等修复\n",
- "estimator = keras.estimator.model_to_estimator(model)\n",
- "# 1. function 输入必须是函数\n",
- "# 2. return 返回值 a. (features, labels) b. dataset -> (feature, label)\n",
- "#是否是上面用了dict\n",
- "estimator.train(input_fn = lambda : make_dataset(\n",
- " train_df, y_train, epochs=100),steps=1)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "m.reset_states()\n",
- "m.update_state([[0, 1], [0, 0]], [[1, 2], [0, 0]],\n",
- " sample_weight=[1, 0])\n",
- "m.result().numpy()"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.6.9"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
- }
|