You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tf02_data_generate_csv.ipynb 31 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 1,
  6. "metadata": {},
  7. "outputs": [
  8. {
  9. "name": "stdout",
  10. "output_type": "stream",
  11. "text": [
  12. "2.2.0\n",
  13. "sys.version_info(major=3, minor=6, micro=9, releaselevel='final', serial=0)\n",
  14. "matplotlib 3.3.4\n",
  15. "numpy 1.19.5\n",
  16. "pandas 1.1.5\n",
  17. "sklearn 0.24.2\n",
  18. "tensorflow 2.2.0\n",
  19. "tensorflow.keras 2.3.0-tf\n"
  20. ]
  21. }
  22. ],
  23. "source": [
  24. "import matplotlib as mpl\n",
  25. "import matplotlib.pyplot as plt\n",
  26. "%matplotlib inline\n",
  27. "import numpy as np\n",
  28. "import sklearn\n",
  29. "import pandas as pd\n",
  30. "import os\n",
  31. "import sys\n",
  32. "import time\n",
  33. "import tensorflow as tf\n",
  34. "\n",
  35. "from tensorflow import keras\n",
  36. "\n",
  37. "print(tf.__version__)\n",
  38. "print(sys.version_info)\n",
  39. "for module in mpl, np, pd, sklearn, tf, keras:\n",
  40. " print(module.__name__, module.__version__)"
  41. ]
  42. },
  43. {
  44. "cell_type": "code",
  45. "execution_count": 2,
  46. "metadata": {},
  47. "outputs": [],
  48. "source": [
  49. "from sklearn.datasets import fetch_california_housing\n",
  50. "\n",
  51. "housing = fetch_california_housing()"
  52. ]
  53. },
  54. {
  55. "cell_type": "code",
  56. "execution_count": 3,
  57. "metadata": {},
  58. "outputs": [
  59. {
  60. "name": "stdout",
  61. "output_type": "stream",
  62. "text": [
  63. "(11610, 8) (11610,)\n",
  64. "(3870, 8) (3870,)\n",
  65. "(5160, 8) (5160,)\n"
  66. ]
  67. }
  68. ],
  69. "source": [
  70. "from sklearn.model_selection import train_test_split\n",
  71. "\n",
  72. "x_train_all, x_test, y_train_all, y_test = train_test_split(\n",
  73. " housing.data, housing.target, random_state = 7)\n",
  74. "x_train, x_valid, y_train, y_valid = train_test_split(\n",
  75. " x_train_all, y_train_all, random_state = 11)\n",
  76. "print(x_train.shape, y_train.shape)\n",
  77. "print(x_valid.shape, y_valid.shape)\n",
  78. "print(x_test.shape, y_test.shape)\n"
  79. ]
  80. },
  81. {
  82. "cell_type": "code",
  83. "execution_count": 4,
  84. "metadata": {},
  85. "outputs": [],
  86. "source": [
  87. "from sklearn.preprocessing import StandardScaler\n",
  88. "\n",
  89. "scaler = StandardScaler()\n",
  90. "x_train_scaled = scaler.fit_transform(x_train)\n",
  91. "x_valid_scaled = scaler.transform(x_valid)\n",
  92. "x_test_scaled = scaler.transform(x_test)"
  93. ]
  94. },
  95. {
  96. "cell_type": "code",
  97. "execution_count": 6,
  98. "metadata": {},
  99. "outputs": [],
  100. "source": [
  101. "!rm -rf generate_csv"
  102. ]
  103. },
  104. {
  105. "cell_type": "code",
  106. "execution_count": 5,
  107. "metadata": {},
  108. "outputs": [
  109. {
  110. "name": "stdout",
  111. "output_type": "stream",
  112. "text": [
  113. "tf01-dataset_basic_api.ipynb tf03-tfrecord_basic_api.ipynb\r\n",
  114. "tf02_data_generate_csv.ipynb tf04_data_generate_tfrecord.ipynb\r\n"
  115. ]
  116. }
  117. ],
  118. "source": [
  119. "!ls"
  120. ]
  121. },
  122. {
  123. "cell_type": "code",
  124. "execution_count": 7,
  125. "metadata": {},
  126. "outputs": [
  127. {
  128. "data": {
  129. "text/plain": [
  130. "numpy.ndarray"
  131. ]
  132. },
  133. "execution_count": 7,
  134. "metadata": {},
  135. "output_type": "execute_result"
  136. }
  137. ],
  138. "source": [
  139. "type(x_train_scaled)"
  140. ]
  141. },
  142. {
  143. "cell_type": "code",
  144. "execution_count": 8,
  145. "metadata": {},
  146. "outputs": [
  147. {
  148. "data": {
  149. "text/plain": [
  150. "['.ipynb_checkpoints',\n",
  151. " 'tf02_data_generate_csv.ipynb',\n",
  152. " 'tf04_data_generate_tfrecord.ipynb',\n",
  153. " 'tf03-tfrecord_basic_api.ipynb',\n",
  154. " 'tf01-dataset_basic_api.ipynb']"
  155. ]
  156. },
  157. "execution_count": 8,
  158. "metadata": {},
  159. "output_type": "execute_result"
  160. }
  161. ],
  162. "source": [
  163. "os.listdir()"
  164. ]
  165. },
  166. {
  167. "cell_type": "code",
  168. "execution_count": 14,
  169. "metadata": {},
  170. "outputs": [
  171. {
  172. "name": "stdout",
  173. "output_type": "stream",
  174. "text": [
  175. "0 [0 1 2 3 4]\n",
  176. "1 [5 6 7 8 9]\n",
  177. "2 [10 11 12 13 14]\n",
  178. "3 [15 16 17 18 19]\n"
  179. ]
  180. }
  181. ],
  182. "source": [
  183. "#为了把数据分好\n",
  184. "for file_idx, row_indices in enumerate(np.array_split(np.arange(20), 4)):\n",
  185. " print(file_idx,row_indices)"
  186. ]
  187. },
  188. {
  189. "cell_type": "code",
  190. "execution_count": 10,
  191. "metadata": {
  192. "scrolled": true
  193. },
  194. "outputs": [
  195. {
  196. "name": "stdout",
  197. "output_type": "stream",
  198. "text": [
  199. "MedInc,HouseAge,AveRooms,AveBedrms,Population,AveOccup,Latitude,Longitude,MidianHouseValue\n",
  200. "--------------------------------------------------\n"
  201. ]
  202. }
  203. ],
  204. "source": [
  205. "#下面要把特征工程后的数据存为csv文件\n",
  206. "output_dir = \"generate_csv\"\n",
  207. "if not os.path.exists(output_dir):\n",
  208. " os.mkdir(output_dir)\n",
  209. "\n",
  210. "#save_to_csv是工作可以直接复用的\n",
  211. "def save_to_csv(output_dir, data, name_prefix,\n",
  212. " header=None, n_parts=10):\n",
  213. " #生成文件名 格式generate_csv/{}_{:02d}.csv\n",
  214. " path_format = os.path.join(output_dir, \"{}_{:02d}.csv\") \n",
  215. " filenames = []\n",
  216. " #把数据分为n_parts部分,写到文件中去\n",
  217. " for file_idx, row_indices in enumerate(\n",
  218. " np.array_split(np.arange(len(data)), n_parts)):\n",
  219. " #print(file_idx,row_indices)\n",
  220. " #生成子文件名\n",
  221. " part_csv = path_format.format(name_prefix, file_idx)\n",
  222. " filenames.append(part_csv) #文件名添加到列表\n",
  223. " with open(part_csv, \"w\", encoding=\"utf-8\") as f:\n",
  224. " #先写头部\n",
  225. " if header is not None:\n",
  226. " f.write(header + \"\\n\")\n",
  227. " for row_index in row_indices:\n",
  228. " #把字符串化后的每个字符串用逗号拼接起来\n",
  229. " f.write(\",\".join(\n",
  230. " [repr(col) for col in data[row_index]]))\n",
  231. " f.write('\\n')\n",
  232. " return filenames\n",
  233. "#np.c_把x和y合并起来,按轴1合并\n",
  234. "train_data = np.c_[x_train_scaled, y_train]\n",
  235. "valid_data = np.c_[x_valid_scaled, y_valid]\n",
  236. "test_data = np.c_[x_test_scaled, y_test]\n",
  237. "#头部,特征,也有目标\n",
  238. "header_cols = housing.feature_names + [\"MidianHouseValue\"]\n",
  239. "#把列表变为字符串\n",
  240. "header_str = \",\".join(header_cols)\n",
  241. "print(header_str)\n",
  242. "print('-'*50)\n",
  243. "train_filenames = save_to_csv(output_dir, train_data, \"train\",\n",
  244. " header_str, n_parts=20)\n",
  245. "valid_filenames = save_to_csv(output_dir, valid_data, \"valid\",\n",
  246. " header_str, n_parts=10)\n",
  247. "test_filenames = save_to_csv(output_dir, test_data, \"test\",\n",
  248. " header_str, n_parts=10)"
  249. ]
  250. },
  251. {
  252. "cell_type": "code",
  253. "execution_count": 12,
  254. "metadata": {},
  255. "outputs": [],
  256. "source": [
  257. "temp_array=np.array([[1,2,3],[4,5,6]])\n",
  258. "np.savetxt(\"temp.csv\",temp_array) #savetxt会自动将整型数或者浮点数转为字符串存储"
  259. ]
  260. },
  261. {
  262. "cell_type": "code",
  263. "execution_count": 13,
  264. "metadata": {},
  265. "outputs": [
  266. {
  267. "name": "stdout",
  268. "output_type": "stream",
  269. "text": [
  270. "1.000000000000000000e+00 2.000000000000000000e+00 3.000000000000000000e+00\r\n",
  271. "4.000000000000000000e+00 5.000000000000000000e+00 6.000000000000000000e+00\r\n"
  272. ]
  273. }
  274. ],
  275. "source": [
  276. "!cat temp.csv"
  277. ]
  278. },
  279. {
  280. "cell_type": "code",
  281. "execution_count": 14,
  282. "metadata": {
  283. "collapsed": true
  284. },
  285. "outputs": [
  286. {
  287. "name": "stdout",
  288. "output_type": "stream",
  289. "text": [
  290. "['generate_csv/train_00.csv', 'generate_csv/train_01.csv', 'generate_csv/train_02.csv', 'generate_csv/train_03.csv', 'generate_csv/train_04.csv', 'generate_csv/train_05.csv', 'generate_csv/train_06.csv', 'generate_csv/train_07.csv', 'generate_csv/train_08.csv', 'generate_csv/train_09.csv', 'generate_csv/train_10.csv', 'generate_csv/train_11.csv', 'generate_csv/train_12.csv', 'generate_csv/train_13.csv', 'generate_csv/train_14.csv', 'generate_csv/train_15.csv', 'generate_csv/train_16.csv', 'generate_csv/train_17.csv', 'generate_csv/train_18.csv', 'generate_csv/train_19.csv']\n",
  291. "train filenames:\n",
  292. "['generate_csv/train_00.csv',\n",
  293. " 'generate_csv/train_01.csv',\n",
  294. " 'generate_csv/train_02.csv',\n",
  295. " 'generate_csv/train_03.csv',\n",
  296. " 'generate_csv/train_04.csv',\n",
  297. " 'generate_csv/train_05.csv',\n",
  298. " 'generate_csv/train_06.csv',\n",
  299. " 'generate_csv/train_07.csv',\n",
  300. " 'generate_csv/train_08.csv',\n",
  301. " 'generate_csv/train_09.csv',\n",
  302. " 'generate_csv/train_10.csv',\n",
  303. " 'generate_csv/train_11.csv',\n",
  304. " 'generate_csv/train_12.csv',\n",
  305. " 'generate_csv/train_13.csv',\n",
  306. " 'generate_csv/train_14.csv',\n",
  307. " 'generate_csv/train_15.csv',\n",
  308. " 'generate_csv/train_16.csv',\n",
  309. " 'generate_csv/train_17.csv',\n",
  310. " 'generate_csv/train_18.csv',\n",
  311. " 'generate_csv/train_19.csv']\n",
  312. "valid filenames:\n",
  313. "['generate_csv/valid_00.csv',\n",
  314. " 'generate_csv/valid_01.csv',\n",
  315. " 'generate_csv/valid_02.csv',\n",
  316. " 'generate_csv/valid_03.csv',\n",
  317. " 'generate_csv/valid_04.csv',\n",
  318. " 'generate_csv/valid_05.csv',\n",
  319. " 'generate_csv/valid_06.csv',\n",
  320. " 'generate_csv/valid_07.csv',\n",
  321. " 'generate_csv/valid_08.csv',\n",
  322. " 'generate_csv/valid_09.csv']\n",
  323. "test filenames:\n",
  324. "['generate_csv/test_00.csv',\n",
  325. " 'generate_csv/test_01.csv',\n",
  326. " 'generate_csv/test_02.csv',\n",
  327. " 'generate_csv/test_03.csv',\n",
  328. " 'generate_csv/test_04.csv',\n",
  329. " 'generate_csv/test_05.csv',\n",
  330. " 'generate_csv/test_06.csv',\n",
  331. " 'generate_csv/test_07.csv',\n",
  332. " 'generate_csv/test_08.csv',\n",
  333. " 'generate_csv/test_09.csv']\n"
  334. ]
  335. }
  336. ],
  337. "source": [
  338. "#看下生成文件的文件名\n",
  339. "print(train_filenames)\n",
  340. "import pprint #为了打印美观性\n",
  341. "print(\"train filenames:\")\n",
  342. "pprint.pprint(train_filenames)\n",
  343. "print(\"valid filenames:\")\n",
  344. "pprint.pprint(valid_filenames)\n",
  345. "print(\"test filenames:\")\n",
  346. "pprint.pprint(test_filenames)"
  347. ]
  348. },
  349. {
  350. "cell_type": "code",
  351. "execution_count": 16,
  352. "metadata": {},
  353. "outputs": [
  354. {
  355. "name": "stdout",
  356. "output_type": "stream",
  357. "text": [
  358. "tf.Tensor(b'generate_csv/train_13.csv', shape=(), dtype=string)\n",
  359. "tf.Tensor(b'generate_csv/train_01.csv', shape=(), dtype=string)\n",
  360. "tf.Tensor(b'generate_csv/train_14.csv', shape=(), dtype=string)\n",
  361. "tf.Tensor(b'generate_csv/train_11.csv', shape=(), dtype=string)\n",
  362. "tf.Tensor(b'generate_csv/train_12.csv', shape=(), dtype=string)\n",
  363. "tf.Tensor(b'generate_csv/train_06.csv', shape=(), dtype=string)\n",
  364. "tf.Tensor(b'generate_csv/train_15.csv', shape=(), dtype=string)\n",
  365. "tf.Tensor(b'generate_csv/train_10.csv', shape=(), dtype=string)\n",
  366. "tf.Tensor(b'generate_csv/train_05.csv', shape=(), dtype=string)\n",
  367. "tf.Tensor(b'generate_csv/train_02.csv', shape=(), dtype=string)\n",
  368. "tf.Tensor(b'generate_csv/train_00.csv', shape=(), dtype=string)\n",
  369. "tf.Tensor(b'generate_csv/train_07.csv', shape=(), dtype=string)\n",
  370. "tf.Tensor(b'generate_csv/train_16.csv', shape=(), dtype=string)\n",
  371. "tf.Tensor(b'generate_csv/train_09.csv', shape=(), dtype=string)\n",
  372. "tf.Tensor(b'generate_csv/train_19.csv', shape=(), dtype=string)\n",
  373. "tf.Tensor(b'generate_csv/train_03.csv', shape=(), dtype=string)\n",
  374. "tf.Tensor(b'generate_csv/train_04.csv', shape=(), dtype=string)\n",
  375. "tf.Tensor(b'generate_csv/train_18.csv', shape=(), dtype=string)\n",
  376. "tf.Tensor(b'generate_csv/train_17.csv', shape=(), dtype=string)\n",
  377. "tf.Tensor(b'generate_csv/train_08.csv', shape=(), dtype=string)\n"
  378. ]
  379. }
  380. ],
  381. "source": [
  382. "# 1. filename -> dataset\n",
  383. "# 2. read file -> dataset -> datasets -> merge\n",
  384. "# 3. parse csv\n",
  385. "#list_files把文件名搞为一个dataset\n",
  386. "# list_files默认行为是按不确定的随机混排顺序返回文件名\n",
  387. "filename_dataset = tf.data.Dataset.list_files(train_filenames)\n",
  388. "for filename in filename_dataset:\n",
  389. " print(filename)"
  390. ]
  391. },
  392. {
  393. "cell_type": "code",
  394. "execution_count": 17,
  395. "metadata": {},
  396. "outputs": [
  397. {
  398. "name": "stdout",
  399. "output_type": "stream",
  400. "text": [
  401. "tf.Tensor(b'generate_csv/train_00.csv', shape=(), dtype=string)\n",
  402. "tf.Tensor(b'generate_csv/train_01.csv', shape=(), dtype=string)\n",
  403. "tf.Tensor(b'generate_csv/train_02.csv', shape=(), dtype=string)\n",
  404. "tf.Tensor(b'generate_csv/train_03.csv', shape=(), dtype=string)\n",
  405. "tf.Tensor(b'generate_csv/train_04.csv', shape=(), dtype=string)\n",
  406. "tf.Tensor(b'generate_csv/train_05.csv', shape=(), dtype=string)\n",
  407. "tf.Tensor(b'generate_csv/train_06.csv', shape=(), dtype=string)\n",
  408. "tf.Tensor(b'generate_csv/train_07.csv', shape=(), dtype=string)\n",
  409. "tf.Tensor(b'generate_csv/train_08.csv', shape=(), dtype=string)\n",
  410. "tf.Tensor(b'generate_csv/train_09.csv', shape=(), dtype=string)\n",
  411. "tf.Tensor(b'generate_csv/train_10.csv', shape=(), dtype=string)\n",
  412. "tf.Tensor(b'generate_csv/train_11.csv', shape=(), dtype=string)\n",
  413. "tf.Tensor(b'generate_csv/train_12.csv', shape=(), dtype=string)\n",
  414. "tf.Tensor(b'generate_csv/train_13.csv', shape=(), dtype=string)\n",
  415. "tf.Tensor(b'generate_csv/train_14.csv', shape=(), dtype=string)\n",
  416. "tf.Tensor(b'generate_csv/train_15.csv', shape=(), dtype=string)\n",
  417. "tf.Tensor(b'generate_csv/train_16.csv', shape=(), dtype=string)\n",
  418. "tf.Tensor(b'generate_csv/train_17.csv', shape=(), dtype=string)\n",
  419. "tf.Tensor(b'generate_csv/train_18.csv', shape=(), dtype=string)\n",
  420. "tf.Tensor(b'generate_csv/train_19.csv', shape=(), dtype=string)\n"
  421. ]
  422. }
  423. ],
  424. "source": [
  425. "filename_mydataset=tf.data.Dataset.from_tensor_slices(train_filenames)\n",
  426. "filename_mydataset=filename_mydataset.repeat(1)\n",
  427. "for i in filename_mydataset:\n",
  428. " print(i)"
  429. ]
  430. },
  431. {
  432. "cell_type": "code",
  433. "execution_count": null,
  434. "metadata": {},
  435. "outputs": [],
  436. "source": [
  437. "# 把数据从文件中拿出来"
  438. ]
  439. },
  440. {
  441. "cell_type": "code",
  442. "execution_count": 24,
  443. "metadata": {
  444. "scrolled": false
  445. },
  446. "outputs": [
  447. {
  448. "name": "stdout",
  449. "output_type": "stream",
  450. "text": [
  451. "tf.Tensor(b'0.801544314532886,0.27216142415910205,-0.11624392696666119,-0.2023115137272354,-0.5430515742518128,-0.021039615516440048,-0.5897620622908205,-0.08241845654707416,3.226', shape=(), dtype=string)\n",
  452. "tf.Tensor(b'-0.2980728090942217,0.3522616607867429,-0.10920507530549702,-0.25055520947444,-0.034064024638222286,-0.006034004264459185,1.080554840130013,-1.0611381656679573,1.514', shape=(), dtype=string)\n",
  453. "tf.Tensor(b'0.8115083791797953,-0.04823952235146133,0.5187339067174729,-0.029386394873127775,-0.034064024638222286,-0.05081594842905086,-0.7157356834231196,0.9162751241885168,2.147', shape=(), dtype=string)\n",
  454. "tf.Tensor(b'-0.6906143291679195,-0.1283397589791022,7.0201810347470595,5.624287386169439,-0.2663292879200034,-0.03662080416157129,-0.6457503383496215,1.2058962626018372,1.352', shape=(), dtype=string)\n",
  455. "tf.Tensor(b'0.401276648075221,-0.9293421252555106,-0.05333050451405854,-0.1865945262276826,0.6545661895448709,0.026434465728210874,0.9312527706398824,-1.4406417263474771,2.512', shape=(), dtype=string)\n",
  456. "tf.Tensor(b'-0.8757754235423053,1.874166156711919,-0.9487499555702599,-0.09657184824705009,-0.7163432355284542,-0.07790191228558485,0.9825753570271144,-1.4206678547327694,2.75', shape=(), dtype=string)\n",
  457. "tf.Tensor(b'0.15782311132800697,0.43236189741438374,0.3379948076652917,-0.015880306122244434,-0.3733890577139493,-0.05305245634489608,0.8006134598360177,-1.2359095422966828,3.169', shape=(), dtype=string)\n",
  458. "tf.Tensor(b'2.2878417437355094,-1.8905449647872008,0.6607106467795992,-0.14964778023694128,-0.06672632728722275,0.44788055801575993,-0.5337737862320228,0.5667323709310584,3.59', shape=(), dtype=string)\n",
  459. "tf.Tensor(b'-1.0591781535672364,1.393564736946074,-0.026331968874673636,-0.11006759528831847,-0.6138198966579805,-0.09695934953589447,0.3247131133362288,-0.037477245413977976,0.672', shape=(), dtype=string)\n",
  460. "tf.Tensor(b'-0.2223565745313433,1.393564736946074,0.02991299565857307,0.0801452044790158,-0.509481985418118,-0.06238599304952824,-0.86503775291325,0.8613469772480595,2.0', shape=(), dtype=string)\n",
  461. "tf.Tensor(b'-0.03058829290446139,-0.9293421252555106,0.2596214817762415,-0.00601274044096368,-0.5004091235711734,-0.030779867916061836,1.5984463936739026,-1.8151518191233238,1.598', shape=(), dtype=string)\n",
  462. "tf.Tensor(b'1.9063832474401923,0.5124621340420246,0.44758280183798754,-0.276721775345798,-0.6310583341671753,-0.07081146722873086,-0.7064043040799849,0.7464972154634646,5.00001', shape=(), dtype=string)\n",
  463. "tf.Tensor(b'-0.9868720801669367,0.832863080552588,-0.18684708416901633,-0.14888949288707784,-0.4532302419670616,-0.11504995754593579,1.6730974284189664,-0.7465496877362412,1.138', shape=(), dtype=string)\n",
  464. "tf.Tensor(b'0.29422955783115173,1.874166156711919,0.004626028663628252,-0.28479278487900694,-0.5602900117610076,-0.1196496378702887,1.3558305307524392,-0.9512818717870428,1.625', shape=(), dtype=string)\n",
  465. "tf.Tensor(b'0.7751155655229017,1.874166156711919,0.15645971958808144,-0.18905190538070707,-0.6292437617977863,-0.08791603438866835,-0.7483955111240856,0.5717258388347319,4.851', shape=(), dtype=string)\n"
  466. ]
  467. }
  468. ],
  469. "source": [
  470. "#一访问list_files的dataset对象就随机了文件顺序\n",
  471. "# for filename in filename_dataset:\n",
  472. "# print(filename)\n",
  473. "n_readers = 5\n",
  474. "dataset = filename_mydataset.interleave(\n",
  475. " #前面1行是header\n",
  476. "# lambda filename: tf.data.TextLineDataset(filename),\n",
  477. " #不带header,把特征名字去掉\n",
  478. " lambda filename: tf.data.TextLineDataset(filename).skip(1),\n",
  479. " cycle_length = n_readers, #cycle_length和block_length增加获取了数据的随机性\n",
  480. " block_length=2\n",
  481. ")\n",
  482. "for line in dataset.take(15):\n",
  483. " print(line)"
  484. ]
  485. },
  486. {
  487. "cell_type": "code",
  488. "execution_count": null,
  489. "metadata": {},
  490. "outputs": [],
  491. "source": [
  492. "# 把每一行数据切分为对应类型"
  493. ]
  494. },
  495. {
  496. "cell_type": "code",
  497. "execution_count": 18,
  498. "metadata": {},
  499. "outputs": [
  500. {
  501. "name": "stdout",
  502. "output_type": "stream",
  503. "text": [
  504. "[<tf.Tensor: shape=(), dtype=int32, numpy=1>, <tf.Tensor: shape=(), dtype=int32, numpy=2>, <tf.Tensor: shape=(), dtype=float32, numpy=3.0>, <tf.Tensor: shape=(), dtype=string, numpy=b'4'>, <tf.Tensor: shape=(), dtype=float32, numpy=5.0>]\n"
  505. ]
  506. }
  507. ],
  508. "source": [
  509. "#parse csv 解析csv,通过decode_csv\n",
  510. "# tf.io.decode_csv(str, record_defaults)\n",
  511. "\n",
  512. "sample_str = '1,2,3,4,5'\n",
  513. "record_defaults = [\n",
  514. " tf.constant(0, dtype=tf.int32),\n",
  515. " 0,\n",
  516. " np.nan,\n",
  517. " \"hello1\",\n",
  518. " tf.constant([])#没有固定类型,默认是float32\n",
  519. "]\n",
  520. "#sample_str数据格式化,按照record_defaults进行处理\n",
  521. "parsed_fields = tf.io.decode_csv(sample_str, record_defaults)\n",
  522. "print(parsed_fields)"
  523. ]
  524. },
  525. {
  526. "cell_type": "code",
  527. "execution_count": 19,
  528. "metadata": {},
  529. "outputs": [
  530. {
  531. "data": {
  532. "text/plain": [
  533. "[<tf.Tensor: shape=(), dtype=int32, numpy=0>,\n",
  534. " <tf.Tensor: shape=(), dtype=int32, numpy=0>,\n",
  535. " <tf.Tensor: shape=(), dtype=float32, numpy=nan>,\n",
  536. " <tf.Tensor: shape=(), dtype=string, numpy=b'hello1'>,\n",
  537. " <tf.Tensor: shape=(), dtype=float32, numpy=1.0>]"
  538. ]
  539. },
  540. "execution_count": 19,
  541. "metadata": {},
  542. "output_type": "execute_result"
  543. }
  544. ],
  545. "source": [
  546. "#我们传一个空的字符串测试\n",
  547. "#最后一个为1是可以转换的\n",
  548. "try:\n",
  549. " parsed_fields = tf.io.decode_csv(',,,,1', record_defaults)\n",
  550. "except tf.errors.InvalidArgumentError as ex:\n",
  551. " print(ex)\n",
  552. "parsed_fields"
  553. ]
  554. },
  555. {
  556. "cell_type": "code",
  557. "execution_count": 20,
  558. "metadata": {},
  559. "outputs": [
  560. {
  561. "name": "stdout",
  562. "output_type": "stream",
  563. "text": [
  564. "Expect 5 fields but have 7 in record 0 [Op:DecodeCSV]\n"
  565. ]
  566. }
  567. ],
  568. "source": [
  569. "#我们给的值过多的情况\n",
  570. "try:\n",
  571. " parsed_fields = tf.io.decode_csv('1,2,3,4,5,6,', record_defaults)\n",
  572. "except tf.errors.InvalidArgumentError as ex:\n",
  573. " print(ex)"
  574. ]
  575. },
  576. {
  577. "cell_type": "code",
  578. "execution_count": 21,
  579. "metadata": {},
  580. "outputs": [
  581. {
  582. "data": {
  583. "text/plain": [
  584. "(<tf.Tensor: shape=(8,), dtype=float32, numpy=\n",
  585. " array([-0.9868721 , 0.8328631 , -0.18684709, -0.1488895 , -0.45323023,\n",
  586. " -0.11504996, 1.6730974 , -0.74654967], dtype=float32)>,\n",
  587. " <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.138], dtype=float32)>)"
  588. ]
  589. },
  590. "execution_count": 21,
  591. "metadata": {},
  592. "output_type": "execute_result"
  593. }
  594. ],
  595. "source": [
  596. "#解析一行\n",
  597. "def parse_csv_line(line, n_fields = 9):\n",
  598. " #先写一个默认的格式,就是9个nan,如果从csv中读取缺失数据,就会变为nan\n",
  599. " defs = [tf.constant(np.nan)] * n_fields\n",
  600. " #使用decode_csv解析\n",
  601. " parsed_fields = tf.io.decode_csv(line, record_defaults=defs)\n",
  602. " #前8个是x,最后一个是y\n",
  603. " x = tf.stack(parsed_fields[0:-1])\n",
  604. " y = tf.stack(parsed_fields[-1:])\n",
  605. " return x, y\n",
  606. "\n",
  607. "parse_csv_line(b'-0.9868720801669367,0.832863080552588,-0.18684708416901633,-0.14888949288707784,-0.4532302419670616,-0.11504995754593579,1.6730974284189664,-0.7465496877362412,1.138',\n",
  608. " n_fields=9)"
  609. ]
  610. },
  611. {
  612. "cell_type": "code",
  613. "execution_count": 22,
  614. "metadata": {
  615. "collapsed": true
  616. },
  617. "outputs": [
  618. {
  619. "name": "stdout",
  620. "output_type": "stream",
  621. "text": [
  622. "<BatchDataset shapes: ((None, 8), (None, 1)), types: (tf.float32, tf.float32)>\n",
  623. "--------------------------------------------------\n",
  624. "x:\n",
  625. "<tf.Tensor: shape=(4, 8), dtype=float32, numpy=\n",
  626. "array([[ 0.15782312, 0.4323619 , 0.3379948 , -0.01588031, -0.37338907,\n",
  627. " -0.05305246, 0.80061346, -1.2359096 ],\n",
  628. " [-1.0591781 , 1.3935647 , -0.02633197, -0.1100676 , -0.6138199 ,\n",
  629. " -0.09695935, 0.3247131 , -0.03747724],\n",
  630. " [-0.82195884, 1.8741661 , 0.1821235 , -0.03170019, -0.6011179 ,\n",
  631. " -0.14337493, 1.0852206 , -0.8613995 ],\n",
  632. " [ 0.63034356, 1.8741661 , -0.06713215, -0.12543367, -0.19737554,\n",
  633. " -0.02272263, -0.69240725, 0.72652334]], dtype=float32)>\n",
  634. "y:\n",
  635. "<tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
  636. "array([[3.169],\n",
  637. " [0.672],\n",
  638. " [1.054],\n",
  639. " [2.419]], dtype=float32)>\n",
  640. "x:\n",
  641. "<tf.Tensor: shape=(4, 8), dtype=float32, numpy=\n",
  642. "array([[ 0.48530516, -0.8492419 , -0.06530126, -0.02337966, 1.4974351 ,\n",
  643. " -0.07790658, -0.90236324, 0.78145146],\n",
  644. " [ 2.2878418 , -1.890545 , 0.66071063, -0.14964779, -0.06672633,\n",
  645. " 0.44788057, -0.5337738 , 0.56673235],\n",
  646. " [-0.22235657, 1.3935647 , 0.029913 , 0.0801452 , -0.50948197,\n",
  647. " -0.06238599, -0.86503774, 0.86134696],\n",
  648. " [-0.46794146, -0.92934215, 0.11909926, -0.06047011, 0.30344644,\n",
  649. " -0.02185189, 1.8737221 , -1.0411643 ]], dtype=float32)>\n",
  650. "y:\n",
  651. "<tf.Tensor: shape=(4, 1), dtype=float32, numpy=\n",
  652. "array([[2.956],\n",
  653. " [3.59 ],\n",
  654. " [2. ],\n",
  655. " [1.012]], dtype=float32)>\n"
  656. ]
  657. }
  658. ],
  659. "source": [
  660. "# 1. filename -> dataset\n",
  661. "# 2. read file -> dataset -> datasets -> merge\n",
  662. "# 3. parse csv\n",
  663. "#完成整个流程\n",
  664. "def csv_reader_dataset(filenames, n_readers=5,\n",
  665. " batch_size=32, n_parse_threads=5,\n",
  666. " shuffle_buffer_size=10000):\n",
  667. " #把文件名类别变为dataset tensor\n",
  668. " dataset = tf.data.Dataset.list_files(filenames)\n",
  669. " #变为repeat dataset可以让读到最后一个样本时,从新去读第一个样本\n",
  670. " dataset = dataset.repeat()\n",
  671. " dataset = dataset.interleave(\n",
  672. " #skip(1)是因为每个文件存了特征名字,target名字\n",
  673. " lambda filename: tf.data.TextLineDataset(filename).skip(1),\n",
  674. " cycle_length = n_readers\n",
  675. " )\n",
  676. " dataset.shuffle(shuffle_buffer_size) #对数据进行洗牌,混乱\n",
  677. " #map,通过parse_csv_line对数据集进行映射,map只会给函数传递一个参数,这个参数\n",
  678. " #就是dataset中的tensor\n",
  679. " dataset = dataset.map(parse_csv_line,\n",
  680. " num_parallel_calls=n_parse_threads)\n",
  681. " dataset = dataset.batch(batch_size)\n",
  682. " return dataset\n",
  683. "#这里是一个测试,写4是为了大家理解\n",
  684. "train_set = csv_reader_dataset(train_filenames, batch_size=4)\n",
  685. "print(train_set)\n",
  686. "print('-'*50)\n",
  687. "i=0\n",
  688. "#是csv_reader_dataset处理后的结果,\n",
  689. "for x_batch, y_batch in train_set.take(2):\n",
  690. "# i=i+1\n",
  691. " print(\"x:\")\n",
  692. " pprint.pprint(x_batch)\n",
  693. " print(\"y:\")\n",
  694. " pprint.pprint(y_batch)\n",
  695. "# print(i)"
  696. ]
  697. },
  698. {
  699. "cell_type": "code",
  700. "execution_count": 23,
  701. "metadata": {},
  702. "outputs": [
  703. {
  704. "name": "stdout",
  705. "output_type": "stream",
  706. "text": [
  707. "CPU times: user 137 ms, sys: 40.3 ms, total: 177 ms\n",
  708. "Wall time: 160 ms\n"
  709. ]
  710. }
  711. ],
  712. "source": [
  713. "%%time\n",
  714. "batch_size = 32\n",
  715. "train_set = csv_reader_dataset(train_filenames,\n",
  716. " batch_size = batch_size)\n",
  717. "valid_set = csv_reader_dataset(valid_filenames,\n",
  718. " batch_size = batch_size)\n",
  719. "test_set = csv_reader_dataset(test_filenames,\n",
  720. " batch_size = batch_size)\n",
  721. "\n",
  722. "# print(train_set)\n",
  723. "# print(valid_set)\n",
  724. "# print(test_set)"
  725. ]
  726. },
  727. {
  728. "cell_type": "code",
  729. "execution_count": 24,
  730. "metadata": {},
  731. "outputs": [
  732. {
  733. "name": "stdout",
  734. "output_type": "stream",
  735. "text": [
  736. "Epoch 1/100\n",
  737. "348/348 [==============================] - 1s 3ms/step - loss: 1.1306 - val_loss: 0.9811\n",
  738. "Epoch 2/100\n",
  739. "348/348 [==============================] - 1s 3ms/step - loss: 2.4388 - val_loss: 0.5692\n",
  740. "Epoch 3/100\n",
  741. "348/348 [==============================] - 1s 3ms/step - loss: 0.5545 - val_loss: 0.6181\n",
  742. "Epoch 4/100\n",
  743. "348/348 [==============================] - 1s 4ms/step - loss: 0.6097 - val_loss: 0.4497\n",
  744. "Epoch 5/100\n",
  745. "348/348 [==============================] - 1s 3ms/step - loss: 0.4277 - val_loss: 0.4555\n",
  746. "Epoch 6/100\n",
  747. "348/348 [==============================] - 1s 4ms/step - loss: 0.3998 - val_loss: 0.3870\n",
  748. "Epoch 7/100\n",
  749. "348/348 [==============================] - 1s 4ms/step - loss: 0.3889 - val_loss: 0.4119\n",
  750. "Epoch 8/100\n",
  751. "348/348 [==============================] - 1s 3ms/step - loss: 0.3831 - val_loss: 0.3941\n",
  752. "Epoch 9/100\n",
  753. "348/348 [==============================] - 1s 3ms/step - loss: 0.3870 - val_loss: 0.4068\n",
  754. "Epoch 10/100\n",
  755. "348/348 [==============================] - 1s 3ms/step - loss: 0.3689 - val_loss: 0.3801\n",
  756. "Epoch 11/100\n",
  757. "348/348 [==============================] - 1s 3ms/step - loss: 0.3804 - val_loss: 0.3957\n"
  758. ]
  759. }
  760. ],
  761. "source": [
  762. "#我们知道长度为8\n",
  763. "model = keras.models.Sequential([\n",
  764. " keras.layers.Dense(30, activation='relu',\n",
  765. " input_shape=[8]),\n",
  766. " keras.layers.Dense(1),\n",
  767. "])\n",
  768. "model.compile(loss=\"mean_squared_error\", optimizer=\"sgd\")\n",
  769. "callbacks = [keras.callbacks.EarlyStopping(\n",
  770. " patience=5, min_delta=1e-2)]\n",
  771. "\n",
  772. "#当是BatchDataset,必须制定steps_per_epoch,validation_steps\n",
  773. "history = model.fit(train_set,\n",
  774. " validation_data = valid_set,\n",
  775. " steps_per_epoch = 11160 // batch_size, #每epoch训练的步数\n",
  776. " validation_steps = 3870 // batch_size,\n",
  777. " epochs = 100,\n",
  778. " callbacks = callbacks)"
  779. ]
  780. },
  781. {
  782. "cell_type": "code",
  783. "execution_count": 25,
  784. "metadata": {},
  785. "outputs": [
  786. {
  787. "name": "stdout",
  788. "output_type": "stream",
  789. "text": [
  790. "161/161 [==============================] - 0s 2ms/step - loss: 0.3995\n"
  791. ]
  792. },
  793. {
  794. "data": {
  795. "text/plain": [
  796. "0.39946985244750977"
  797. ]
  798. },
  799. "execution_count": 25,
  800. "metadata": {},
  801. "output_type": "execute_result"
  802. }
  803. ],
  804. "source": [
  805. "model.evaluate(test_set, steps = 5160 // batch_size)"
  806. ]
  807. },
  808. {
  809. "cell_type": "code",
  810. "execution_count": 37,
  811. "metadata": {},
  812. "outputs": [
  813. {
  814. "data": {
  815. "text/plain": [
  816. "[<tf.Tensor: shape=(4,), dtype=int64, numpy=array([0, 1, 2, 3])>,\n",
  817. " <tf.Tensor: shape=(4,), dtype=int64, numpy=array([4, 5, 6, 7])>]"
  818. ]
  819. },
  820. "execution_count": 37,
  821. "metadata": {},
  822. "output_type": "execute_result"
  823. }
  824. ],
  825. "source": [
  826. "dataset = tf.data.Dataset.range(8)\n",
  827. "dataset = dataset.batch(4) #把tensor组合到一起,就是分了batch\n",
  828. "list(dataset)"
  829. ]
  830. }
  831. ],
  832. "metadata": {
  833. "kernelspec": {
  834. "display_name": "Python 3",
  835. "language": "python",
  836. "name": "python3"
  837. },
  838. "language_info": {
  839. "codemirror_mode": {
  840. "name": "ipython",
  841. "version": 3
  842. },
  843. "file_extension": ".py",
  844. "mimetype": "text/x-python",
  845. "name": "python",
  846. "nbconvert_exporter": "python",
  847. "pygments_lexer": "ipython3",
  848. "version": "3.6.9"
  849. }
  850. },
  851. "nbformat": 4,
  852. "nbformat_minor": 2
  853. }

随着人工智能和大数据的发展,任一方面对自动化工具有着一定的需求,在当下疫情防控期间,使用mindspore来实现yolo模型来进行目标检测及语义分割,对视频或图片都可以进行口罩佩戴检测和行人社交距离检测,来对公共场所的疫情防控来实行自动化管理。