|
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094 |
- {
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "2.2.0\n",
- "sys.version_info(major=3, minor=6, micro=9, releaselevel='final', serial=0)\n",
- "matplotlib 3.3.4\n",
- "numpy 1.19.5\n",
- "pandas 1.1.5\n",
- "sklearn 0.24.2\n",
- "tensorflow 2.2.0\n",
- "tensorflow.keras 2.3.0-tf\n"
- ]
- }
- ],
- "source": [
- "import matplotlib as mpl\n",
- "import matplotlib.pyplot as plt\n",
- "%matplotlib inline\n",
- "import numpy as np\n",
- "import sklearn\n",
- "import pandas as pd\n",
- "import os\n",
- "import sys\n",
- "import time\n",
- "import tensorflow as tf\n",
- "\n",
- "from tensorflow import keras\n",
- "\n",
- "print(tf.__version__)\n",
- "print(sys.version_info)\n",
- "for module in mpl, np, pd, sklearn, tf, keras:\n",
- " print(module.__name__, module.__version__)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "test_00.csv test_08.csv train_06.csv train_14.csv valid_02.csv\r\n",
- "test_01.csv test_09.csv train_07.csv train_15.csv valid_03.csv\r\n",
- "test_02.csv train_00.csv train_08.csv train_16.csv valid_04.csv\r\n",
- "test_03.csv train_01.csv train_09.csv train_17.csv valid_05.csv\r\n",
- "test_04.csv train_02.csv train_10.csv train_18.csv valid_06.csv\r\n",
- "test_05.csv train_03.csv train_11.csv train_19.csv valid_07.csv\r\n",
- "test_06.csv train_04.csv train_12.csv valid_00.csv valid_08.csv\r\n",
- "test_07.csv train_05.csv train_13.csv valid_01.csv valid_09.csv\r\n"
- ]
- }
- ],
- "source": [
- "!ls generate_csv"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "['./generate_csv/train_08.csv',\n",
- " './generate_csv/train_11.csv',\n",
- " './generate_csv/train_18.csv',\n",
- " './generate_csv/train_15.csv',\n",
- " './generate_csv/train_17.csv',\n",
- " './generate_csv/train_00.csv',\n",
- " './generate_csv/train_01.csv',\n",
- " './generate_csv/train_19.csv',\n",
- " './generate_csv/train_14.csv',\n",
- " './generate_csv/train_02.csv',\n",
- " './generate_csv/train_16.csv',\n",
- " './generate_csv/train_09.csv',\n",
- " './generate_csv/train_03.csv',\n",
- " './generate_csv/train_12.csv',\n",
- " './generate_csv/train_10.csv',\n",
- " './generate_csv/train_13.csv',\n",
- " './generate_csv/train_05.csv',\n",
- " './generate_csv/train_07.csv',\n",
- " './generate_csv/train_04.csv',\n",
- " './generate_csv/train_06.csv']\n",
- "['./generate_csv/valid_01.csv',\n",
- " './generate_csv/valid_05.csv',\n",
- " './generate_csv/valid_02.csv',\n",
- " './generate_csv/valid_04.csv',\n",
- " './generate_csv/valid_08.csv',\n",
- " './generate_csv/valid_07.csv',\n",
- " './generate_csv/valid_06.csv',\n",
- " './generate_csv/valid_00.csv',\n",
- " './generate_csv/valid_09.csv',\n",
- " './generate_csv/valid_03.csv']\n",
- "['./generate_csv/test_00.csv',\n",
- " './generate_csv/test_07.csv',\n",
- " './generate_csv/test_01.csv',\n",
- " './generate_csv/test_08.csv',\n",
- " './generate_csv/test_06.csv',\n",
- " './generate_csv/test_02.csv',\n",
- " './generate_csv/test_04.csv',\n",
- " './generate_csv/test_05.csv',\n",
- " './generate_csv/test_09.csv',\n",
- " './generate_csv/test_03.csv']\n"
- ]
- }
- ],
- "source": [
- "source_dir = \"./generate_csv/\"\n",
- "\n",
- "#通过判断开头去添加文件\n",
- "def get_filenames_by_prefix(source_dir, prefix_name):\n",
- " all_files = os.listdir(source_dir)\n",
- " results = []\n",
- " for filename in all_files:\n",
- " if filename.startswith(prefix_name):\n",
- " results.append(os.path.join(source_dir, filename))\n",
- " return results\n",
- "\n",
- "train_filenames = get_filenames_by_prefix(source_dir, \"train\")\n",
- "valid_filenames = get_filenames_by_prefix(source_dir, \"valid\")\n",
- "test_filenames = get_filenames_by_prefix(source_dir, \"test\")\n",
- "\n",
- "import pprint\n",
- "pprint.pprint(train_filenames)\n",
- "pprint.pprint(valid_filenames)\n",
- "pprint.pprint(test_filenames)\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [],
- "source": [
- "#下面的接口都是之前用过的\n",
- "def parse_csv_line(line, n_fields = 9):\n",
- " defs = [tf.constant(np.nan)] * n_fields\n",
- " parsed_fields = tf.io.decode_csv(line, record_defaults=defs)\n",
- " x = tf.stack(parsed_fields[0:-1])\n",
- " y = tf.stack(parsed_fields[-1:])\n",
- " return x, y\n",
- "\n",
- "def csv_reader_dataset(filenames, n_readers=5,\n",
- " batch_size=32, n_parse_threads=5,\n",
- " shuffle_buffer_size=10000):\n",
- " dataset = tf.data.Dataset.list_files(filenames)\n",
- " dataset = dataset.repeat()\n",
- " dataset = dataset.interleave(\n",
- " lambda filename: tf.data.TextLineDataset(filename).skip(1),\n",
- " cycle_length = n_readers\n",
- " )\n",
- " dataset.shuffle(shuffle_buffer_size)\n",
- " dataset = dataset.map(parse_csv_line,\n",
- " num_parallel_calls=n_parse_threads)\n",
- " dataset = dataset.batch(batch_size)\n",
- " return dataset\n",
- "\n",
- "batch_size = 32\n",
- "train_set = csv_reader_dataset(train_filenames,\n",
- " batch_size = batch_size)\n",
- "valid_set = csv_reader_dataset(valid_filenames,\n",
- " batch_size = batch_size)\n",
- "test_set = csv_reader_dataset(test_filenames,\n",
- " batch_size = batch_size)\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "chapter_4.tar.gz\t tf02_data_generate_csv.ipynb\r\n",
- "generate_csv\t\t tf03-tfrecord_basic_api.ipynb\r\n",
- "generate_tfrecords\t tf04_data_generate_tfrecord.ipynb\r\n",
- "temp.csv\t\t tfrecord_basic\r\n",
- "tf01-dataset_basic_api.ipynb\r\n"
- ]
- }
- ],
- "source": [
- "!ls"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "# 把train_set,valid_set,test_set 存储到tfrecord类型的文件中"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [],
- "source": [
- "#把基础的如何序列化的步骤搞到一个函数\n",
- "def serialize_example(x, y):\n",
- " \"\"\"Converts x, y to tf.train.Example and serialize\"\"\"\n",
- " input_feautres = tf.train.FloatList(value = x) #特征\n",
- " label = tf.train.FloatList(value = y)#标签\n",
- " features = tf.train.Features(\n",
- " feature = {\n",
- " \"input_features\": tf.train.Feature(\n",
- " float_list = input_feautres),\n",
- " \"label\": tf.train.Feature(float_list = label)\n",
- " }\n",
- " )\n",
- " #把features变为example\n",
- " example = tf.train.Example(features = features)\n",
- " return example.SerializeToString() #把example序列化\n",
- "#n_shards是存为多少个文件,steps_per_shard和 steps_per_epoch类似\n",
- "def csv_dataset_to_tfrecords(base_filename, dataset,\n",
- " n_shards, steps_per_shard,\n",
- " compression_type = None):\n",
- " #压缩文件类型\n",
- " options = tf.io.TFRecordOptions(\n",
- " compression_type = compression_type)\n",
- " all_filenames = []\n",
- " \n",
- " for shard_id in range(n_shards):\n",
- " filename_fullpath = '{}_{:05d}-of-{:05d}'.format(\n",
- " base_filename, shard_id, n_shards) #base_filename是一个前缀\n",
- " #打开文件\n",
- " with tf.io.TFRecordWriter(filename_fullpath, options) as writer:\n",
- " #取出数据,为什么skip,上一个文件写了前500行,下一个文件存后面的数据\n",
- " for x_batch, y_batch in dataset.skip(shard_id * steps_per_shard).take(steps_per_shard):\n",
- " for x_example, y_example in zip(x_batch, y_batch):\n",
- " writer.write(\n",
- " serialize_example(x_example, y_example))\n",
- " all_filenames.append(filename_fullpath)\n",
- " #返回所有tfrecord文件名\n",
- " return all_filenames"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [],
- "source": [
- "!rm -rf generate_tfrecords"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {
- "collapsed": true
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "(<tf.Tensor: shape=(32, 8), dtype=float32, numpy=\n",
- "array([[ 8.01544309e-01, 2.72161424e-01, -1.16243929e-01,\n",
- " -2.02311516e-01, -5.43051600e-01, -2.10396163e-02,\n",
- " -5.89762092e-01, -8.24184567e-02],\n",
- " [ 4.85305160e-01, -8.49241912e-01, -6.53012618e-02,\n",
- " -2.33796556e-02, 1.49743509e+00, -7.79065788e-02,\n",
- " -9.02363241e-01, 7.81451464e-01],\n",
- " [-6.67222738e-01, -4.82395217e-02, 3.45294058e-01,\n",
- " 5.38266838e-01, 1.85218394e+00, -6.11253828e-02,\n",
- " -8.41709316e-01, 1.52048469e+00],\n",
- " [ 6.30343556e-01, 1.87416613e+00, -6.71321452e-02,\n",
- " -1.25433668e-01, -1.97375536e-01, -2.27226317e-02,\n",
- " -6.92407250e-01, 7.26523340e-01],\n",
- " [ 6.36364639e-01, -1.08954263e+00, 9.26090255e-02,\n",
- " -2.05381244e-01, 1.20256710e+00, -3.63012254e-02,\n",
- " -6.78410172e-01, 1.82235345e-01],\n",
- " [-2.98072815e-01, 3.52261662e-01, -1.09205075e-01,\n",
- " -2.50555217e-01, -3.40640247e-02, -6.03400404e-03,\n",
- " 1.08055484e+00, -1.06113815e+00],\n",
- " [-7.43205428e-01, 9.12963331e-01, -6.44320250e-01,\n",
- " -1.47909701e-01, 7.39851117e-01, 1.14276908e-01,\n",
- " -7.95052409e-01, 6.81582153e-01],\n",
- " [ 1.51805115e+00, -5.28840959e-01, 8.10247064e-01,\n",
- " -1.92141697e-01, 4.41353947e-01, 2.73350589e-02,\n",
- " -8.18380833e-01, 8.56353521e-01],\n",
- " [ 1.63122582e+00, 3.52261662e-01, 4.08057608e-02,\n",
- " -1.40889511e-01, -4.63210404e-01, -6.75162375e-02,\n",
- " -8.27712238e-01, 5.96693158e-01],\n",
- " [-1.06354737e+00, 1.87416613e+00, -4.93448943e-01,\n",
- " -6.96261302e-02, -2.73587584e-01, -1.34195149e-01,\n",
- " 1.03389800e+00, -1.34576583e+00],\n",
- " [-3.05882934e-02, -9.29342151e-01, 2.59621471e-01,\n",
- " -6.01274054e-03, -5.00409126e-01, -3.07798684e-02,\n",
- " 1.59844637e+00, -1.81515181e+00],\n",
- " [-8.24676275e-01, -4.82395217e-02, -3.44865829e-01,\n",
- " -8.47758725e-02, 5.01234829e-01, -3.46999951e-02,\n",
- " 5.30003488e-01, -8.74119252e-02],\n",
- " [-3.29563528e-01, 9.93063569e-01, -8.77174079e-01,\n",
- " -3.63671094e-01, -1.11645639e+00, -8.51059332e-02,\n",
- " 1.06655777e+00, -1.38571358e+00],\n",
- " [ 4.04922552e-02, -6.89041436e-01, -4.43798512e-01,\n",
- " 2.23745853e-02, -2.21872270e-01, -1.48285031e-01,\n",
- " -8.88366222e-01, 6.36640906e-01],\n",
- " [-9.97422278e-01, 1.23336422e+00, -7.57719278e-01,\n",
- " -1.11092515e-02, -2.30037838e-01, 5.48742227e-02,\n",
- " -7.57726908e-01, 7.06549466e-01],\n",
- " [ 1.90638328e+00, 5.12462139e-01, 4.47582811e-01,\n",
- " -2.76721776e-01, -6.31058335e-01, -7.08114654e-02,\n",
- " -7.06404328e-01, 7.46497214e-01],\n",
- " [ 7.00647458e-02, 3.18607129e-02, -2.57098645e-01,\n",
- " -3.00019473e-01, -2.66329288e-01, -9.85835046e-02,\n",
- " 1.08522058e+00, -1.37073314e+00],\n",
- " [-6.91253722e-01, -4.82395217e-02, -8.84629071e-01,\n",
- " 4.13295318e-04, -1.42938361e-01, -1.61392525e-01,\n",
- " -7.20401347e-01, 6.16667032e-01],\n",
- " [-4.27712232e-01, -9.29342151e-01, -3.57968122e-01,\n",
- " 6.92046210e-02, -8.15237403e-01, -5.46457283e-02,\n",
- " -5.71099281e-01, 5.51751971e-01],\n",
- " [-1.86834182e-03, -1.24974310e+00, 2.20688999e-01,\n",
- " -2.50608753e-02, 3.33386868e-01, 2.75048837e-02,\n",
- " 9.17255700e-01, -6.81634605e-01],\n",
- " [ 3.88017392e+00, -9.29342151e-01, 1.29029870e+00,\n",
- " -1.72691330e-01, -2.25501403e-01, 5.14101014e-02,\n",
- " -8.46374989e-01, 8.86314332e-01],\n",
- " [ 1.23452818e+00, -1.32984328e+00, 5.75435698e-01,\n",
- " -1.19361848e-01, -4.16938782e-01, 1.54011786e-01,\n",
- " 1.09921765e+00, -1.34576583e+00],\n",
- " [ 9.00758684e-01, -1.28339753e-01, 8.26199055e-02,\n",
- " -2.27456287e-01, -2.84474999e-01, 8.01625699e-02,\n",
- " -9.16360319e-01, 8.11412275e-01],\n",
- " [-1.04972430e-01, -2.13084579e+00, -8.56280804e-01,\n",
- " 3.74467552e-01, -7.72594988e-01, -1.98224083e-01,\n",
- " 9.96572435e-01, -1.42566133e+00],\n",
- " [-4.90587056e-01, 3.52261662e-01, -6.64385974e-01,\n",
- " 2.82370716e-01, 5.10307670e-01, 4.01096910e-01,\n",
- " 8.51936042e-01, -1.31580508e+00],\n",
- " [ 3.72503400e-01, -6.89041436e-01, 6.45801365e-01,\n",
- " 8.00678432e-02, -3.15322757e-01, -2.51115970e-02,\n",
- " 5.62663257e-01, -5.74511178e-02],\n",
- " [ 1.58142820e-01, 1.15326405e+00, -9.75820422e-02,\n",
- " -2.74813175e-01, -6.60091519e-01, -9.36449692e-02,\n",
- " -8.51040661e-01, 7.26523340e-01],\n",
- " [-2.78091401e-01, 2.72161424e-01, 4.07564074e-01,\n",
- " -3.05503774e-02, -1.23885356e-01, -2.59928163e-02,\n",
- " 1.40715313e+00, -8.61399472e-01],\n",
- " [-2.48145923e-01, -1.08954263e+00, -3.17847952e-02,\n",
- " -1.67912588e-01, -9.64032352e-01, -6.52535707e-02,\n",
- " 1.46314144e+00, -1.50555682e+00],\n",
- " [ 4.95588928e-01, 1.23336422e+00, -2.66338676e-01,\n",
- " -3.35229747e-02, -1.15637708e+00, -1.26926363e-01,\n",
- " 8.65933120e-01, -1.33577895e+00],\n",
- " [ 4.06818151e-01, 1.79406595e+00, 5.06766677e-01,\n",
- " -1.45333722e-01, -4.59581256e-01, -9.45077166e-02,\n",
- " 9.73243952e-01, -1.28085077e+00],\n",
- " [-1.15839255e+00, 1.23336422e+00, -7.80114472e-01,\n",
- " -3.65703627e-02, 1.13002844e-02, 2.21132398e-01,\n",
- " -7.76389658e-01, 6.61608279e-01]], dtype=float32)>, <tf.Tensor: shape=(32, 1), dtype=float32, numpy=\n",
- "array([[3.226 ],\n",
- " [2.956 ],\n",
- " [1.59 ],\n",
- " [2.419 ],\n",
- " [2.429 ],\n",
- " [1.514 ],\n",
- " [1.438 ],\n",
- " [2.898 ],\n",
- " [3.376 ],\n",
- " [1.982 ],\n",
- " [1.598 ],\n",
- " [0.717 ],\n",
- " [1.563 ],\n",
- " [2.852 ],\n",
- " [1.739 ],\n",
- " [5.00001],\n",
- " [1.61 ],\n",
- " [2.321 ],\n",
- " [2.385 ],\n",
- " [1.293 ],\n",
- " [5.00001],\n",
- " [2.487 ],\n",
- " [2.392 ],\n",
- " [2.75 ],\n",
- " [2.028 ],\n",
- " [1.431 ],\n",
- " [2.26 ],\n",
- " [1.171 ],\n",
- " [1.031 ],\n",
- " [2.538 ],\n",
- " [2.519 ],\n",
- " [1.125 ]], dtype=float32)>)\n"
- ]
- }
- ],
- "source": [
- "for i in train_set.take(1):\n",
- " print(i) "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "CPU times: user 40 s, sys: 8.95 s, total: 48.9 s\n",
- "Wall time: 42.6 s\n"
- ]
- }
- ],
- "source": [
- "%%time\n",
- "# 训练集和测试集都分20\n",
- "n_shards = 20\n",
- "train_steps_per_shard = 11610 // batch_size // n_shards\n",
- "valid_steps_per_shard = 3880 // batch_size // 10\n",
- "test_steps_per_shard = 5170 // batch_size // 10\n",
- "\n",
- "output_dir = \"generate_tfrecords\"\n",
- "if not os.path.exists(output_dir):\n",
- " os.mkdir(output_dir)\n",
- "\n",
- "train_basename = os.path.join(output_dir, \"train\")\n",
- "valid_basename = os.path.join(output_dir, \"valid\")\n",
- "test_basename = os.path.join(output_dir, \"test\")\n",
- "\n",
- "train_tfrecord_filenames = csv_dataset_to_tfrecords(\n",
- " train_basename, train_set, n_shards, train_steps_per_shard, None)\n",
- "valid_tfrecord_filenames = csv_dataset_to_tfrecords(\n",
- " valid_basename, valid_set, 10, valid_steps_per_shard, None)\n",
- "test_tfrecord_fielnames = csv_dataset_to_tfrecords(\n",
- " test_basename, test_set, 10, test_steps_per_shard, None)\n",
- "#执行会发现目录下总计生成了60个文件,这里文件数目改为一致,为了对比时间"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "总用量 1960\r\n",
- "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00000-of-00010\r\n",
- "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00001-of-00010\r\n",
- "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00002-of-00010\r\n",
- "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00003-of-00010\r\n",
- "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00004-of-00010\r\n",
- "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00005-of-00010\r\n",
- "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00006-of-00010\r\n",
- "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00007-of-00010\r\n",
- "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00008-of-00010\r\n",
- "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00009-of-00010\r\n",
- "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:32 train_00000-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:32 train_00001-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:32 train_00002-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:32 train_00003-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:32 train_00004-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:32 train_00005-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:32 train_00006-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:32 train_00007-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:32 train_00008-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00009-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00010-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00011-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00012-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00013-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00014-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00015-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00016-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00017-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00018-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00019-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00000-of-00010\r\n",
- "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00001-of-00010\r\n",
- "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00002-of-00010\r\n",
- "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00003-of-00010\r\n",
- "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00004-of-00010\r\n",
- "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00005-of-00010\r\n",
- "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00006-of-00010\r\n",
- "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00007-of-00010\r\n",
- "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00008-of-00010\r\n",
- "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00009-of-00010\r\n"
- ]
- }
- ],
- "source": [
- "!ls -l generate_tfrecords"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 33,
- "metadata": {},
- "outputs": [],
- "source": [
- "#生成一下压缩的\n",
- "# n_shards = 20\n",
- "# train_steps_per_shard = 11610 // batch_size // n_shards\n",
- "# valid_steps_per_shard = 3880 // batch_size // n_shards\n",
- "# test_steps_per_shard = 5170 // batch_size // n_shards\n",
- "\n",
- "# output_dir = \"generate_tfrecords_zip\"\n",
- "# if not os.path.exists(output_dir):\n",
- "# os.mkdir(output_dir)\n",
- "\n",
- "# train_basename = os.path.join(output_dir, \"train\")\n",
- "# valid_basename = os.path.join(output_dir, \"valid\")\n",
- "# test_basename = os.path.join(output_dir, \"test\")\n",
- "# #只需修改参数的类型即可\n",
- "# train_tfrecord_filenames = csv_dataset_to_tfrecords(\n",
- "# train_basename, train_set, n_shards, train_steps_per_shard,\n",
- "# compression_type = \"GZIP\")\n",
- "# valid_tfrecord_filenames = csv_dataset_to_tfrecords(\n",
- "# valid_basename, valid_set, n_shards, valid_steps_per_shard,\n",
- "# compression_type = \"GZIP\")\n",
- "# test_tfrecord_fielnames = csv_dataset_to_tfrecords(\n",
- "# test_basename, test_set, n_shards, test_steps_per_shard,\n",
- "# compression_type = \"GZIP\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 34,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "总用量 860\r\n",
- "-rw-rw-r-- 1 luke luke 10171 May 7 11:16 test_00000-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 10230 May 7 11:16 test_00001-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 10204 May 7 11:16 test_00002-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 10213 May 7 11:16 test_00003-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 10229 May 7 11:16 test_00004-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 10200 May 7 11:16 test_00005-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 10199 May 7 11:16 test_00006-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 10215 May 7 11:16 test_00007-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 10179 May 7 11:16 test_00008-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 10149 May 7 11:16 test_00009-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 10141 May 7 11:16 test_00010-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 10221 May 7 11:16 test_00011-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 10209 May 7 11:16 test_00012-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 10214 May 7 11:16 test_00013-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 10212 May 7 11:16 test_00014-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 10209 May 7 11:16 test_00015-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 10185 May 7 11:16 test_00016-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 10266 May 7 11:16 test_00017-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 10258 May 7 11:16 test_00018-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 10170 May 7 11:16 test_00019-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 22359 May 7 19:17 train_00000-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 22447 May 7 19:17 train_00001-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 22366 May 7 19:17 train_00002-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 22311 May 7 19:17 train_00003-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 22384 May 7 19:17 train_00004-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 22341 May 7 19:17 train_00005-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 22416 May 7 19:17 train_00006-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 22285 May 7 19:17 train_00007-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 22415 May 7 19:17 train_00008-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 22365 May 7 19:17 train_00009-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 22431 May 7 19:17 train_00010-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 22367 May 7 19:17 train_00011-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 22346 May 7 19:17 train_00012-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 22332 May 7 19:17 train_00013-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 22452 May 7 19:17 train_00014-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 20 May 7 19:17 train_00015-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 22427 May 7 11:16 train_00016-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 22427 May 7 11:16 train_00017-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 22454 May 7 11:16 train_00018-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 22309 May 7 11:16 train_00019-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 7747 May 7 11:16 valid_00000-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 7744 May 7 11:16 valid_00001-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 7749 May 7 11:16 valid_00002-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 7755 May 7 11:16 valid_00003-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 7744 May 7 11:16 valid_00004-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 7678 May 7 11:16 valid_00005-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 7762 May 7 11:16 valid_00006-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 7720 May 7 11:16 valid_00007-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 7727 May 7 11:16 valid_00008-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 7739 May 7 11:16 valid_00009-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 7762 May 7 11:16 valid_00010-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 7727 May 7 11:16 valid_00011-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 7729 May 7 11:16 valid_00012-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 7763 May 7 11:16 valid_00013-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 7727 May 7 11:16 valid_00014-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 7749 May 7 11:16 valid_00015-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 7741 May 7 11:16 valid_00016-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 7753 May 7 11:16 valid_00017-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 7702 May 7 11:16 valid_00018-of-00020\r\n",
- "-rw-rw-r-- 1 luke luke 7711 May 7 11:16 valid_00019-of-00020\r\n"
- ]
- }
- ],
- "source": [
- "!ls -l generate_tfrecords_zip"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "['generate_tfrecords/train_00000-of-00020',\n",
- " 'generate_tfrecords/train_00001-of-00020',\n",
- " 'generate_tfrecords/train_00002-of-00020',\n",
- " 'generate_tfrecords/train_00003-of-00020',\n",
- " 'generate_tfrecords/train_00004-of-00020',\n",
- " 'generate_tfrecords/train_00005-of-00020',\n",
- " 'generate_tfrecords/train_00006-of-00020',\n",
- " 'generate_tfrecords/train_00007-of-00020',\n",
- " 'generate_tfrecords/train_00008-of-00020',\n",
- " 'generate_tfrecords/train_00009-of-00020',\n",
- " 'generate_tfrecords/train_00010-of-00020',\n",
- " 'generate_tfrecords/train_00011-of-00020',\n",
- " 'generate_tfrecords/train_00012-of-00020',\n",
- " 'generate_tfrecords/train_00013-of-00020',\n",
- " 'generate_tfrecords/train_00014-of-00020',\n",
- " 'generate_tfrecords/train_00015-of-00020',\n",
- " 'generate_tfrecords/train_00016-of-00020',\n",
- " 'generate_tfrecords/train_00017-of-00020',\n",
- " 'generate_tfrecords/train_00018-of-00020',\n",
- " 'generate_tfrecords/train_00019-of-00020']\n",
- "['generate_tfrecords/valid_00000-of-00010',\n",
- " 'generate_tfrecords/valid_00001-of-00010',\n",
- " 'generate_tfrecords/valid_00002-of-00010',\n",
- " 'generate_tfrecords/valid_00003-of-00010',\n",
- " 'generate_tfrecords/valid_00004-of-00010',\n",
- " 'generate_tfrecords/valid_00005-of-00010',\n",
- " 'generate_tfrecords/valid_00006-of-00010',\n",
- " 'generate_tfrecords/valid_00007-of-00010',\n",
- " 'generate_tfrecords/valid_00008-of-00010',\n",
- " 'generate_tfrecords/valid_00009-of-00010']\n",
- "['generate_tfrecords/test_00000-of-00010',\n",
- " 'generate_tfrecords/test_00001-of-00010',\n",
- " 'generate_tfrecords/test_00002-of-00010',\n",
- " 'generate_tfrecords/test_00003-of-00010',\n",
- " 'generate_tfrecords/test_00004-of-00010',\n",
- " 'generate_tfrecords/test_00005-of-00010',\n",
- " 'generate_tfrecords/test_00006-of-00010',\n",
- " 'generate_tfrecords/test_00007-of-00010',\n",
- " 'generate_tfrecords/test_00008-of-00010',\n",
- " 'generate_tfrecords/test_00009-of-00010']\n"
- ]
- }
- ],
- "source": [
- "#打印一下文件名\n",
- "pprint.pprint(train_tfrecord_filenames)\n",
- "pprint.pprint(valid_tfrecord_filenames)\n",
- "pprint.pprint(test_tfrecord_fielnames)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "CPU times: user 58 µs, sys: 14 µs, total: 72 µs\n",
- "Wall time: 80.1 µs\n"
- ]
- }
- ],
- "source": [
- "%%time\n",
- "#把数据读取出来\n",
- "expected_features = {\n",
- " \"input_features\": tf.io.FixedLenFeature([8], dtype=tf.float32),\n",
- " \"label\": tf.io.FixedLenFeature([1], dtype=tf.float32)\n",
- "}\n",
- "\n",
- "def parse_example(serialized_example):\n",
- " example = tf.io.parse_single_example(serialized_example,\n",
- " expected_features)\n",
- " return example[\"input_features\"], example[\"label\"]\n",
- "\n",
- "def tfrecords_reader_dataset(filenames, n_readers=5,\n",
- " batch_size=32, n_parse_threads=5,\n",
- " shuffle_buffer_size=10000):\n",
- " dataset = tf.data.Dataset.list_files(filenames)\n",
- " dataset = dataset.repeat() #为了能够无限次epoch\n",
- " dataset = dataset.interleave(\n",
- "# lambda filename: tf.data.TFRecordDataset(\n",
- "# filename, compression_type = \"GZIP\"),\n",
- " lambda filename: tf.data.TFRecordDataset(\n",
- " filename),\n",
- " cycle_length = n_readers\n",
- " )\n",
- " #洗牌,就是给数据打乱,样本顺序打乱\n",
- " dataset.shuffle(shuffle_buffer_size)\n",
- " dataset = dataset.map(parse_example,\n",
- " num_parallel_calls=n_parse_threads)#把对应的一个样本是字节流的,变为浮点类型\n",
- " dataset = dataset.batch(batch_size) #原来写进去是一条一条的sample,要分配\n",
- " return dataset\n",
- "\n",
- "#测试一下,tfrecords_reader_dataset是否可以正常运行\n",
- "# tfrecords_train = tfrecords_reader_dataset(train_tfrecord_filenames,\n",
- "# batch_size = 3)\n",
- "# for x_batch, y_batch in tfrecords_train.take(10):\n",
- "# print(x_batch)\n",
- "# print(y_batch)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {
- "collapsed": true
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "WARNING:tensorflow:AutoGraph could not transform <function tfrecords_reader_dataset.<locals>.<lambda> at 0x7f98284712f0> and will run it as-is.\n",
- "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
- "Cause: Unable to locate the source code of <function tfrecords_reader_dataset.<locals>.<lambda> at 0x7f98284712f0>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code\n",
- "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
- "WARNING: AutoGraph could not transform <function tfrecords_reader_dataset.<locals>.<lambda> at 0x7f98284712f0> and will run it as-is.\n",
- "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
- "Cause: Unable to locate the source code of <function tfrecords_reader_dataset.<locals>.<lambda> at 0x7f98284712f0>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code\n",
- "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
- "WARNING:tensorflow:AutoGraph could not transform <function parse_example at 0x7f98284710d0> and will run it as-is.\n",
- "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
- "Cause: Unable to locate the source code of <function parse_example at 0x7f98284710d0>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code\n",
- "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
- "WARNING: AutoGraph could not transform <function parse_example at 0x7f98284710d0> and will run it as-is.\n",
- "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
- "Cause: Unable to locate the source code of <function parse_example at 0x7f98284710d0>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code\n",
- "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
- "WARNING:tensorflow:AutoGraph could not transform <function tfrecords_reader_dataset.<locals>.<lambda> at 0x7f9828471378> and will run it as-is.\n",
- "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
- "Cause: Unable to locate the source code of <function tfrecords_reader_dataset.<locals>.<lambda> at 0x7f9828471378>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code\n",
- "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
- "WARNING: AutoGraph could not transform <function tfrecords_reader_dataset.<locals>.<lambda> at 0x7f9828471378> and will run it as-is.\n",
- "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
- "Cause: Unable to locate the source code of <function tfrecords_reader_dataset.<locals>.<lambda> at 0x7f9828471378>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code\n",
- "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
- "WARNING:tensorflow:AutoGraph could not transform <function tfrecords_reader_dataset.<locals>.<lambda> at 0x7f9828471598> and will run it as-is.\n",
- "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
- "Cause: Unable to locate the source code of <function tfrecords_reader_dataset.<locals>.<lambda> at 0x7f9828471598>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code\n",
- "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
- "WARNING: AutoGraph could not transform <function tfrecords_reader_dataset.<locals>.<lambda> at 0x7f9828471598> and will run it as-is.\n",
- "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
- "Cause: Unable to locate the source code of <function tfrecords_reader_dataset.<locals>.<lambda> at 0x7f9828471598>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code\n",
- "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
- "CPU times: user 218 ms, sys: 83.8 ms, total: 302 ms\n",
- "Wall time: 294 ms\n"
- ]
- }
- ],
- "source": [
- "%%time\n",
- "#得到dataset,dataset是tensor,可以直接拿tensor训练\n",
- "\n",
- "batch_size = 32\n",
- "tfrecords_train_set = tfrecords_reader_dataset(\n",
- " train_tfrecord_filenames, batch_size = batch_size)\n",
- "tfrecords_valid_set = tfrecords_reader_dataset(\n",
- " valid_tfrecord_filenames, batch_size = batch_size)\n",
- "tfrecords_test_set = tfrecords_reader_dataset(\n",
- " test_tfrecord_fielnames, batch_size = batch_size)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensorflow.python.data.ops.dataset_ops.BatchDataset"
- ]
- },
- "execution_count": 17,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "type(tfrecords_train_set)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "(<tf.Tensor: shape=(32, 8), dtype=float32, numpy=\n",
- "array([[-1.33849168e+00, 1.15326405e+00, -5.64006925e-01,\n",
- " -1.73587687e-02, -1.50110144e-02, 1.30692095e-01,\n",
- " -7.62392581e-01, 6.61608279e-01],\n",
- " [ 1.69143653e+00, -4.48740691e-01, 1.74070787e+00,\n",
- " 2.62504518e-01, 3.60605478e-01, 3.49154733e-02,\n",
- " -9.30357397e-01, 8.06418836e-01],\n",
- " [-3.68034393e-01, -1.00944233e+00, 9.95716763e+00,\n",
- " 8.32341957e+00, -1.11282730e+00, -1.44638717e-01,\n",
- " 1.34183347e+00, -2.12248623e-01],\n",
- " [ 6.85438991e-01, -5.28840959e-01, 2.43798941e-01,\n",
- " -1.07175767e-01, -3.28932047e-01, -7.14625940e-02,\n",
- " -5.66433609e-01, -9.73988622e-02],\n",
- " [ 2.22390127e+00, -1.00944233e+00, 1.12938309e+00,\n",
- " 6.29329979e-01, -4.16031510e-01, 3.57014686e-01,\n",
- " -7.34398425e-01, 3.57006729e-01],\n",
- " [-3.75547409e-01, 7.52762854e-01, -5.35495043e-01,\n",
- " -6.88404664e-02, 8.90460610e-01, 8.53771530e-03,\n",
- " 9.73243952e-01, -1.45062864e+00],\n",
- " [ 1.26937580e+00, -1.49004376e+00, 1.30503133e-01,\n",
- " -1.33789301e-01, 6.90857649e-01, 1.06014162e-01,\n",
- " 7.58622229e-01, -1.11107290e+00],\n",
- " [-4.36877012e-01, -1.32984328e+00, 2.76863605e-01,\n",
- " 2.20493942e-01, -1.10647631e+00, -9.77752283e-02,\n",
- " -7.95052409e-01, 1.43060231e+00],\n",
- " [-5.33587039e-01, 9.93063569e-01, -2.71614999e-01,\n",
- " -1.86064959e-01, -3.83369207e-01, -8.69813338e-02,\n",
- " -1.04530334e-01, 2.62130827e-01],\n",
- " [ 7.65653625e-02, 6.72662616e-01, 1.92134786e+00,\n",
- " 9.22839880e-01, -1.18722475e+00, 2.84929629e-02,\n",
- " 9.35918450e-01, -7.26575792e-01],\n",
- " [ 9.79352236e-01, 1.87416613e+00, 1.42585576e-01,\n",
- " -1.65525392e-01, 1.85499236e-01, -7.59028569e-02,\n",
- " 9.77909684e-01, -1.44563520e+00],\n",
- " [-4.48919147e-01, -2.88540244e-01, -4.44056481e-01,\n",
- " 5.86968623e-02, 9.30381179e-01, -3.20860595e-02,\n",
- " -6.78410172e-01, 5.91699719e-01],\n",
- " [ 9.33741331e-01, -4.48740691e-01, 1.85706854e-01,\n",
- " -6.31846488e-02, -1.69249669e-01, 1.85529351e-01,\n",
- " -6.36418939e-01, 5.56745410e-01],\n",
- " [-9.18775439e-01, 4.32361901e-01, -5.84726632e-01,\n",
- " -5.63860834e-02, 3.48897241e-02, 5.70803955e-02,\n",
- " -1.36893225e+00, 1.21588326e+00],\n",
- " [-6.18148386e-01, 1.87416613e+00, -4.37868014e-02,\n",
- " -8.39404464e-02, -6.44667625e-01, -4.46900912e-02,\n",
- " 9.96572435e-01, -1.34576583e+00],\n",
- " [-3.01620234e-02, -2.08439991e-01, 3.14824611e-01,\n",
- " 1.60544395e-01, -4.75005120e-01, -2.96259038e-02,\n",
- " 1.71975434e+00, -3.12117994e-01],\n",
- " [-9.78879511e-01, -6.89041436e-01, -5.42888701e-01,\n",
- " 3.03960383e-01, 3.80565763e-01, 2.37069577e-02,\n",
- " -1.35026944e+00, 1.23585713e+00],\n",
- " [-9.83834922e-01, -6.08941197e-01, -3.04150850e-01,\n",
- " -9.17670801e-02, 2.08181381e-01, -4.22854684e-02,\n",
- " 2.23298025e+00, -1.35575283e+00],\n",
- " [-5.00178158e-01, -1.08954263e+00, 7.68675879e-02,\n",
- " 5.26065975e-02, 4.82181817e-01, -2.72741280e-02,\n",
- " 1.52379537e+00, -6.51673794e-01],\n",
- " [-3.39900553e-01, 2.72161424e-01, -4.90426540e-01,\n",
- " 7.07200915e-02, 2.49916553e-01, -9.07782912e-02,\n",
- " 9.96572435e-01, -1.45562208e+00],\n",
- " [ 1.51482344e-01, 3.18607129e-02, -2.61826307e-01,\n",
- " -1.51025891e-01, -8.21501911e-02, 2.73460269e-01,\n",
- " -6.36418939e-01, 5.76719284e-01],\n",
- " [-9.25968707e-01, -4.82395217e-02, -1.45638064e-01,\n",
- " -3.07705432e-01, -8.87820303e-01, -2.56793760e-02,\n",
- " 1.42115021e+00, -9.66262281e-01],\n",
- " [ 6.59276664e-01, -2.08439991e-01, 2.04598561e-01,\n",
- " -2.05626473e-01, -8.42456043e-01, -6.28024265e-02,\n",
- " 1.40715313e+00, -8.71386409e-01],\n",
- " [-1.23248763e-01, 1.87416613e+00, 2.59033680e-01,\n",
- " -1.21581666e-01, -5.81157565e-01, -7.85437226e-02,\n",
- " 1.08055484e+00, -8.66392910e-01],\n",
- " [ 3.76925975e-01, -1.81044471e+00, -2.83305883e-01,\n",
- " -5.30055724e-02, 1.07472621e-01, -1.56231388e-01,\n",
- " -1.34093809e+00, 1.20090282e+00],\n",
- " [-5.23303270e-01, -1.32984328e+00, -1.37612730e-01,\n",
- " -1.91330597e-01, 2.53545702e-01, -1.70176193e-01,\n",
- " -1.16830754e+00, 1.18092895e+00],\n",
- " [-3.61214072e-01, 5.12462139e-01, -6.50760174e-01,\n",
- " -1.81008682e-01, -3.18951875e-01, -1.45220742e-01,\n",
- " 9.17255700e-01, -1.41068089e+00],\n",
- " [ 3.42984200e-01, 1.92061186e-01, 2.13778451e-01,\n",
- " -2.96686888e-01, -5.69362879e-01, -4.00676690e-02,\n",
- " -1.14031339e+00, 1.17593551e+00],\n",
- " [ 1.50238574e+00, -1.57014406e+00, 3.20324659e-01,\n",
- " -5.00563085e-02, -3.75203639e-01, 9.73348692e-02,\n",
- " -8.51040661e-01, 8.56353521e-01],\n",
- " [-7.03136027e-01, 1.39356470e+00, -7.07432747e-01,\n",
- " -4.66516092e-02, 2.47186041e+00, 1.57230482e-01,\n",
- " -7.15735674e-01, 6.86575592e-01],\n",
- " [-7.26261199e-01, 4.32361901e-01, -1.33732617e-01,\n",
- " -7.91970268e-02, 4.94928146e-03, -1.02592900e-01,\n",
- " -5.61767936e-01, 1.72248408e-01],\n",
- " [-3.96328092e-01, -2.08439991e-01, -1.22804724e-01,\n",
- " -1.11505985e-01, 1.06738138e+00, 9.11513790e-02,\n",
- " 1.29517651e+00, -1.57047188e+00]], dtype=float32)>, <tf.Tensor: shape=(32, 1), dtype=float32, numpy=\n",
- "array([[1.072 ],\n",
- " [3.506 ],\n",
- " [1.406 ],\n",
- " [3.425 ],\n",
- " [5.00001],\n",
- " [3.5 ],\n",
- " [2.607 ],\n",
- " [0.663 ],\n",
- " [0.57 ],\n",
- " [1.5 ],\n",
- " [3.636 ],\n",
- " [2.042 ],\n",
- " [2.176 ],\n",
- " [1.394 ],\n",
- " [2.079 ],\n",
- " [1.095 ],\n",
- " [0.992 ],\n",
- " [0.598 ],\n",
- " [1.273 ],\n",
- " [3.229 ],\n",
- " [1.5 ],\n",
- " [1.086 ],\n",
- " [1.634 ],\n",
- " [1.184 ],\n",
- " [1.648 ],\n",
- " [1.633 ],\n",
- " [3.184 ],\n",
- " [2.055 ],\n",
- " [2.309 ],\n",
- " [1.629 ],\n",
- " [2.603 ],\n",
- " [1.438 ]], dtype=float32)>)\n"
- ]
- }
- ],
- "source": [
- "for i in tfrecords_train_set.take(1):\n",
- " print(i)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 1/100\n",
- "348/348 [==============================] - 1s 3ms/step - loss: 0.8099 - val_loss: 0.6248\n",
- "Epoch 2/100\n",
- "348/348 [==============================] - 1s 3ms/step - loss: 0.5202 - val_loss: 0.5199\n",
- "Epoch 3/100\n",
- "348/348 [==============================] - 1s 3ms/step - loss: 0.4712 - val_loss: 0.4874\n",
- "Epoch 4/100\n",
- "348/348 [==============================] - 1s 3ms/step - loss: 0.4509 - val_loss: 0.4747\n",
- "Epoch 5/100\n",
- "348/348 [==============================] - 1s 2ms/step - loss: 0.4298 - val_loss: 0.4615\n",
- "Epoch 6/100\n",
- "348/348 [==============================] - 1s 2ms/step - loss: 0.4159 - val_loss: 0.4296\n",
- "Epoch 7/100\n",
- "348/348 [==============================] - 1s 2ms/step - loss: 0.4033 - val_loss: 0.4194\n",
- "Epoch 8/100\n",
- "348/348 [==============================] - 1s 2ms/step - loss: 0.4042 - val_loss: 0.4123\n",
- "Epoch 9/100\n",
- "348/348 [==============================] - 1s 2ms/step - loss: 0.5006 - val_loss: 0.4300\n",
- "Epoch 10/100\n",
- "348/348 [==============================] - 1s 2ms/step - loss: 0.3920 - val_loss: 0.4135\n",
- "Epoch 11/100\n",
- "348/348 [==============================] - 1s 3ms/step - loss: 0.3976 - val_loss: 0.4100\n",
- "Epoch 12/100\n",
- "348/348 [==============================] - 1s 2ms/step - loss: 0.3836 - val_loss: 0.3966\n",
- "Epoch 13/100\n",
- "348/348 [==============================] - 1s 2ms/step - loss: 0.3744 - val_loss: 0.3917\n",
- "Epoch 14/100\n",
- "348/348 [==============================] - 1s 2ms/step - loss: 0.4394 - val_loss: 0.4169\n",
- "Epoch 15/100\n",
- "348/348 [==============================] - 1s 2ms/step - loss: 0.3968 - val_loss: 0.3938\n",
- "Epoch 16/100\n",
- "348/348 [==============================] - 1s 2ms/step - loss: 0.3682 - val_loss: 0.3880\n",
- "Epoch 17/100\n",
- "348/348 [==============================] - 1s 2ms/step - loss: 0.3709 - val_loss: 0.3835\n",
- "Epoch 18/100\n",
- "348/348 [==============================] - 1s 2ms/step - loss: 0.3666 - val_loss: 0.3795\n",
- "Epoch 19/100\n",
- "348/348 [==============================] - 1s 2ms/step - loss: 0.3692 - val_loss: 0.3756\n",
- "Epoch 20/100\n",
- "348/348 [==============================] - 1s 2ms/step - loss: 0.3587 - val_loss: 0.3736\n",
- "Epoch 21/100\n",
- "348/348 [==============================] - 1s 2ms/step - loss: 0.3554 - val_loss: 0.3765\n",
- "Epoch 22/100\n",
- "348/348 [==============================] - 1s 2ms/step - loss: 0.3619 - val_loss: 0.3732\n",
- "Epoch 23/100\n",
- "348/348 [==============================] - 1s 2ms/step - loss: 0.3529 - val_loss: 0.4280\n",
- "Epoch 24/100\n",
- "348/348 [==============================] - 1s 2ms/step - loss: 0.3537 - val_loss: 0.3658\n",
- "Epoch 25/100\n",
- "348/348 [==============================] - 1s 2ms/step - loss: 0.3515 - val_loss: 0.3704\n",
- "Epoch 26/100\n",
- "348/348 [==============================] - 1s 2ms/step - loss: 0.3707 - val_loss: 0.3642\n",
- "Epoch 27/100\n",
- "348/348 [==============================] - 1s 2ms/step - loss: 0.3512 - val_loss: 0.3651\n"
- ]
- }
- ],
- "source": [
- "#开始训练\n",
- "model = keras.models.Sequential([\n",
- " keras.layers.Dense(30, activation='relu',\n",
- " input_shape=[8]),\n",
- " keras.layers.Dense(1),\n",
- "])\n",
- "model.compile(loss=\"mean_squared_error\", optimizer=\"sgd\")\n",
- "callbacks = [keras.callbacks.EarlyStopping(\n",
- " patience=5, min_delta=1e-2)]\n",
- "\n",
- "history = model.fit(tfrecords_train_set,\n",
- " validation_data = tfrecords_valid_set,\n",
- " steps_per_epoch = 11160 // batch_size,\n",
- " validation_steps = 3870 // batch_size,\n",
- " epochs = 100,\n",
- " callbacks = callbacks)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 20,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "161/161 [==============================] - 0s 2ms/step - loss: 0.3376\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "0.33755674958229065"
- ]
- },
- "execution_count": 20,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "model.evaluate(tfrecords_test_set, steps = 5160 // batch_size)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.6.9"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
- }
|