You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tf04_data_generate_tfrecord.ipynb 53 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 1,
  6. "metadata": {},
  7. "outputs": [
  8. {
  9. "name": "stdout",
  10. "output_type": "stream",
  11. "text": [
  12. "2.2.0\n",
  13. "sys.version_info(major=3, minor=6, micro=9, releaselevel='final', serial=0)\n",
  14. "matplotlib 3.3.4\n",
  15. "numpy 1.19.5\n",
  16. "pandas 1.1.5\n",
  17. "sklearn 0.24.2\n",
  18. "tensorflow 2.2.0\n",
  19. "tensorflow.keras 2.3.0-tf\n"
  20. ]
  21. }
  22. ],
  23. "source": [
  24. "import matplotlib as mpl\n",
  25. "import matplotlib.pyplot as plt\n",
  26. "%matplotlib inline\n",
  27. "import numpy as np\n",
  28. "import sklearn\n",
  29. "import pandas as pd\n",
  30. "import os\n",
  31. "import sys\n",
  32. "import time\n",
  33. "import tensorflow as tf\n",
  34. "\n",
  35. "from tensorflow import keras\n",
  36. "\n",
  37. "print(tf.__version__)\n",
  38. "print(sys.version_info)\n",
  39. "for module in mpl, np, pd, sklearn, tf, keras:\n",
  40. " print(module.__name__, module.__version__)"
  41. ]
  42. },
  43. {
  44. "cell_type": "code",
  45. "execution_count": 2,
  46. "metadata": {},
  47. "outputs": [
  48. {
  49. "name": "stdout",
  50. "output_type": "stream",
  51. "text": [
  52. "test_00.csv test_08.csv train_06.csv train_14.csv valid_02.csv\r\n",
  53. "test_01.csv test_09.csv train_07.csv train_15.csv valid_03.csv\r\n",
  54. "test_02.csv train_00.csv train_08.csv train_16.csv valid_04.csv\r\n",
  55. "test_03.csv train_01.csv train_09.csv train_17.csv valid_05.csv\r\n",
  56. "test_04.csv train_02.csv train_10.csv train_18.csv valid_06.csv\r\n",
  57. "test_05.csv train_03.csv train_11.csv train_19.csv valid_07.csv\r\n",
  58. "test_06.csv train_04.csv train_12.csv valid_00.csv valid_08.csv\r\n",
  59. "test_07.csv train_05.csv train_13.csv valid_01.csv valid_09.csv\r\n"
  60. ]
  61. }
  62. ],
  63. "source": [
  64. "!ls generate_csv"
  65. ]
  66. },
  67. {
  68. "cell_type": "code",
  69. "execution_count": 2,
  70. "metadata": {},
  71. "outputs": [
  72. {
  73. "name": "stdout",
  74. "output_type": "stream",
  75. "text": [
  76. "['./generate_csv/train_08.csv',\n",
  77. " './generate_csv/train_11.csv',\n",
  78. " './generate_csv/train_18.csv',\n",
  79. " './generate_csv/train_15.csv',\n",
  80. " './generate_csv/train_17.csv',\n",
  81. " './generate_csv/train_00.csv',\n",
  82. " './generate_csv/train_01.csv',\n",
  83. " './generate_csv/train_19.csv',\n",
  84. " './generate_csv/train_14.csv',\n",
  85. " './generate_csv/train_02.csv',\n",
  86. " './generate_csv/train_16.csv',\n",
  87. " './generate_csv/train_09.csv',\n",
  88. " './generate_csv/train_03.csv',\n",
  89. " './generate_csv/train_12.csv',\n",
  90. " './generate_csv/train_10.csv',\n",
  91. " './generate_csv/train_13.csv',\n",
  92. " './generate_csv/train_05.csv',\n",
  93. " './generate_csv/train_07.csv',\n",
  94. " './generate_csv/train_04.csv',\n",
  95. " './generate_csv/train_06.csv']\n",
  96. "['./generate_csv/valid_01.csv',\n",
  97. " './generate_csv/valid_05.csv',\n",
  98. " './generate_csv/valid_02.csv',\n",
  99. " './generate_csv/valid_04.csv',\n",
  100. " './generate_csv/valid_08.csv',\n",
  101. " './generate_csv/valid_07.csv',\n",
  102. " './generate_csv/valid_06.csv',\n",
  103. " './generate_csv/valid_00.csv',\n",
  104. " './generate_csv/valid_09.csv',\n",
  105. " './generate_csv/valid_03.csv']\n",
  106. "['./generate_csv/test_00.csv',\n",
  107. " './generate_csv/test_07.csv',\n",
  108. " './generate_csv/test_01.csv',\n",
  109. " './generate_csv/test_08.csv',\n",
  110. " './generate_csv/test_06.csv',\n",
  111. " './generate_csv/test_02.csv',\n",
  112. " './generate_csv/test_04.csv',\n",
  113. " './generate_csv/test_05.csv',\n",
  114. " './generate_csv/test_09.csv',\n",
  115. " './generate_csv/test_03.csv']\n"
  116. ]
  117. }
  118. ],
  119. "source": [
  120. "source_dir = \"./generate_csv/\"\n",
  121. "\n",
  122. "#通过判断开头去添加文件\n",
  123. "def get_filenames_by_prefix(source_dir, prefix_name):\n",
  124. " all_files = os.listdir(source_dir)\n",
  125. " results = []\n",
  126. " for filename in all_files:\n",
  127. " if filename.startswith(prefix_name):\n",
  128. " results.append(os.path.join(source_dir, filename))\n",
  129. " return results\n",
  130. "\n",
  131. "train_filenames = get_filenames_by_prefix(source_dir, \"train\")\n",
  132. "valid_filenames = get_filenames_by_prefix(source_dir, \"valid\")\n",
  133. "test_filenames = get_filenames_by_prefix(source_dir, \"test\")\n",
  134. "\n",
  135. "import pprint\n",
  136. "pprint.pprint(train_filenames)\n",
  137. "pprint.pprint(valid_filenames)\n",
  138. "pprint.pprint(test_filenames)\n"
  139. ]
  140. },
  141. {
  142. "cell_type": "code",
  143. "execution_count": 3,
  144. "metadata": {},
  145. "outputs": [],
  146. "source": [
  147. "#下面的接口都是之前用过的\n",
  148. "def parse_csv_line(line, n_fields = 9):\n",
  149. " defs = [tf.constant(np.nan)] * n_fields\n",
  150. " parsed_fields = tf.io.decode_csv(line, record_defaults=defs)\n",
  151. " x = tf.stack(parsed_fields[0:-1])\n",
  152. " y = tf.stack(parsed_fields[-1:])\n",
  153. " return x, y\n",
  154. "\n",
  155. "def csv_reader_dataset(filenames, n_readers=5,\n",
  156. " batch_size=32, n_parse_threads=5,\n",
  157. " shuffle_buffer_size=10000):\n",
  158. " dataset = tf.data.Dataset.list_files(filenames)\n",
  159. " dataset = dataset.repeat()\n",
  160. " dataset = dataset.interleave(\n",
  161. " lambda filename: tf.data.TextLineDataset(filename).skip(1),\n",
  162. " cycle_length = n_readers\n",
  163. " )\n",
  164. " dataset.shuffle(shuffle_buffer_size)\n",
  165. " dataset = dataset.map(parse_csv_line,\n",
  166. " num_parallel_calls=n_parse_threads)\n",
  167. " dataset = dataset.batch(batch_size)\n",
  168. " return dataset\n",
  169. "\n",
  170. "batch_size = 32\n",
  171. "train_set = csv_reader_dataset(train_filenames,\n",
  172. " batch_size = batch_size)\n",
  173. "valid_set = csv_reader_dataset(valid_filenames,\n",
  174. " batch_size = batch_size)\n",
  175. "test_set = csv_reader_dataset(test_filenames,\n",
  176. " batch_size = batch_size)\n"
  177. ]
  178. },
  179. {
  180. "cell_type": "code",
  181. "execution_count": 4,
  182. "metadata": {},
  183. "outputs": [
  184. {
  185. "name": "stdout",
  186. "output_type": "stream",
  187. "text": [
  188. "chapter_4.tar.gz\t tf02_data_generate_csv.ipynb\r\n",
  189. "generate_csv\t\t tf03-tfrecord_basic_api.ipynb\r\n",
  190. "generate_tfrecords\t tf04_data_generate_tfrecord.ipynb\r\n",
  191. "temp.csv\t\t tfrecord_basic\r\n",
  192. "tf01-dataset_basic_api.ipynb\r\n"
  193. ]
  194. }
  195. ],
  196. "source": [
  197. "!ls"
  198. ]
  199. },
  200. {
  201. "cell_type": "markdown",
  202. "metadata": {},
  203. "source": [
  204. "# 把train_set,valid_set,test_set 存储到tfrecord类型的文件中"
  205. ]
  206. },
  207. {
  208. "cell_type": "code",
  209. "execution_count": 5,
  210. "metadata": {},
  211. "outputs": [],
  212. "source": [
  213. "#把基础的如何序列化的步骤搞到一个函数\n",
  214. "def serialize_example(x, y):\n",
  215. " \"\"\"Converts x, y to tf.train.Example and serialize\"\"\"\n",
  216. " input_feautres = tf.train.FloatList(value = x) #特征\n",
  217. " label = tf.train.FloatList(value = y)#标签\n",
  218. " features = tf.train.Features(\n",
  219. " feature = {\n",
  220. " \"input_features\": tf.train.Feature(\n",
  221. " float_list = input_feautres),\n",
  222. " \"label\": tf.train.Feature(float_list = label)\n",
  223. " }\n",
  224. " )\n",
  225. " #把features变为example\n",
  226. " example = tf.train.Example(features = features)\n",
  227. " return example.SerializeToString() #把example序列化\n",
  228. "#n_shards是存为多少个文件,steps_per_shard和 steps_per_epoch类似\n",
  229. "def csv_dataset_to_tfrecords(base_filename, dataset,\n",
  230. " n_shards, steps_per_shard,\n",
  231. " compression_type = None):\n",
  232. " #压缩文件类型\n",
  233. " options = tf.io.TFRecordOptions(\n",
  234. " compression_type = compression_type)\n",
  235. " all_filenames = []\n",
  236. " \n",
  237. " for shard_id in range(n_shards):\n",
  238. " filename_fullpath = '{}_{:05d}-of-{:05d}'.format(\n",
  239. " base_filename, shard_id, n_shards) #base_filename是一个前缀\n",
  240. " #打开文件\n",
  241. " with tf.io.TFRecordWriter(filename_fullpath, options) as writer:\n",
  242. " #取出数据,为什么skip,上一个文件写了前500行,下一个文件存后面的数据\n",
  243. " for x_batch, y_batch in dataset.skip(shard_id * steps_per_shard).take(steps_per_shard):\n",
  244. " for x_example, y_example in zip(x_batch, y_batch):\n",
  245. " writer.write(\n",
  246. " serialize_example(x_example, y_example))\n",
  247. " all_filenames.append(filename_fullpath)\n",
  248. " #返回所有tfrecord文件名\n",
  249. " return all_filenames"
  250. ]
  251. },
  252. {
  253. "cell_type": "code",
  254. "execution_count": 6,
  255. "metadata": {},
  256. "outputs": [],
  257. "source": [
  258. "!rm -rf generate_tfrecords"
  259. ]
  260. },
  261. {
  262. "cell_type": "code",
  263. "execution_count": 7,
  264. "metadata": {
  265. "collapsed": true
  266. },
  267. "outputs": [
  268. {
  269. "name": "stdout",
  270. "output_type": "stream",
  271. "text": [
  272. "(<tf.Tensor: shape=(32, 8), dtype=float32, numpy=\n",
  273. "array([[ 8.01544309e-01, 2.72161424e-01, -1.16243929e-01,\n",
  274. " -2.02311516e-01, -5.43051600e-01, -2.10396163e-02,\n",
  275. " -5.89762092e-01, -8.24184567e-02],\n",
  276. " [ 4.85305160e-01, -8.49241912e-01, -6.53012618e-02,\n",
  277. " -2.33796556e-02, 1.49743509e+00, -7.79065788e-02,\n",
  278. " -9.02363241e-01, 7.81451464e-01],\n",
  279. " [-6.67222738e-01, -4.82395217e-02, 3.45294058e-01,\n",
  280. " 5.38266838e-01, 1.85218394e+00, -6.11253828e-02,\n",
  281. " -8.41709316e-01, 1.52048469e+00],\n",
  282. " [ 6.30343556e-01, 1.87416613e+00, -6.71321452e-02,\n",
  283. " -1.25433668e-01, -1.97375536e-01, -2.27226317e-02,\n",
  284. " -6.92407250e-01, 7.26523340e-01],\n",
  285. " [ 6.36364639e-01, -1.08954263e+00, 9.26090255e-02,\n",
  286. " -2.05381244e-01, 1.20256710e+00, -3.63012254e-02,\n",
  287. " -6.78410172e-01, 1.82235345e-01],\n",
  288. " [-2.98072815e-01, 3.52261662e-01, -1.09205075e-01,\n",
  289. " -2.50555217e-01, -3.40640247e-02, -6.03400404e-03,\n",
  290. " 1.08055484e+00, -1.06113815e+00],\n",
  291. " [-7.43205428e-01, 9.12963331e-01, -6.44320250e-01,\n",
  292. " -1.47909701e-01, 7.39851117e-01, 1.14276908e-01,\n",
  293. " -7.95052409e-01, 6.81582153e-01],\n",
  294. " [ 1.51805115e+00, -5.28840959e-01, 8.10247064e-01,\n",
  295. " -1.92141697e-01, 4.41353947e-01, 2.73350589e-02,\n",
  296. " -8.18380833e-01, 8.56353521e-01],\n",
  297. " [ 1.63122582e+00, 3.52261662e-01, 4.08057608e-02,\n",
  298. " -1.40889511e-01, -4.63210404e-01, -6.75162375e-02,\n",
  299. " -8.27712238e-01, 5.96693158e-01],\n",
  300. " [-1.06354737e+00, 1.87416613e+00, -4.93448943e-01,\n",
  301. " -6.96261302e-02, -2.73587584e-01, -1.34195149e-01,\n",
  302. " 1.03389800e+00, -1.34576583e+00],\n",
  303. " [-3.05882934e-02, -9.29342151e-01, 2.59621471e-01,\n",
  304. " -6.01274054e-03, -5.00409126e-01, -3.07798684e-02,\n",
  305. " 1.59844637e+00, -1.81515181e+00],\n",
  306. " [-8.24676275e-01, -4.82395217e-02, -3.44865829e-01,\n",
  307. " -8.47758725e-02, 5.01234829e-01, -3.46999951e-02,\n",
  308. " 5.30003488e-01, -8.74119252e-02],\n",
  309. " [-3.29563528e-01, 9.93063569e-01, -8.77174079e-01,\n",
  310. " -3.63671094e-01, -1.11645639e+00, -8.51059332e-02,\n",
  311. " 1.06655777e+00, -1.38571358e+00],\n",
  312. " [ 4.04922552e-02, -6.89041436e-01, -4.43798512e-01,\n",
  313. " 2.23745853e-02, -2.21872270e-01, -1.48285031e-01,\n",
  314. " -8.88366222e-01, 6.36640906e-01],\n",
  315. " [-9.97422278e-01, 1.23336422e+00, -7.57719278e-01,\n",
  316. " -1.11092515e-02, -2.30037838e-01, 5.48742227e-02,\n",
  317. " -7.57726908e-01, 7.06549466e-01],\n",
  318. " [ 1.90638328e+00, 5.12462139e-01, 4.47582811e-01,\n",
  319. " -2.76721776e-01, -6.31058335e-01, -7.08114654e-02,\n",
  320. " -7.06404328e-01, 7.46497214e-01],\n",
  321. " [ 7.00647458e-02, 3.18607129e-02, -2.57098645e-01,\n",
  322. " -3.00019473e-01, -2.66329288e-01, -9.85835046e-02,\n",
  323. " 1.08522058e+00, -1.37073314e+00],\n",
  324. " [-6.91253722e-01, -4.82395217e-02, -8.84629071e-01,\n",
  325. " 4.13295318e-04, -1.42938361e-01, -1.61392525e-01,\n",
  326. " -7.20401347e-01, 6.16667032e-01],\n",
  327. " [-4.27712232e-01, -9.29342151e-01, -3.57968122e-01,\n",
  328. " 6.92046210e-02, -8.15237403e-01, -5.46457283e-02,\n",
  329. " -5.71099281e-01, 5.51751971e-01],\n",
  330. " [-1.86834182e-03, -1.24974310e+00, 2.20688999e-01,\n",
  331. " -2.50608753e-02, 3.33386868e-01, 2.75048837e-02,\n",
  332. " 9.17255700e-01, -6.81634605e-01],\n",
  333. " [ 3.88017392e+00, -9.29342151e-01, 1.29029870e+00,\n",
  334. " -1.72691330e-01, -2.25501403e-01, 5.14101014e-02,\n",
  335. " -8.46374989e-01, 8.86314332e-01],\n",
  336. " [ 1.23452818e+00, -1.32984328e+00, 5.75435698e-01,\n",
  337. " -1.19361848e-01, -4.16938782e-01, 1.54011786e-01,\n",
  338. " 1.09921765e+00, -1.34576583e+00],\n",
  339. " [ 9.00758684e-01, -1.28339753e-01, 8.26199055e-02,\n",
  340. " -2.27456287e-01, -2.84474999e-01, 8.01625699e-02,\n",
  341. " -9.16360319e-01, 8.11412275e-01],\n",
  342. " [-1.04972430e-01, -2.13084579e+00, -8.56280804e-01,\n",
  343. " 3.74467552e-01, -7.72594988e-01, -1.98224083e-01,\n",
  344. " 9.96572435e-01, -1.42566133e+00],\n",
  345. " [-4.90587056e-01, 3.52261662e-01, -6.64385974e-01,\n",
  346. " 2.82370716e-01, 5.10307670e-01, 4.01096910e-01,\n",
  347. " 8.51936042e-01, -1.31580508e+00],\n",
  348. " [ 3.72503400e-01, -6.89041436e-01, 6.45801365e-01,\n",
  349. " 8.00678432e-02, -3.15322757e-01, -2.51115970e-02,\n",
  350. " 5.62663257e-01, -5.74511178e-02],\n",
  351. " [ 1.58142820e-01, 1.15326405e+00, -9.75820422e-02,\n",
  352. " -2.74813175e-01, -6.60091519e-01, -9.36449692e-02,\n",
  353. " -8.51040661e-01, 7.26523340e-01],\n",
  354. " [-2.78091401e-01, 2.72161424e-01, 4.07564074e-01,\n",
  355. " -3.05503774e-02, -1.23885356e-01, -2.59928163e-02,\n",
  356. " 1.40715313e+00, -8.61399472e-01],\n",
  357. " [-2.48145923e-01, -1.08954263e+00, -3.17847952e-02,\n",
  358. " -1.67912588e-01, -9.64032352e-01, -6.52535707e-02,\n",
  359. " 1.46314144e+00, -1.50555682e+00],\n",
  360. " [ 4.95588928e-01, 1.23336422e+00, -2.66338676e-01,\n",
  361. " -3.35229747e-02, -1.15637708e+00, -1.26926363e-01,\n",
  362. " 8.65933120e-01, -1.33577895e+00],\n",
  363. " [ 4.06818151e-01, 1.79406595e+00, 5.06766677e-01,\n",
  364. " -1.45333722e-01, -4.59581256e-01, -9.45077166e-02,\n",
  365. " 9.73243952e-01, -1.28085077e+00],\n",
  366. " [-1.15839255e+00, 1.23336422e+00, -7.80114472e-01,\n",
  367. " -3.65703627e-02, 1.13002844e-02, 2.21132398e-01,\n",
  368. " -7.76389658e-01, 6.61608279e-01]], dtype=float32)>, <tf.Tensor: shape=(32, 1), dtype=float32, numpy=\n",
  369. "array([[3.226 ],\n",
  370. " [2.956 ],\n",
  371. " [1.59 ],\n",
  372. " [2.419 ],\n",
  373. " [2.429 ],\n",
  374. " [1.514 ],\n",
  375. " [1.438 ],\n",
  376. " [2.898 ],\n",
  377. " [3.376 ],\n",
  378. " [1.982 ],\n",
  379. " [1.598 ],\n",
  380. " [0.717 ],\n",
  381. " [1.563 ],\n",
  382. " [2.852 ],\n",
  383. " [1.739 ],\n",
  384. " [5.00001],\n",
  385. " [1.61 ],\n",
  386. " [2.321 ],\n",
  387. " [2.385 ],\n",
  388. " [1.293 ],\n",
  389. " [5.00001],\n",
  390. " [2.487 ],\n",
  391. " [2.392 ],\n",
  392. " [2.75 ],\n",
  393. " [2.028 ],\n",
  394. " [1.431 ],\n",
  395. " [2.26 ],\n",
  396. " [1.171 ],\n",
  397. " [1.031 ],\n",
  398. " [2.538 ],\n",
  399. " [2.519 ],\n",
  400. " [1.125 ]], dtype=float32)>)\n"
  401. ]
  402. }
  403. ],
  404. "source": [
  405. "for i in train_set.take(1):\n",
  406. " print(i) "
  407. ]
  408. },
  409. {
  410. "cell_type": "code",
  411. "execution_count": 8,
  412. "metadata": {},
  413. "outputs": [
  414. {
  415. "name": "stdout",
  416. "output_type": "stream",
  417. "text": [
  418. "CPU times: user 40 s, sys: 8.95 s, total: 48.9 s\n",
  419. "Wall time: 42.6 s\n"
  420. ]
  421. }
  422. ],
  423. "source": [
  424. "%%time\n",
  425. "# 训练集和测试集都分20\n",
  426. "n_shards = 20\n",
  427. "train_steps_per_shard = 11610 // batch_size // n_shards\n",
  428. "valid_steps_per_shard = 3880 // batch_size // 10\n",
  429. "test_steps_per_shard = 5170 // batch_size // 10\n",
  430. "\n",
  431. "output_dir = \"generate_tfrecords\"\n",
  432. "if not os.path.exists(output_dir):\n",
  433. " os.mkdir(output_dir)\n",
  434. "\n",
  435. "train_basename = os.path.join(output_dir, \"train\")\n",
  436. "valid_basename = os.path.join(output_dir, \"valid\")\n",
  437. "test_basename = os.path.join(output_dir, \"test\")\n",
  438. "\n",
  439. "train_tfrecord_filenames = csv_dataset_to_tfrecords(\n",
  440. " train_basename, train_set, n_shards, train_steps_per_shard, None)\n",
  441. "valid_tfrecord_filenames = csv_dataset_to_tfrecords(\n",
  442. " valid_basename, valid_set, 10, valid_steps_per_shard, None)\n",
  443. "test_tfrecord_fielnames = csv_dataset_to_tfrecords(\n",
  444. " test_basename, test_set, 10, test_steps_per_shard, None)\n",
  445. "#执行会发现目录下总计生成了60个文件,这里文件数目改为一致,为了对比时间"
  446. ]
  447. },
  448. {
  449. "cell_type": "code",
  450. "execution_count": 9,
  451. "metadata": {},
  452. "outputs": [
  453. {
  454. "name": "stdout",
  455. "output_type": "stream",
  456. "text": [
  457. "总用量 1960\r\n",
  458. "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00000-of-00010\r\n",
  459. "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00001-of-00010\r\n",
  460. "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00002-of-00010\r\n",
  461. "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00003-of-00010\r\n",
  462. "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00004-of-00010\r\n",
  463. "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00005-of-00010\r\n",
  464. "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00006-of-00010\r\n",
  465. "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00007-of-00010\r\n",
  466. "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00008-of-00010\r\n",
  467. "-rw-rw-r-- 1 luke luke 47616 Jul 23 11:33 test_00009-of-00010\r\n",
  468. "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:32 train_00000-of-00020\r\n",
  469. "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:32 train_00001-of-00020\r\n",
  470. "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:32 train_00002-of-00020\r\n",
  471. "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:32 train_00003-of-00020\r\n",
  472. "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:32 train_00004-of-00020\r\n",
  473. "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:32 train_00005-of-00020\r\n",
  474. "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:32 train_00006-of-00020\r\n",
  475. "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:32 train_00007-of-00020\r\n",
  476. "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:32 train_00008-of-00020\r\n",
  477. "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00009-of-00020\r\n",
  478. "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00010-of-00020\r\n",
  479. "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00011-of-00020\r\n",
  480. "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00012-of-00020\r\n",
  481. "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00013-of-00020\r\n",
  482. "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00014-of-00020\r\n",
  483. "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00015-of-00020\r\n",
  484. "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00016-of-00020\r\n",
  485. "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00017-of-00020\r\n",
  486. "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00018-of-00020\r\n",
  487. "-rw-rw-r-- 1 luke luke 53568 Jul 23 11:33 train_00019-of-00020\r\n",
  488. "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00000-of-00010\r\n",
  489. "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00001-of-00010\r\n",
  490. "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00002-of-00010\r\n",
  491. "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00003-of-00010\r\n",
  492. "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00004-of-00010\r\n",
  493. "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00005-of-00010\r\n",
  494. "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00006-of-00010\r\n",
  495. "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00007-of-00010\r\n",
  496. "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00008-of-00010\r\n",
  497. "-rw-rw-r-- 1 luke luke 35712 Jul 23 11:33 valid_00009-of-00010\r\n"
  498. ]
  499. }
  500. ],
  501. "source": [
  502. "!ls -l generate_tfrecords"
  503. ]
  504. },
  505. {
  506. "cell_type": "code",
  507. "execution_count": 33,
  508. "metadata": {},
  509. "outputs": [],
  510. "source": [
  511. "#生成一下压缩的\n",
  512. "# n_shards = 20\n",
  513. "# train_steps_per_shard = 11610 // batch_size // n_shards\n",
  514. "# valid_steps_per_shard = 3880 // batch_size // n_shards\n",
  515. "# test_steps_per_shard = 5170 // batch_size // n_shards\n",
  516. "\n",
  517. "# output_dir = \"generate_tfrecords_zip\"\n",
  518. "# if not os.path.exists(output_dir):\n",
  519. "# os.mkdir(output_dir)\n",
  520. "\n",
  521. "# train_basename = os.path.join(output_dir, \"train\")\n",
  522. "# valid_basename = os.path.join(output_dir, \"valid\")\n",
  523. "# test_basename = os.path.join(output_dir, \"test\")\n",
  524. "# #只需修改参数的类型即可\n",
  525. "# train_tfrecord_filenames = csv_dataset_to_tfrecords(\n",
  526. "# train_basename, train_set, n_shards, train_steps_per_shard,\n",
  527. "# compression_type = \"GZIP\")\n",
  528. "# valid_tfrecord_filenames = csv_dataset_to_tfrecords(\n",
  529. "# valid_basename, valid_set, n_shards, valid_steps_per_shard,\n",
  530. "# compression_type = \"GZIP\")\n",
  531. "# test_tfrecord_fielnames = csv_dataset_to_tfrecords(\n",
  532. "# test_basename, test_set, n_shards, test_steps_per_shard,\n",
  533. "# compression_type = \"GZIP\")"
  534. ]
  535. },
  536. {
  537. "cell_type": "code",
  538. "execution_count": 34,
  539. "metadata": {},
  540. "outputs": [
  541. {
  542. "name": "stdout",
  543. "output_type": "stream",
  544. "text": [
  545. "总用量 860\r\n",
  546. "-rw-rw-r-- 1 luke luke 10171 May 7 11:16 test_00000-of-00020\r\n",
  547. "-rw-rw-r-- 1 luke luke 10230 May 7 11:16 test_00001-of-00020\r\n",
  548. "-rw-rw-r-- 1 luke luke 10204 May 7 11:16 test_00002-of-00020\r\n",
  549. "-rw-rw-r-- 1 luke luke 10213 May 7 11:16 test_00003-of-00020\r\n",
  550. "-rw-rw-r-- 1 luke luke 10229 May 7 11:16 test_00004-of-00020\r\n",
  551. "-rw-rw-r-- 1 luke luke 10200 May 7 11:16 test_00005-of-00020\r\n",
  552. "-rw-rw-r-- 1 luke luke 10199 May 7 11:16 test_00006-of-00020\r\n",
  553. "-rw-rw-r-- 1 luke luke 10215 May 7 11:16 test_00007-of-00020\r\n",
  554. "-rw-rw-r-- 1 luke luke 10179 May 7 11:16 test_00008-of-00020\r\n",
  555. "-rw-rw-r-- 1 luke luke 10149 May 7 11:16 test_00009-of-00020\r\n",
  556. "-rw-rw-r-- 1 luke luke 10141 May 7 11:16 test_00010-of-00020\r\n",
  557. "-rw-rw-r-- 1 luke luke 10221 May 7 11:16 test_00011-of-00020\r\n",
  558. "-rw-rw-r-- 1 luke luke 10209 May 7 11:16 test_00012-of-00020\r\n",
  559. "-rw-rw-r-- 1 luke luke 10214 May 7 11:16 test_00013-of-00020\r\n",
  560. "-rw-rw-r-- 1 luke luke 10212 May 7 11:16 test_00014-of-00020\r\n",
  561. "-rw-rw-r-- 1 luke luke 10209 May 7 11:16 test_00015-of-00020\r\n",
  562. "-rw-rw-r-- 1 luke luke 10185 May 7 11:16 test_00016-of-00020\r\n",
  563. "-rw-rw-r-- 1 luke luke 10266 May 7 11:16 test_00017-of-00020\r\n",
  564. "-rw-rw-r-- 1 luke luke 10258 May 7 11:16 test_00018-of-00020\r\n",
  565. "-rw-rw-r-- 1 luke luke 10170 May 7 11:16 test_00019-of-00020\r\n",
  566. "-rw-rw-r-- 1 luke luke 22359 May 7 19:17 train_00000-of-00020\r\n",
  567. "-rw-rw-r-- 1 luke luke 22447 May 7 19:17 train_00001-of-00020\r\n",
  568. "-rw-rw-r-- 1 luke luke 22366 May 7 19:17 train_00002-of-00020\r\n",
  569. "-rw-rw-r-- 1 luke luke 22311 May 7 19:17 train_00003-of-00020\r\n",
  570. "-rw-rw-r-- 1 luke luke 22384 May 7 19:17 train_00004-of-00020\r\n",
  571. "-rw-rw-r-- 1 luke luke 22341 May 7 19:17 train_00005-of-00020\r\n",
  572. "-rw-rw-r-- 1 luke luke 22416 May 7 19:17 train_00006-of-00020\r\n",
  573. "-rw-rw-r-- 1 luke luke 22285 May 7 19:17 train_00007-of-00020\r\n",
  574. "-rw-rw-r-- 1 luke luke 22415 May 7 19:17 train_00008-of-00020\r\n",
  575. "-rw-rw-r-- 1 luke luke 22365 May 7 19:17 train_00009-of-00020\r\n",
  576. "-rw-rw-r-- 1 luke luke 22431 May 7 19:17 train_00010-of-00020\r\n",
  577. "-rw-rw-r-- 1 luke luke 22367 May 7 19:17 train_00011-of-00020\r\n",
  578. "-rw-rw-r-- 1 luke luke 22346 May 7 19:17 train_00012-of-00020\r\n",
  579. "-rw-rw-r-- 1 luke luke 22332 May 7 19:17 train_00013-of-00020\r\n",
  580. "-rw-rw-r-- 1 luke luke 22452 May 7 19:17 train_00014-of-00020\r\n",
  581. "-rw-rw-r-- 1 luke luke 20 May 7 19:17 train_00015-of-00020\r\n",
  582. "-rw-rw-r-- 1 luke luke 22427 May 7 11:16 train_00016-of-00020\r\n",
  583. "-rw-rw-r-- 1 luke luke 22427 May 7 11:16 train_00017-of-00020\r\n",
  584. "-rw-rw-r-- 1 luke luke 22454 May 7 11:16 train_00018-of-00020\r\n",
  585. "-rw-rw-r-- 1 luke luke 22309 May 7 11:16 train_00019-of-00020\r\n",
  586. "-rw-rw-r-- 1 luke luke 7747 May 7 11:16 valid_00000-of-00020\r\n",
  587. "-rw-rw-r-- 1 luke luke 7744 May 7 11:16 valid_00001-of-00020\r\n",
  588. "-rw-rw-r-- 1 luke luke 7749 May 7 11:16 valid_00002-of-00020\r\n",
  589. "-rw-rw-r-- 1 luke luke 7755 May 7 11:16 valid_00003-of-00020\r\n",
  590. "-rw-rw-r-- 1 luke luke 7744 May 7 11:16 valid_00004-of-00020\r\n",
  591. "-rw-rw-r-- 1 luke luke 7678 May 7 11:16 valid_00005-of-00020\r\n",
  592. "-rw-rw-r-- 1 luke luke 7762 May 7 11:16 valid_00006-of-00020\r\n",
  593. "-rw-rw-r-- 1 luke luke 7720 May 7 11:16 valid_00007-of-00020\r\n",
  594. "-rw-rw-r-- 1 luke luke 7727 May 7 11:16 valid_00008-of-00020\r\n",
  595. "-rw-rw-r-- 1 luke luke 7739 May 7 11:16 valid_00009-of-00020\r\n",
  596. "-rw-rw-r-- 1 luke luke 7762 May 7 11:16 valid_00010-of-00020\r\n",
  597. "-rw-rw-r-- 1 luke luke 7727 May 7 11:16 valid_00011-of-00020\r\n",
  598. "-rw-rw-r-- 1 luke luke 7729 May 7 11:16 valid_00012-of-00020\r\n",
  599. "-rw-rw-r-- 1 luke luke 7763 May 7 11:16 valid_00013-of-00020\r\n",
  600. "-rw-rw-r-- 1 luke luke 7727 May 7 11:16 valid_00014-of-00020\r\n",
  601. "-rw-rw-r-- 1 luke luke 7749 May 7 11:16 valid_00015-of-00020\r\n",
  602. "-rw-rw-r-- 1 luke luke 7741 May 7 11:16 valid_00016-of-00020\r\n",
  603. "-rw-rw-r-- 1 luke luke 7753 May 7 11:16 valid_00017-of-00020\r\n",
  604. "-rw-rw-r-- 1 luke luke 7702 May 7 11:16 valid_00018-of-00020\r\n",
  605. "-rw-rw-r-- 1 luke luke 7711 May 7 11:16 valid_00019-of-00020\r\n"
  606. ]
  607. }
  608. ],
  609. "source": [
  610. "!ls -l generate_tfrecords_zip"
  611. ]
  612. },
  613. {
  614. "cell_type": "code",
  615. "execution_count": 10,
  616. "metadata": {},
  617. "outputs": [
  618. {
  619. "name": "stdout",
  620. "output_type": "stream",
  621. "text": [
  622. "['generate_tfrecords/train_00000-of-00020',\n",
  623. " 'generate_tfrecords/train_00001-of-00020',\n",
  624. " 'generate_tfrecords/train_00002-of-00020',\n",
  625. " 'generate_tfrecords/train_00003-of-00020',\n",
  626. " 'generate_tfrecords/train_00004-of-00020',\n",
  627. " 'generate_tfrecords/train_00005-of-00020',\n",
  628. " 'generate_tfrecords/train_00006-of-00020',\n",
  629. " 'generate_tfrecords/train_00007-of-00020',\n",
  630. " 'generate_tfrecords/train_00008-of-00020',\n",
  631. " 'generate_tfrecords/train_00009-of-00020',\n",
  632. " 'generate_tfrecords/train_00010-of-00020',\n",
  633. " 'generate_tfrecords/train_00011-of-00020',\n",
  634. " 'generate_tfrecords/train_00012-of-00020',\n",
  635. " 'generate_tfrecords/train_00013-of-00020',\n",
  636. " 'generate_tfrecords/train_00014-of-00020',\n",
  637. " 'generate_tfrecords/train_00015-of-00020',\n",
  638. " 'generate_tfrecords/train_00016-of-00020',\n",
  639. " 'generate_tfrecords/train_00017-of-00020',\n",
  640. " 'generate_tfrecords/train_00018-of-00020',\n",
  641. " 'generate_tfrecords/train_00019-of-00020']\n",
  642. "['generate_tfrecords/valid_00000-of-00010',\n",
  643. " 'generate_tfrecords/valid_00001-of-00010',\n",
  644. " 'generate_tfrecords/valid_00002-of-00010',\n",
  645. " 'generate_tfrecords/valid_00003-of-00010',\n",
  646. " 'generate_tfrecords/valid_00004-of-00010',\n",
  647. " 'generate_tfrecords/valid_00005-of-00010',\n",
  648. " 'generate_tfrecords/valid_00006-of-00010',\n",
  649. " 'generate_tfrecords/valid_00007-of-00010',\n",
  650. " 'generate_tfrecords/valid_00008-of-00010',\n",
  651. " 'generate_tfrecords/valid_00009-of-00010']\n",
  652. "['generate_tfrecords/test_00000-of-00010',\n",
  653. " 'generate_tfrecords/test_00001-of-00010',\n",
  654. " 'generate_tfrecords/test_00002-of-00010',\n",
  655. " 'generate_tfrecords/test_00003-of-00010',\n",
  656. " 'generate_tfrecords/test_00004-of-00010',\n",
  657. " 'generate_tfrecords/test_00005-of-00010',\n",
  658. " 'generate_tfrecords/test_00006-of-00010',\n",
  659. " 'generate_tfrecords/test_00007-of-00010',\n",
  660. " 'generate_tfrecords/test_00008-of-00010',\n",
  661. " 'generate_tfrecords/test_00009-of-00010']\n"
  662. ]
  663. }
  664. ],
  665. "source": [
  666. "#打印一下文件名\n",
  667. "pprint.pprint(train_tfrecord_filenames)\n",
  668. "pprint.pprint(valid_tfrecord_filenames)\n",
  669. "pprint.pprint(test_tfrecord_fielnames)"
  670. ]
  671. },
  672. {
  673. "cell_type": "code",
  674. "execution_count": 11,
  675. "metadata": {},
  676. "outputs": [
  677. {
  678. "name": "stdout",
  679. "output_type": "stream",
  680. "text": [
  681. "CPU times: user 58 µs, sys: 14 µs, total: 72 µs\n",
  682. "Wall time: 80.1 µs\n"
  683. ]
  684. }
  685. ],
  686. "source": [
  687. "%%time\n",
  688. "#把数据读取出来\n",
  689. "expected_features = {\n",
  690. " \"input_features\": tf.io.FixedLenFeature([8], dtype=tf.float32),\n",
  691. " \"label\": tf.io.FixedLenFeature([1], dtype=tf.float32)\n",
  692. "}\n",
  693. "\n",
  694. "def parse_example(serialized_example):\n",
  695. " example = tf.io.parse_single_example(serialized_example,\n",
  696. " expected_features)\n",
  697. " return example[\"input_features\"], example[\"label\"]\n",
  698. "\n",
  699. "def tfrecords_reader_dataset(filenames, n_readers=5,\n",
  700. " batch_size=32, n_parse_threads=5,\n",
  701. " shuffle_buffer_size=10000):\n",
  702. " dataset = tf.data.Dataset.list_files(filenames)\n",
  703. " dataset = dataset.repeat() #为了能够无限次epoch\n",
  704. " dataset = dataset.interleave(\n",
  705. "# lambda filename: tf.data.TFRecordDataset(\n",
  706. "# filename, compression_type = \"GZIP\"),\n",
  707. " lambda filename: tf.data.TFRecordDataset(\n",
  708. " filename),\n",
  709. " cycle_length = n_readers\n",
  710. " )\n",
  711. " #洗牌,就是给数据打乱,样本顺序打乱\n",
  712. " dataset.shuffle(shuffle_buffer_size)\n",
  713. " dataset = dataset.map(parse_example,\n",
  714. " num_parallel_calls=n_parse_threads)#把对应的一个样本是字节流的,变为浮点类型\n",
  715. " dataset = dataset.batch(batch_size) #原来写进去是一条一条的sample,要分配\n",
  716. " return dataset\n",
  717. "\n",
  718. "#测试一下,tfrecords_reader_dataset是否可以正常运行\n",
  719. "# tfrecords_train = tfrecords_reader_dataset(train_tfrecord_filenames,\n",
  720. "# batch_size = 3)\n",
  721. "# for x_batch, y_batch in tfrecords_train.take(10):\n",
  722. "# print(x_batch)\n",
  723. "# print(y_batch)"
  724. ]
  725. },
  726. {
  727. "cell_type": "code",
  728. "execution_count": 12,
  729. "metadata": {
  730. "collapsed": true
  731. },
  732. "outputs": [
  733. {
  734. "name": "stdout",
  735. "output_type": "stream",
  736. "text": [
  737. "WARNING:tensorflow:AutoGraph could not transform <function tfrecords_reader_dataset.<locals>.<lambda> at 0x7f98284712f0> and will run it as-is.\n",
  738. "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
  739. "Cause: Unable to locate the source code of <function tfrecords_reader_dataset.<locals>.<lambda> at 0x7f98284712f0>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code\n",
  740. "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
  741. "WARNING: AutoGraph could not transform <function tfrecords_reader_dataset.<locals>.<lambda> at 0x7f98284712f0> and will run it as-is.\n",
  742. "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
  743. "Cause: Unable to locate the source code of <function tfrecords_reader_dataset.<locals>.<lambda> at 0x7f98284712f0>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code\n",
  744. "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
  745. "WARNING:tensorflow:AutoGraph could not transform <function parse_example at 0x7f98284710d0> and will run it as-is.\n",
  746. "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
  747. "Cause: Unable to locate the source code of <function parse_example at 0x7f98284710d0>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code\n",
  748. "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
  749. "WARNING: AutoGraph could not transform <function parse_example at 0x7f98284710d0> and will run it as-is.\n",
  750. "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
  751. "Cause: Unable to locate the source code of <function parse_example at 0x7f98284710d0>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code\n",
  752. "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
  753. "WARNING:tensorflow:AutoGraph could not transform <function tfrecords_reader_dataset.<locals>.<lambda> at 0x7f9828471378> and will run it as-is.\n",
  754. "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
  755. "Cause: Unable to locate the source code of <function tfrecords_reader_dataset.<locals>.<lambda> at 0x7f9828471378>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code\n",
  756. "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
  757. "WARNING: AutoGraph could not transform <function tfrecords_reader_dataset.<locals>.<lambda> at 0x7f9828471378> and will run it as-is.\n",
  758. "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
  759. "Cause: Unable to locate the source code of <function tfrecords_reader_dataset.<locals>.<lambda> at 0x7f9828471378>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code\n",
  760. "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
  761. "WARNING:tensorflow:AutoGraph could not transform <function tfrecords_reader_dataset.<locals>.<lambda> at 0x7f9828471598> and will run it as-is.\n",
  762. "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
  763. "Cause: Unable to locate the source code of <function tfrecords_reader_dataset.<locals>.<lambda> at 0x7f9828471598>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code\n",
  764. "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
  765. "WARNING: AutoGraph could not transform <function tfrecords_reader_dataset.<locals>.<lambda> at 0x7f9828471598> and will run it as-is.\n",
  766. "Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.\n",
  767. "Cause: Unable to locate the source code of <function tfrecords_reader_dataset.<locals>.<lambda> at 0x7f9828471598>. Note that functions defined in certain environments, like the interactive Python shell do not expose their source code. If that is the case, you should to define them in a .py source file. If you are certain the code is graph-compatible, wrap the call using @tf.autograph.do_not_convert. Original error: could not get source code\n",
  768. "To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert\n",
  769. "CPU times: user 218 ms, sys: 83.8 ms, total: 302 ms\n",
  770. "Wall time: 294 ms\n"
  771. ]
  772. }
  773. ],
  774. "source": [
  775. "%%time\n",
  776. "#得到dataset,dataset是tensor,可以直接拿tensor训练\n",
  777. "\n",
  778. "batch_size = 32\n",
  779. "tfrecords_train_set = tfrecords_reader_dataset(\n",
  780. " train_tfrecord_filenames, batch_size = batch_size)\n",
  781. "tfrecords_valid_set = tfrecords_reader_dataset(\n",
  782. " valid_tfrecord_filenames, batch_size = batch_size)\n",
  783. "tfrecords_test_set = tfrecords_reader_dataset(\n",
  784. " test_tfrecord_fielnames, batch_size = batch_size)"
  785. ]
  786. },
  787. {
  788. "cell_type": "code",
  789. "execution_count": 17,
  790. "metadata": {},
  791. "outputs": [
  792. {
  793. "data": {
  794. "text/plain": [
  795. "tensorflow.python.data.ops.dataset_ops.BatchDataset"
  796. ]
  797. },
  798. "execution_count": 17,
  799. "metadata": {},
  800. "output_type": "execute_result"
  801. }
  802. ],
  803. "source": [
  804. "type(tfrecords_train_set)"
  805. ]
  806. },
  807. {
  808. "cell_type": "code",
  809. "execution_count": 13,
  810. "metadata": {},
  811. "outputs": [
  812. {
  813. "name": "stdout",
  814. "output_type": "stream",
  815. "text": [
  816. "(<tf.Tensor: shape=(32, 8), dtype=float32, numpy=\n",
  817. "array([[-1.33849168e+00, 1.15326405e+00, -5.64006925e-01,\n",
  818. " -1.73587687e-02, -1.50110144e-02, 1.30692095e-01,\n",
  819. " -7.62392581e-01, 6.61608279e-01],\n",
  820. " [ 1.69143653e+00, -4.48740691e-01, 1.74070787e+00,\n",
  821. " 2.62504518e-01, 3.60605478e-01, 3.49154733e-02,\n",
  822. " -9.30357397e-01, 8.06418836e-01],\n",
  823. " [-3.68034393e-01, -1.00944233e+00, 9.95716763e+00,\n",
  824. " 8.32341957e+00, -1.11282730e+00, -1.44638717e-01,\n",
  825. " 1.34183347e+00, -2.12248623e-01],\n",
  826. " [ 6.85438991e-01, -5.28840959e-01, 2.43798941e-01,\n",
  827. " -1.07175767e-01, -3.28932047e-01, -7.14625940e-02,\n",
  828. " -5.66433609e-01, -9.73988622e-02],\n",
  829. " [ 2.22390127e+00, -1.00944233e+00, 1.12938309e+00,\n",
  830. " 6.29329979e-01, -4.16031510e-01, 3.57014686e-01,\n",
  831. " -7.34398425e-01, 3.57006729e-01],\n",
  832. " [-3.75547409e-01, 7.52762854e-01, -5.35495043e-01,\n",
  833. " -6.88404664e-02, 8.90460610e-01, 8.53771530e-03,\n",
  834. " 9.73243952e-01, -1.45062864e+00],\n",
  835. " [ 1.26937580e+00, -1.49004376e+00, 1.30503133e-01,\n",
  836. " -1.33789301e-01, 6.90857649e-01, 1.06014162e-01,\n",
  837. " 7.58622229e-01, -1.11107290e+00],\n",
  838. " [-4.36877012e-01, -1.32984328e+00, 2.76863605e-01,\n",
  839. " 2.20493942e-01, -1.10647631e+00, -9.77752283e-02,\n",
  840. " -7.95052409e-01, 1.43060231e+00],\n",
  841. " [-5.33587039e-01, 9.93063569e-01, -2.71614999e-01,\n",
  842. " -1.86064959e-01, -3.83369207e-01, -8.69813338e-02,\n",
  843. " -1.04530334e-01, 2.62130827e-01],\n",
  844. " [ 7.65653625e-02, 6.72662616e-01, 1.92134786e+00,\n",
  845. " 9.22839880e-01, -1.18722475e+00, 2.84929629e-02,\n",
  846. " 9.35918450e-01, -7.26575792e-01],\n",
  847. " [ 9.79352236e-01, 1.87416613e+00, 1.42585576e-01,\n",
  848. " -1.65525392e-01, 1.85499236e-01, -7.59028569e-02,\n",
  849. " 9.77909684e-01, -1.44563520e+00],\n",
  850. " [-4.48919147e-01, -2.88540244e-01, -4.44056481e-01,\n",
  851. " 5.86968623e-02, 9.30381179e-01, -3.20860595e-02,\n",
  852. " -6.78410172e-01, 5.91699719e-01],\n",
  853. " [ 9.33741331e-01, -4.48740691e-01, 1.85706854e-01,\n",
  854. " -6.31846488e-02, -1.69249669e-01, 1.85529351e-01,\n",
  855. " -6.36418939e-01, 5.56745410e-01],\n",
  856. " [-9.18775439e-01, 4.32361901e-01, -5.84726632e-01,\n",
  857. " -5.63860834e-02, 3.48897241e-02, 5.70803955e-02,\n",
  858. " -1.36893225e+00, 1.21588326e+00],\n",
  859. " [-6.18148386e-01, 1.87416613e+00, -4.37868014e-02,\n",
  860. " -8.39404464e-02, -6.44667625e-01, -4.46900912e-02,\n",
  861. " 9.96572435e-01, -1.34576583e+00],\n",
  862. " [-3.01620234e-02, -2.08439991e-01, 3.14824611e-01,\n",
  863. " 1.60544395e-01, -4.75005120e-01, -2.96259038e-02,\n",
  864. " 1.71975434e+00, -3.12117994e-01],\n",
  865. " [-9.78879511e-01, -6.89041436e-01, -5.42888701e-01,\n",
  866. " 3.03960383e-01, 3.80565763e-01, 2.37069577e-02,\n",
  867. " -1.35026944e+00, 1.23585713e+00],\n",
  868. " [-9.83834922e-01, -6.08941197e-01, -3.04150850e-01,\n",
  869. " -9.17670801e-02, 2.08181381e-01, -4.22854684e-02,\n",
  870. " 2.23298025e+00, -1.35575283e+00],\n",
  871. " [-5.00178158e-01, -1.08954263e+00, 7.68675879e-02,\n",
  872. " 5.26065975e-02, 4.82181817e-01, -2.72741280e-02,\n",
  873. " 1.52379537e+00, -6.51673794e-01],\n",
  874. " [-3.39900553e-01, 2.72161424e-01, -4.90426540e-01,\n",
  875. " 7.07200915e-02, 2.49916553e-01, -9.07782912e-02,\n",
  876. " 9.96572435e-01, -1.45562208e+00],\n",
  877. " [ 1.51482344e-01, 3.18607129e-02, -2.61826307e-01,\n",
  878. " -1.51025891e-01, -8.21501911e-02, 2.73460269e-01,\n",
  879. " -6.36418939e-01, 5.76719284e-01],\n",
  880. " [-9.25968707e-01, -4.82395217e-02, -1.45638064e-01,\n",
  881. " -3.07705432e-01, -8.87820303e-01, -2.56793760e-02,\n",
  882. " 1.42115021e+00, -9.66262281e-01],\n",
  883. " [ 6.59276664e-01, -2.08439991e-01, 2.04598561e-01,\n",
  884. " -2.05626473e-01, -8.42456043e-01, -6.28024265e-02,\n",
  885. " 1.40715313e+00, -8.71386409e-01],\n",
  886. " [-1.23248763e-01, 1.87416613e+00, 2.59033680e-01,\n",
  887. " -1.21581666e-01, -5.81157565e-01, -7.85437226e-02,\n",
  888. " 1.08055484e+00, -8.66392910e-01],\n",
  889. " [ 3.76925975e-01, -1.81044471e+00, -2.83305883e-01,\n",
  890. " -5.30055724e-02, 1.07472621e-01, -1.56231388e-01,\n",
  891. " -1.34093809e+00, 1.20090282e+00],\n",
  892. " [-5.23303270e-01, -1.32984328e+00, -1.37612730e-01,\n",
  893. " -1.91330597e-01, 2.53545702e-01, -1.70176193e-01,\n",
  894. " -1.16830754e+00, 1.18092895e+00],\n",
  895. " [-3.61214072e-01, 5.12462139e-01, -6.50760174e-01,\n",
  896. " -1.81008682e-01, -3.18951875e-01, -1.45220742e-01,\n",
  897. " 9.17255700e-01, -1.41068089e+00],\n",
  898. " [ 3.42984200e-01, 1.92061186e-01, 2.13778451e-01,\n",
  899. " -2.96686888e-01, -5.69362879e-01, -4.00676690e-02,\n",
  900. " -1.14031339e+00, 1.17593551e+00],\n",
  901. " [ 1.50238574e+00, -1.57014406e+00, 3.20324659e-01,\n",
  902. " -5.00563085e-02, -3.75203639e-01, 9.73348692e-02,\n",
  903. " -8.51040661e-01, 8.56353521e-01],\n",
  904. " [-7.03136027e-01, 1.39356470e+00, -7.07432747e-01,\n",
  905. " -4.66516092e-02, 2.47186041e+00, 1.57230482e-01,\n",
  906. " -7.15735674e-01, 6.86575592e-01],\n",
  907. " [-7.26261199e-01, 4.32361901e-01, -1.33732617e-01,\n",
  908. " -7.91970268e-02, 4.94928146e-03, -1.02592900e-01,\n",
  909. " -5.61767936e-01, 1.72248408e-01],\n",
  910. " [-3.96328092e-01, -2.08439991e-01, -1.22804724e-01,\n",
  911. " -1.11505985e-01, 1.06738138e+00, 9.11513790e-02,\n",
  912. " 1.29517651e+00, -1.57047188e+00]], dtype=float32)>, <tf.Tensor: shape=(32, 1), dtype=float32, numpy=\n",
  913. "array([[1.072 ],\n",
  914. " [3.506 ],\n",
  915. " [1.406 ],\n",
  916. " [3.425 ],\n",
  917. " [5.00001],\n",
  918. " [3.5 ],\n",
  919. " [2.607 ],\n",
  920. " [0.663 ],\n",
  921. " [0.57 ],\n",
  922. " [1.5 ],\n",
  923. " [3.636 ],\n",
  924. " [2.042 ],\n",
  925. " [2.176 ],\n",
  926. " [1.394 ],\n",
  927. " [2.079 ],\n",
  928. " [1.095 ],\n",
  929. " [0.992 ],\n",
  930. " [0.598 ],\n",
  931. " [1.273 ],\n",
  932. " [3.229 ],\n",
  933. " [1.5 ],\n",
  934. " [1.086 ],\n",
  935. " [1.634 ],\n",
  936. " [1.184 ],\n",
  937. " [1.648 ],\n",
  938. " [1.633 ],\n",
  939. " [3.184 ],\n",
  940. " [2.055 ],\n",
  941. " [2.309 ],\n",
  942. " [1.629 ],\n",
  943. " [2.603 ],\n",
  944. " [1.438 ]], dtype=float32)>)\n"
  945. ]
  946. }
  947. ],
  948. "source": [
  949. "for i in tfrecords_train_set.take(1):\n",
  950. " print(i)"
  951. ]
  952. },
  953. {
  954. "cell_type": "code",
  955. "execution_count": 14,
  956. "metadata": {},
  957. "outputs": [
  958. {
  959. "name": "stdout",
  960. "output_type": "stream",
  961. "text": [
  962. "Epoch 1/100\n",
  963. "348/348 [==============================] - 1s 3ms/step - loss: 0.8099 - val_loss: 0.6248\n",
  964. "Epoch 2/100\n",
  965. "348/348 [==============================] - 1s 3ms/step - loss: 0.5202 - val_loss: 0.5199\n",
  966. "Epoch 3/100\n",
  967. "348/348 [==============================] - 1s 3ms/step - loss: 0.4712 - val_loss: 0.4874\n",
  968. "Epoch 4/100\n",
  969. "348/348 [==============================] - 1s 3ms/step - loss: 0.4509 - val_loss: 0.4747\n",
  970. "Epoch 5/100\n",
  971. "348/348 [==============================] - 1s 2ms/step - loss: 0.4298 - val_loss: 0.4615\n",
  972. "Epoch 6/100\n",
  973. "348/348 [==============================] - 1s 2ms/step - loss: 0.4159 - val_loss: 0.4296\n",
  974. "Epoch 7/100\n",
  975. "348/348 [==============================] - 1s 2ms/step - loss: 0.4033 - val_loss: 0.4194\n",
  976. "Epoch 8/100\n",
  977. "348/348 [==============================] - 1s 2ms/step - loss: 0.4042 - val_loss: 0.4123\n",
  978. "Epoch 9/100\n",
  979. "348/348 [==============================] - 1s 2ms/step - loss: 0.5006 - val_loss: 0.4300\n",
  980. "Epoch 10/100\n",
  981. "348/348 [==============================] - 1s 2ms/step - loss: 0.3920 - val_loss: 0.4135\n",
  982. "Epoch 11/100\n",
  983. "348/348 [==============================] - 1s 3ms/step - loss: 0.3976 - val_loss: 0.4100\n",
  984. "Epoch 12/100\n",
  985. "348/348 [==============================] - 1s 2ms/step - loss: 0.3836 - val_loss: 0.3966\n",
  986. "Epoch 13/100\n",
  987. "348/348 [==============================] - 1s 2ms/step - loss: 0.3744 - val_loss: 0.3917\n",
  988. "Epoch 14/100\n",
  989. "348/348 [==============================] - 1s 2ms/step - loss: 0.4394 - val_loss: 0.4169\n",
  990. "Epoch 15/100\n",
  991. "348/348 [==============================] - 1s 2ms/step - loss: 0.3968 - val_loss: 0.3938\n",
  992. "Epoch 16/100\n",
  993. "348/348 [==============================] - 1s 2ms/step - loss: 0.3682 - val_loss: 0.3880\n",
  994. "Epoch 17/100\n",
  995. "348/348 [==============================] - 1s 2ms/step - loss: 0.3709 - val_loss: 0.3835\n",
  996. "Epoch 18/100\n",
  997. "348/348 [==============================] - 1s 2ms/step - loss: 0.3666 - val_loss: 0.3795\n",
  998. "Epoch 19/100\n",
  999. "348/348 [==============================] - 1s 2ms/step - loss: 0.3692 - val_loss: 0.3756\n",
  1000. "Epoch 20/100\n",
  1001. "348/348 [==============================] - 1s 2ms/step - loss: 0.3587 - val_loss: 0.3736\n",
  1002. "Epoch 21/100\n",
  1003. "348/348 [==============================] - 1s 2ms/step - loss: 0.3554 - val_loss: 0.3765\n",
  1004. "Epoch 22/100\n",
  1005. "348/348 [==============================] - 1s 2ms/step - loss: 0.3619 - val_loss: 0.3732\n",
  1006. "Epoch 23/100\n",
  1007. "348/348 [==============================] - 1s 2ms/step - loss: 0.3529 - val_loss: 0.4280\n",
  1008. "Epoch 24/100\n",
  1009. "348/348 [==============================] - 1s 2ms/step - loss: 0.3537 - val_loss: 0.3658\n",
  1010. "Epoch 25/100\n",
  1011. "348/348 [==============================] - 1s 2ms/step - loss: 0.3515 - val_loss: 0.3704\n",
  1012. "Epoch 26/100\n",
  1013. "348/348 [==============================] - 1s 2ms/step - loss: 0.3707 - val_loss: 0.3642\n",
  1014. "Epoch 27/100\n",
  1015. "348/348 [==============================] - 1s 2ms/step - loss: 0.3512 - val_loss: 0.3651\n"
  1016. ]
  1017. }
  1018. ],
  1019. "source": [
  1020. "#开始训练\n",
  1021. "model = keras.models.Sequential([\n",
  1022. " keras.layers.Dense(30, activation='relu',\n",
  1023. " input_shape=[8]),\n",
  1024. " keras.layers.Dense(1),\n",
  1025. "])\n",
  1026. "model.compile(loss=\"mean_squared_error\", optimizer=\"sgd\")\n",
  1027. "callbacks = [keras.callbacks.EarlyStopping(\n",
  1028. " patience=5, min_delta=1e-2)]\n",
  1029. "\n",
  1030. "history = model.fit(tfrecords_train_set,\n",
  1031. " validation_data = tfrecords_valid_set,\n",
  1032. " steps_per_epoch = 11160 // batch_size,\n",
  1033. " validation_steps = 3870 // batch_size,\n",
  1034. " epochs = 100,\n",
  1035. " callbacks = callbacks)"
  1036. ]
  1037. },
  1038. {
  1039. "cell_type": "code",
  1040. "execution_count": 20,
  1041. "metadata": {},
  1042. "outputs": [
  1043. {
  1044. "name": "stdout",
  1045. "output_type": "stream",
  1046. "text": [
  1047. "161/161 [==============================] - 0s 2ms/step - loss: 0.3376\n"
  1048. ]
  1049. },
  1050. {
  1051. "data": {
  1052. "text/plain": [
  1053. "0.33755674958229065"
  1054. ]
  1055. },
  1056. "execution_count": 20,
  1057. "metadata": {},
  1058. "output_type": "execute_result"
  1059. }
  1060. ],
  1061. "source": [
  1062. "model.evaluate(tfrecords_test_set, steps = 5160 // batch_size)"
  1063. ]
  1064. },
  1065. {
  1066. "cell_type": "code",
  1067. "execution_count": null,
  1068. "metadata": {},
  1069. "outputs": [],
  1070. "source": []
  1071. }
  1072. ],
  1073. "metadata": {
  1074. "kernelspec": {
  1075. "display_name": "Python 3",
  1076. "language": "python",
  1077. "name": "python3"
  1078. },
  1079. "language_info": {
  1080. "codemirror_mode": {
  1081. "name": "ipython",
  1082. "version": 3
  1083. },
  1084. "file_extension": ".py",
  1085. "mimetype": "text/x-python",
  1086. "name": "python",
  1087. "nbconvert_exporter": "python",
  1088. "pygments_lexer": "ipython3",
  1089. "version": "3.6.9"
  1090. }
  1091. },
  1092. "nbformat": 4,
  1093. "nbformat_minor": 2
  1094. }

随着人工智能和大数据的发展,任一方面对自动化工具有着一定的需求,在当下疫情防控期间,使用mindspore来实现yolo模型来进行目标检测及语义分割,对视频或图片都可以进行口罩佩戴检测和行人社交距离检测,来对公共场所的疫情防控来实行自动化管理。