{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.2.0\n", "sys.version_info(major=3, minor=6, micro=9, releaselevel='final', serial=0)\n", "matplotlib 3.3.4\n", "numpy 1.19.5\n", "pandas 1.1.5\n", "sklearn 0.24.2\n", "tensorflow 2.2.0\n", "tensorflow.keras 2.3.0-tf\n" ] } ], "source": [ "import matplotlib as mpl\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "import numpy as np\n", "import sklearn\n", "import pandas as pd\n", "import os\n", "import sys\n", "import time\n", "import tensorflow as tf\n", "\n", "from tensorflow import keras\n", "\n", "print(tf.__version__)\n", "print(sys.version_info)\n", "for module in mpl, np, pd, sklearn, tf, keras:\n", " print(module.__name__, module.__version__)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "test_00.csv test_08.csv train_06.csv train_14.csv valid_02.csv\r\n", "test_01.csv test_09.csv train_07.csv train_15.csv valid_03.csv\r\n", "test_02.csv train_00.csv train_08.csv train_16.csv valid_04.csv\r\n", "test_03.csv train_01.csv train_09.csv train_17.csv valid_05.csv\r\n", "test_04.csv train_02.csv train_10.csv train_18.csv valid_06.csv\r\n", "test_05.csv train_03.csv train_11.csv train_19.csv valid_07.csv\r\n", "test_06.csv train_04.csv train_12.csv valid_00.csv valid_08.csv\r\n", "test_07.csv train_05.csv train_13.csv valid_01.csv valid_09.csv\r\n" ] } ], "source": [ "!ls generate_csv" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['./generate_csv/train_08.csv',\n", " './generate_csv/train_11.csv',\n", " './generate_csv/train_18.csv',\n", " './generate_csv/train_15.csv',\n", " './generate_csv/train_17.csv',\n", " './generate_csv/train_00.csv',\n", " './generate_csv/train_01.csv',\n", " './generate_csv/train_19.csv',\n", " './generate_csv/train_14.csv',\n", " './generate_csv/train_02.csv',\n", " './generate_csv/train_16.csv',\n", " './generate_csv/train_09.csv',\n", " './generate_csv/train_03.csv',\n", " './generate_csv/train_12.csv',\n", " './generate_csv/train_10.csv',\n", " './generate_csv/train_13.csv',\n", " './generate_csv/train_05.csv',\n", " './generate_csv/train_07.csv',\n", " './generate_csv/train_04.csv',\n", " './generate_csv/train_06.csv']\n", "['./generate_csv/valid_01.csv',\n", " './generate_csv/valid_05.csv',\n", " './generate_csv/valid_02.csv',\n", " './generate_csv/valid_04.csv',\n", " './generate_csv/valid_08.csv',\n", " './generate_csv/valid_07.csv',\n", " './generate_csv/valid_06.csv',\n", " './generate_csv/valid_00.csv',\n", " './generate_csv/valid_09.csv',\n", " './generate_csv/valid_03.csv']\n", "['./generate_csv/test_00.csv',\n", " './generate_csv/test_07.csv',\n", " './generate_csv/test_01.csv',\n", " './generate_csv/test_08.csv',\n", " './generate_csv/test_06.csv',\n", " './generate_csv/test_02.csv',\n", " './generate_csv/test_04.csv',\n", " './generate_csv/test_05.csv',\n", " './generate_csv/test_09.csv',\n", " './generate_csv/test_03.csv']\n" ] } ], "source": [ "source_dir = \"./generate_csv/\"\n", "\n", "#通过判断开头去添加文件\n", "def get_filenames_by_prefix(source_dir, prefix_name):\n", " all_files = os.listdir(source_dir)\n", " results = []\n", " for filename in all_files:\n", " if filename.startswith(prefix_name):\n", " results.append(os.path.join(source_dir, filename))\n", " return results\n", "\n", "train_filenames = get_filenames_by_prefix(source_dir, \"train\")\n", "valid_filenames = get_filenames_by_prefix(source_dir, \"valid\")\n", "test_filenames = get_filenames_by_prefix(source_dir, \"test\")\n", "\n", "import pprint\n", "pprint.pprint(train_filenames)\n", "pprint.pprint(valid_filenames)\n", "pprint.pprint(test_filenames)\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "#下面的接口都是之前用过的\n", "def parse_csv_line(line, n_fields = 9):\n", " defs = [tf.constant(np.nan)] * n_fields\n", " parsed_fields = tf.io.decode_csv(line, record_defaults=defs)\n", " x = tf.stack(parsed_fields[0:-1])\n", " y = tf.stack(parsed_fields[-1:])\n", " return x, y\n", "\n", "def csv_reader_dataset(filenames, n_readers=5,\n", " batch_size=32, n_parse_threads=5,\n", " shuffle_buffer_size=10000):\n", " dataset = tf.data.Dataset.list_files(filenames)\n", " dataset = dataset.repeat()\n", " dataset = dataset.interleave(\n", " lambda filename: tf.data.TextLineDataset(filename).skip(1),\n", " cycle_length = n_readers\n", " )\n", " dataset.shuffle(shuffle_buffer_size)\n", " dataset = dataset.map(parse_csv_line,\n", " num_parallel_calls=n_parse_threads)\n", " dataset = dataset.batch(batch_size)\n", " return dataset\n", "\n", "batch_size = 32\n", "train_set = csv_reader_dataset(train_filenames,\n", " batch_size = batch_size)\n", "valid_set = csv_reader_dataset(valid_filenames,\n", " batch_size = batch_size)\n", "test_set = csv_reader_dataset(test_filenames,\n", " batch_size = batch_size)\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "chapter_4.tar.gz\t tf02_data_generate_csv.ipynb\r\n", "generate_csv\t\t tf03-tfrecord_basic_api.ipynb\r\n", "generate_tfrecords\t tf04_data_generate_tfrecord.ipynb\r\n", "temp.csv\t\t tfrecord_basic\r\n", "tf01-dataset_basic_api.ipynb\r\n" ] } ], "source": [ "!ls" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 把train_set,valid_set,test_set 存储到tfrecord类型的文件中" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "#把基础的如何序列化的步骤搞到一个函数\n", "def serialize_example(x, y):\n", " \"\"\"Converts x, y to tf.train.Example and serialize\"\"\"\n", " input_feautres = tf.train.FloatList(value = x) #特征\n", " label = tf.train.FloatList(value = y)#标签\n", " features = tf.train.Features(\n", " feature = {\n", " \"input_features\": tf.train.Feature(\n", " float_list = input_feautres),\n", " \"label\": tf.train.Feature(float_list = label)\n", " }\n", " )\n", " #把features变为example\n", " example = tf.train.Example(features = features)\n", " return example.SerializeToString() #把example序列化\n", "#n_shards是存为多少个文件,steps_per_shard和 steps_per_epoch类似\n", "def csv_dataset_to_tfrecords(base_filename, dataset,\n", " n_shards, steps_per_shard,\n", " compression_type = None):\n", " #压缩文件类型\n", " options = tf.io.TFRecordOptions(\n", " compression_type = compression_type)\n", " all_filenames = []\n", " \n", " for shard_id in range(n_shards):\n", " filename_fullpath = '{}_{:05d}-of-{:05d}'.format(\n", " base_filename, shard_id, n_shards) #base_filename是一个前缀\n", " #打开文件\n", " with tf.io.TFRecordWriter(filename_fullpath, options) as writer:\n", " #取出数据,为什么skip,上一个文件写了前500行,下一个文件存后面的数据\n", " for x_batch, y_batch in dataset.skip(shard_id * steps_per_shard).take(steps_per_shard):\n", " for x_example, y_example in zip(x_batch, y_batch):\n", " writer.write(\n", " serialize_example(x_example, y_example))\n", " all_filenames.append(filename_fullpath)\n", " #返回所有tfrecord文件名\n", " return all_filenames" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "!rm -rf generate_tfrecords" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "collapsed": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(, )\n" ] } ], "source": [ "for i in train_set.take(1):\n", " print(i) " ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 40 s, sys: 8.95 s, total: 48.9 s\n", "Wall time: 42.6 s\n" ] } ], "source": [ "%%time\n", "# 训练集和测试集都分20\n", "n_shards = 20\n", "train_steps_per_shard = 11610 // batch_size // n_shards\n", "valid_steps_per_shard = 3880 // batch_size // 10\n", "test_steps_per_shard = 5170 // batch_size // 10\n", "\n", "output_dir = \"generate_tfrecords\"\n", "if not os.path.exists(output_dir):\n", " os.mkdir(output_dir)\n", "\n", "train_basename = os.path.join(output_dir, \"train\")\n", "valid_basename = os.path.join(output_dir, \"valid\")\n", "test_basename = os.path.join(output_dir, \"test\")\n", "\n", "train_tfrecord_filenames = csv_dataset_to_tfrecords(\n", " train_basename, train_set, n_shards, train_steps_per_shard, None)\n", "valid_tfrecord_filenames = csv_dataset_to_tfrecords(\n", " valid_basename, valid_set, 10, valid_steps_per_shard, None)\n", "test_tfrecord_fielnames = csv_dataset_to_tfrecords(\n", " test_basename, test_set, 10, test_steps_per_shard, None)\n", "#执行会发现目录下总计生成了60个文件,这里文件数目改为一致,为了对比时间" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "总用量 1960\r\n", "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00000-of-00010\r\n", "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00001-of-00010\r\n", "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00002-of-00010\r\n", "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00003-of-00010\r\n", "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00004-of-00010\r\n", "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00005-of-00010\r\n", "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00006-of-00010\r\n", "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00007-of-00010\r\n", "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00008-of-00010\r\n", "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00009-of-00010\r\n", "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:32 train_00000-of-00020\r\n", "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:32 train_00001-of-00020\r\n", "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:32 train_00002-of-00020\r\n", "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:32 train_00003-of-00020\r\n", "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:32 train_00004-of-00020\r\n", "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:32 train_00005-of-00020\r\n", "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:32 train_00006-of-00020\r\n", "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:32 train_00007-of-00020\r\n", "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:32 train_00008-of-00020\r\n", "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00009-of-00020\r\n", "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00010-of-00020\r\n", "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00011-of-00020\r\n", "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00012-of-00020\r\n", "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00013-of-00020\r\n", "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00014-of-00020\r\n", "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00015-of-00020\r\n", "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00016-of-00020\r\n", "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00017-of-00020\r\n", "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00018-of-00020\r\n", "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00019-of-00020\r\n", "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00000-of-00010\r\n", "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00001-of-00010\r\n", "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00002-of-00010\r\n", "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00003-of-00010\r\n", "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00004-of-00010\r\n", "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00005-of-00010\r\n", "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00006-of-00010\r\n", "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00007-of-00010\r\n", "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00008-of-00010\r\n", "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00009-of-00010\r\n" ] } ], "source": [ "!ls -l generate_tfrecords" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "#生成一下压缩的\n", "# n_shards = 20\n", "# train_steps_per_shard = 11610 // batch_size // n_shards\n", "# valid_steps_per_shard = 3880 // batch_size // n_shards\n", "# test_steps_per_shard = 5170 // batch_size // n_shards\n", "\n", "# output_dir = \"generate_tfrecords_zip\"\n", "# if not os.path.exists(output_dir):\n", "# os.mkdir(output_dir)\n", "\n", "# train_basename = os.path.join(output_dir, \"train\")\n", "# valid_basename = os.path.join(output_dir, \"valid\")\n", "# test_basename = os.path.join(output_dir, \"test\")\n", "# #只需修改参数的类型即可\n", "# train_tfrecord_filenames = csv_dataset_to_tfrecords(\n", "# train_basename, train_set, n_shards, train_steps_per_shard,\n", "# compression_type = \"GZIP\")\n", "# valid_tfrecord_filenames = csv_dataset_to_tfrecords(\n", "# valid_basename, valid_set, n_shards, valid_steps_per_shard,\n", "# compression_type = \"GZIP\")\n", "# test_tfrecord_fielnames = csv_dataset_to_tfrecords(\n", "# test_basename, test_set, n_shards, test_steps_per_shard,\n", "# compression_type = \"GZIP\")" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "总用量 860\r\n", "-rw-rw-r-- 1 luke luke 10171 May 7 11:16 test_00000-of-00020\r\n", "-rw-rw-r-- 1 luke luke 10230 May 7 11:16 test_00001-of-00020\r\n", "-rw-rw-r-- 1 luke luke 10204 May 7 11:16 test_00002-of-00020\r\n", "-rw-rw-r-- 1 luke luke 10213 May 7 11:16 test_00003-of-00020\r\n", "-rw-rw-r-- 1 luke luke 10229 May 7 11:16 test_00004-of-00020\r\n", "-rw-rw-r-- 1 luke luke 10200 May 7 11:16 test_00005-of-00020\r\n", "-rw-rw-r-- 1 luke luke 10199 May 7 11:16 test_00006-of-00020\r\n", "-rw-rw-r-- 1 luke luke 10215 May 7 11:16 test_00007-of-00020\r\n", "-rw-rw-r-- 1 luke luke 10179 May 7 11:16 test_00008-of-00020\r\n", "-rw-rw-r-- 1 luke luke 10149 May 7 11:16 test_00009-of-00020\r\n", "-rw-rw-r-- 1 luke luke 10141 May 7 11:16 test_00010-of-00020\r\n", "-rw-rw-r-- 1 luke luke 10221 May 7 11:16 test_00011-of-00020\r\n", "-rw-rw-r-- 1 luke luke 10209 May 7 11:16 test_00012-of-00020\r\n", "-rw-rw-r-- 1 luke luke 10214 May 7 11:16 test_00013-of-00020\r\n", "-rw-rw-r-- 1 luke luke 10212 May 7 11:16 test_00014-of-00020\r\n", "-rw-rw-r-- 1 luke luke 10209 May 7 11:16 test_00015-of-00020\r\n", "-rw-rw-r-- 1 luke luke 10185 May 7 11:16 test_00016-of-00020\r\n", "-rw-rw-r-- 1 luke luke 10266 May 7 11:16 test_00017-of-00020\r\n", "-rw-rw-r-- 1 luke luke 10258 May 7 11:16 test_00018-of-00020\r\n", "-rw-rw-r-- 1 luke luke 10170 May 7 11:16 test_00019-of-00020\r\n", "-rw-rw-r-- 1 luke luke 22359 May 7 19:17 train_00000-of-00020\r\n", "-rw-rw-r-- 1 luke luke 22447 May 7 19:17 train_00001-of-00020\r\n", "-rw-rw-r-- 1 luke luke 22366 May 7 19:17 train_00002-of-00020\r\n", "-rw-rw-r-- 1 luke luke 22311 May 7 19:17 train_00003-of-00020\r\n", "-rw-rw-r-- 1 luke luke 22384 May 7 19:17 train_00004-of-00020\r\n", "-rw-rw-r-- 1 luke luke 22341 May 7 19:17 train_00005-of-00020\r\n", "-rw-rw-r-- 1 luke luke 22416 May 7 19:17 train_00006-of-00020\r\n", "-rw-rw-r-- 1 luke luke 22285 May 7 19:17 train_00007-of-00020\r\n", "-rw-rw-r-- 1 luke luke 22415 May 7 19:17 train_00008-of-00020\r\n", "-rw-rw-r-- 1 luke luke 22365 May 7 19:17 train_00009-of-00020\r\n", "-rw-rw-r-- 1 luke luke 22431 May 7 19:17 train_00010-of-00020\r\n", "-rw-rw-r-- 1 luke luke 22367 May 7 19:17 train_00011-of-00020\r\n", "-rw-rw-r-- 1 luke luke 22346 May 7 19:17 train_00012-of-00020\r\n", "-rw-rw-r-- 1 luke luke 22332 May 7 19:17 train_00013-of-00020\r\n", "-rw-rw-r-- 1 luke luke 22452 May 7 19:17 train_00014-of-00020\r\n", "-rw-rw-r-- 1 luke luke 20 May 7 19:17 train_00015-of-00020\r\n", "-rw-rw-r-- 1 luke luke 22427 May 7 11:16 train_00016-of-00020\r\n", "-rw-rw-r-- 1 luke luke 22427 May 7 11:16 train_00017-of-00020\r\n", "-rw-rw-r-- 1 luke luke 22454 May 7 11:16 train_00018-of-00020\r\n", "-rw-rw-r-- 1 luke luke 22309 May 7 11:16 train_00019-of-00020\r\n", "-rw-rw-r-- 1 luke luke 7747 May 7 11:16 valid_00000-of-00020\r\n", "-rw-rw-r-- 1 luke luke 7744 May 7 11:16 valid_00001-of-00020\r\n", "-rw-rw-r-- 1 luke luke 7749 May 7 11:16 valid_00002-of-00020\r\n", "-rw-rw-r-- 1 luke luke 7755 May 7 11:16 valid_00003-of-00020\r\n", "-rw-rw-r-- 1 luke luke 7744 May 7 11:16 valid_00004-of-00020\r\n", "-rw-rw-r-- 1 luke luke 7678 May 7 11:16 valid_00005-of-00020\r\n", "-rw-rw-r-- 1 luke luke 7762 May 7 11:16 valid_00006-of-00020\r\n", "-rw-rw-r-- 1 luke luke 7720 May 7 11:16 valid_00007-of-00020\r\n", "-rw-rw-r-- 1 luke luke 7727 May 7 11:16 valid_00008-of-00020\r\n", "-rw-rw-r-- 1 luke luke 7739 May 7 11:16 valid_00009-of-00020\r\n", "-rw-rw-r-- 1 luke luke 7762 May 7 11:16 valid_00010-of-00020\r\n", "-rw-rw-r-- 1 luke luke 7727 May 7 11:16 valid_00011-of-00020\r\n", "-rw-rw-r-- 1 luke luke 7729 May 7 11:16 valid_00012-of-00020\r\n", "-rw-rw-r-- 1 luke luke 7763 May 7 11:16 valid_00013-of-00020\r\n", "-rw-rw-r-- 1 luke luke 7727 May 7 11:16 valid_00014-of-00020\r\n", "-rw-rw-r-- 1 luke luke 7749 May 7 11:16 valid_00015-of-00020\r\n", "-rw-rw-r-- 1 luke luke 7741 May 7 11:16 valid_00016-of-00020\r\n", "-rw-rw-r-- 1 luke luke 7753 May 7 11:16 valid_00017-of-00020\r\n", "-rw-rw-r-- 1 luke luke 7702 May 7 11:16 valid_00018-of-00020\r\n", "-rw-rw-r-- 1 luke luke 7711 May 7 11:16 valid_00019-of-00020\r\n" ] } ], "source": [ "!ls -l generate_tfrecords_zip" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['generate_tfrecords/train_00000-of-00020',\n", " 'generate_tfrecords/train_00001-of-00020',\n", " 'generate_tfrecords/train_00002-of-00020',\n", " 'generate_tfrecords/train_00003-of-00020',\n", " 'generate_tfrecords/train_00004-of-00020',\n", " 'generate_tfrecords/train_00005-of-00020',\n", " 'generate_tfrecords/train_00006-of-00020',\n", " 'generate_tfrecords/train_00007-of-00020',\n", " 'generate_tfrecords/train_00008-of-00020',\n", " 'generate_tfrecords/train_00009-of-00020',\n", " 'generate_tfrecords/train_00010-of-00020',\n", " 'generate_tfrecords/train_00011-of-00020',\n", " 'generate_tfrecords/train_00012-of-00020',\n", " 'generate_tfrecords/train_00013-of-00020',\n", " 'generate_tfrecords/train_00014-of-00020',\n", " 'generate_tfrecords/train_00015-of-00020',\n", " 'generate_tfrecords/train_00016-of-00020',\n", " 'generate_tfrecords/train_00017-of-00020',\n", " 'generate_tfrecords/train_00018-of-00020',\n", " 'generate_tfrecords/train_00019-of-00020']\n", "['generate_tfrecords/valid_00000-of-00010',\n", " 'generate_tfrecords/valid_00001-of-00010',\n", " 'generate_tfrecords/valid_00002-of-00010',\n", " 'generate_tfrecords/valid_00003-of-00010',\n", " 'generate_tfrecords/valid_00004-of-00010',\n", " 'generate_tfrecords/valid_00005-of-00010',\n", " 'generate_tfrecords/valid_00006-of-00010',\n", " 'generate_tfrecords/valid_00007-of-00010',\n", " 'generate_tfrecords/valid_00008-of-00010',\n", " 'generate_tfrecords/valid_00009-of-00010']\n", "['generate_tfrecords/test_00000-of-00010',\n", " 'generate_tfrecords/test_00001-of-00010',\n", " 'generate_tfrecords/test_00002-of-00010',\n", " 'generate_tfrecords/test_00003-of-00010',\n", " 'generate_tfrecords/test_00004-of-00010',\n", " 'generate_tfrecords/test_00005-of-00010',\n", " 'generate_tfrecords/test_00006-of-00010',\n", " 'generate_tfrecords/test_00007-of-00010',\n", " 'generate_tfrecords/test_00008-of-00010',\n", " 'generate_tfrecords/test_00009-of-00010']\n" ] } ], "source": [ "#打印一下文件名\n", "pprint.pprint(train_tfrecord_filenames)\n", "pprint.pprint(valid_tfrecord_filenames)\n", "pprint.pprint(test_tfrecord_fielnames)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 58 µs, sys: 14 µs, total: 72 µs\n", "Wall time: 80.1 µs\n" ] } ], "source": [ "%%time\n", "#把数据读取出来\n", "expected_features = {\n", " \"input_features\": tf.io.FixedLenFeature([8], dtype=tf.float32),\n", " \"label\": tf.io.FixedLenFeature([1], dtype=tf.float32)\n", "}\n", "\n", "def parse_example(serialized_example):\n", " example = tf.io.parse_single_example(serialized_example,\n", " expected_features)\n", " return example[\"input_features\"], example[\"label\"]\n", "\n", "def tfrecords_reader_dataset(filenames, n_readers=5,\n", " batch_size=32, n_parse_threads=5,\n", " shuffle_buffer_size=10000):\n", " dataset = tf.data.Dataset.list_files(filenames)\n", " dataset = dataset.repeat() #为了能够无限次epoch\n", " dataset = dataset.interleave(\n", "# lambda filename: tf.data.TFRecordDataset(\n", "# filename, compression_type = \"GZIP\"),\n", " lambda filename: tf.data.TFRecordDataset(\n", " filename),\n", " cycle_length = n_readers\n", " )\n", " #洗牌,就是给数据打乱,样本顺序打乱\n", " dataset.shuffle(shuffle_buffer_size)\n", " dataset = dataset.map(parse_example,\n", " num_parallel_calls=n_parse_threads)#把对应的一个样本是字节流的,变为浮点类型\n", " dataset = dataset.batch(batch_size) #原来写进去是一条一条的sample,要分配\n", " return dataset\n", "\n", "#测试一下,tfrecords_reader_dataset是否可以正常运行\n", "# tfrecords_train = tfrecords_reader_dataset(train_tfrecord_filenames,\n", "# batch_size = 3)\n", "# for x_batch, y_batch in tfrecords_train.take(10):\n", "# print(x_batch)\n", "# print(y_batch)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:AutoGraph could not transform . at 0x7f98284712f0> and will run it as-is.\n", "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n", "Cause: Unable to locate the source code of . at 0x7f98284712f0>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code\n", "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n", "WARNING: AutoGraph could not transform . at 0x7f98284712f0> and will run it as-is.\n", "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n", "Cause: Unable to locate the source code of . at 0x7f98284712f0>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code\n", "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n", "WARNING:tensorflow:AutoGraph could not transform and will run it as-is.\n", "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n", "Cause: Unable to locate the source code of . Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code\n", "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n", "WARNING: AutoGraph could not transform and will run it as-is.\n", "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n", "Cause: Unable to locate the source code of . Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code\n", "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n", "WARNING:tensorflow:AutoGraph could not transform . at 0x7f9828471378> and will run it as-is.\n", "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n", "Cause: Unable to locate the source code of . at 0x7f9828471378>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code\n", "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n", "WARNING: AutoGraph could not transform . at 0x7f9828471378> and will run it as-is.\n", "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n", "Cause: Unable to locate the source code of . at 0x7f9828471378>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code\n", "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n", "WARNING:tensorflow:AutoGraph could not transform . at 0x7f9828471598> and will run it as-is.\n", "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n", "Cause: Unable to locate the source code of . at 0x7f9828471598>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code\n", "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n", "WARNING: AutoGraph could not transform . at 0x7f9828471598> and will run it as-is.\n", "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n", "Cause: Unable to locate the source code of . at 0x7f9828471598>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code\n", "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n", "CPU times: user 218 ms, sys: 83.8 ms, total: 302 ms\n", "Wall time: 294 ms\n" ] } ], "source": [ "%%time\n", "#得到dataset,dataset是tensor,可以直接拿tensor训练\n", "\n", "batch_size = 32\n", "tfrecords_train_set = tfrecords_reader_dataset(\n", " train_tfrecord_filenames, batch_size = batch_size)\n", "tfrecords_valid_set = tfrecords_reader_dataset(\n", " valid_tfrecord_filenames, batch_size = batch_size)\n", "tfrecords_test_set = tfrecords_reader_dataset(\n", " test_tfrecord_fielnames, batch_size = batch_size)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensorflow.python.data.ops.dataset_ops.BatchDataset" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(tfrecords_train_set)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(, )\n" ] } ], "source": [ "for i in tfrecords_train_set.take(1):\n", " print(i)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1/100\n", "348/348 [==============================] - 1s 3ms/step - loss: 0.8099 - val_loss: 0.6248\n", "Epoch 2/100\n", "348/348 [==============================] - 1s 3ms/step - loss: 0.5202 - val_loss: 0.5199\n", "Epoch 3/100\n", "348/348 [==============================] - 1s 3ms/step - loss: 0.4712 - val_loss: 0.4874\n", "Epoch 4/100\n", "348/348 [==============================] - 1s 3ms/step - loss: 0.4509 - val_loss: 0.4747\n", "Epoch 5/100\n", "348/348 [==============================] - 1s 2ms/step - loss: 0.4298 - val_loss: 0.4615\n", "Epoch 6/100\n", "348/348 [==============================] - 1s 2ms/step - loss: 0.4159 - val_loss: 0.4296\n", "Epoch 7/100\n", "348/348 [==============================] - 1s 2ms/step - loss: 0.4033 - val_loss: 0.4194\n", "Epoch 8/100\n", "348/348 [==============================] - 1s 2ms/step - loss: 0.4042 - val_loss: 0.4123\n", "Epoch 9/100\n", "348/348 [==============================] - 1s 2ms/step - loss: 0.5006 - val_loss: 0.4300\n", "Epoch 10/100\n", "348/348 [==============================] - 1s 2ms/step - loss: 0.3920 - val_loss: 0.4135\n", "Epoch 11/100\n", "348/348 [==============================] - 1s 3ms/step - loss: 0.3976 - val_loss: 0.4100\n", "Epoch 12/100\n", "348/348 [==============================] - 1s 2ms/step - loss: 0.3836 - val_loss: 0.3966\n", "Epoch 13/100\n", "348/348 [==============================] - 1s 2ms/step - loss: 0.3744 - val_loss: 0.3917\n", "Epoch 14/100\n", "348/348 [==============================] - 1s 2ms/step - loss: 0.4394 - val_loss: 0.4169\n", "Epoch 15/100\n", "348/348 [==============================] - 1s 2ms/step - loss: 0.3968 - val_loss: 0.3938\n", "Epoch 16/100\n", "348/348 [==============================] - 1s 2ms/step - loss: 0.3682 - val_loss: 0.3880\n", "Epoch 17/100\n", "348/348 [==============================] - 1s 2ms/step - loss: 0.3709 - val_loss: 0.3835\n", "Epoch 18/100\n", "348/348 [==============================] - 1s 2ms/step - loss: 0.3666 - val_loss: 0.3795\n", "Epoch 19/100\n", "348/348 [==============================] - 1s 2ms/step - loss: 0.3692 - val_loss: 0.3756\n", "Epoch 20/100\n", "348/348 [==============================] - 1s 2ms/step - loss: 0.3587 - val_loss: 0.3736\n", "Epoch 21/100\n", "348/348 [==============================] - 1s 2ms/step - loss: 0.3554 - val_loss: 0.3765\n", "Epoch 22/100\n", "348/348 [==============================] - 1s 2ms/step - loss: 0.3619 - val_loss: 0.3732\n", "Epoch 23/100\n", "348/348 [==============================] - 1s 2ms/step - loss: 0.3529 - val_loss: 0.4280\n", "Epoch 24/100\n", "348/348 [==============================] - 1s 2ms/step - loss: 0.3537 - val_loss: 0.3658\n", "Epoch 25/100\n", "348/348 [==============================] - 1s 2ms/step - loss: 0.3515 - val_loss: 0.3704\n", "Epoch 26/100\n", "348/348 [==============================] - 1s 2ms/step - loss: 0.3707 - val_loss: 0.3642\n", "Epoch 27/100\n", "348/348 [==============================] - 1s 2ms/step - loss: 0.3512 - val_loss: 0.3651\n" ] } ], "source": [ "#开始训练\n", "model = keras.models.Sequential([\n", " keras.layers.Dense(30, activation='relu',\n", " input_shape=[8]),\n", " keras.layers.Dense(1),\n", "])\n", "model.compile(loss=\"mean_squared_error\", optimizer=\"sgd\")\n", "callbacks = [keras.callbacks.EarlyStopping(\n", " patience=5, min_delta=1e-2)]\n", "\n", "history = model.fit(tfrecords_train_set,\n", " validation_data = tfrecords_valid_set,\n", " steps_per_epoch = 11160 // batch_size,\n", " validation_steps = 3870 // batch_size,\n", " epochs = 100,\n", " callbacks = callbacks)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "161/161 [==============================] - 0s 2ms/step - loss: 0.3376\n" ] }, { "data": { "text/plain": [ "0.33755674958229065" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.evaluate(tfrecords_test_set, steps = 5160 // batch_size)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.9" } }, "nbformat": 4, "nbformat_minor": 2 }