You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 26 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import os
  4. import random
  5. import subprocess
  6. import sys
  7. import time
  8. from collections import Counter
  9. from sys import exit as _exit
  10. from sys import platform as _platform
  11. import numpy as np
  12. import tensorflow as tf
  13. from sklearn.metrics import accuracy_score, confusion_matrix, f1_score
  14. import tensorlayer as tl
  15. __all__ = [
  16. 'fit', 'test', 'predict', 'evaluation', 'dict_to_one', 'flatten_list', 'class_balancing_oversample',
  17. 'get_random_int', 'list_string_to_dict', 'exit_tensorflow', 'open_tensorboard', 'clear_all_placeholder_variables',
  18. 'set_gpu_fraction', 'train_epoch', 'run_epoch'
  19. ]
  20. def fit(
  21. network, train_op, cost, X_train, y_train, acc=None, batch_size=100, n_epoch=100, print_freq=5, X_val=None,
  22. y_val=None, eval_train=True, tensorboard_dir=None, tensorboard_epoch_freq=5, tensorboard_weight_histograms=True,
  23. tensorboard_graph_vis=True
  24. ):
  25. """Training a given non time-series network by the given cost function, training data, batch_size, n_epoch etc.
  26. - MNIST example click `here <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_mnist_simple.py>`_.
  27. - In order to control the training details, the authors HIGHLY recommend ``tl.iterate`` see two MNIST examples `1 <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_mlp_dropout1.py>`_, `2 <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_mlp_dropout1.py>`_.
  28. Parameters
  29. ----------
  30. network : TensorLayer Model
  31. the network to be trained.
  32. train_op : TensorFlow optimizer
  33. The optimizer for training e.g. tf.optimizers.Adam().
  34. cost : TensorLayer or TensorFlow loss function
  35. Metric for loss function, e.g tl.cost.cross_entropy.
  36. X_train : numpy.array
  37. The input of training data
  38. y_train : numpy.array
  39. The target of training data
  40. acc : TensorFlow/numpy expression or None
  41. Metric for accuracy or others. If None, would not print the information.
  42. batch_size : int
  43. The batch size for training and evaluating.
  44. n_epoch : int
  45. The number of training epochs.
  46. print_freq : int
  47. Print the training information every ``print_freq`` epochs.
  48. X_val : numpy.array or None
  49. The input of validation data. If None, would not perform validation.
  50. y_val : numpy.array or None
  51. The target of validation data. If None, would not perform validation.
  52. eval_train : boolean
  53. Whether to evaluate the model during training.
  54. If X_val and y_val are not None, it reflects whether to evaluate the model on training data.
  55. tensorboard_dir : string
  56. path to log dir, if set, summary data will be stored to the tensorboard_dir/ directory for visualization with tensorboard. (default None)
  57. tensorboard_epoch_freq : int
  58. How many epochs between storing tensorboard checkpoint for visualization to log/ directory (default 5).
  59. tensorboard_weight_histograms : boolean
  60. If True updates tensorboard data in the logs/ directory for visualization
  61. of the weight histograms every tensorboard_epoch_freq epoch (default True).
  62. tensorboard_graph_vis : boolean
  63. If True stores the graph in the tensorboard summaries saved to log/ (default True).
  64. Examples
  65. --------
  66. See `tutorial_mnist_simple.py <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_mnist_simple.py>`_
  67. >>> tl.utils.fit(network, train_op=tf.optimizers.Adam(learning_rate=0.0001),
  68. ... cost=tl.cost.cross_entropy, X_train=X_train, y_train=y_train, acc=acc,
  69. ... batch_size=64, n_epoch=20, _val=X_val, y_val=y_val, eval_train=True)
  70. >>> tl.utils.fit(network, train_op, cost, X_train, y_train,
  71. ... acc=acc, batch_size=500, n_epoch=200, print_freq=5,
  72. ... X_val=X_val, y_val=y_val, eval_train=False, tensorboard=True)
  73. Notes
  74. --------
  75. 'tensorboard_weight_histograms' and 'tensorboard_weight_histograms' are not supported now.
  76. """
  77. if X_train.shape[0] < batch_size:
  78. raise AssertionError("Number of training examples should be bigger than the batch size")
  79. if tensorboard_dir is not None:
  80. tl.logging.info("Setting up tensorboard ...")
  81. #Set up tensorboard summaries and saver
  82. tl.files.exists_or_mkdir(tensorboard_dir)
  83. #Only write summaries for more recent TensorFlow versions
  84. if hasattr(tf, 'summary') and hasattr(tf.summary, 'create_file_writer'):
  85. train_writer = tf.summary.create_file_writer(tensorboard_dir + '/train')
  86. val_writer = tf.summary.create_file_writer(tensorboard_dir + '/validation')
  87. if tensorboard_graph_vis:
  88. # FIXME : not sure how to add tl network graph
  89. pass
  90. else:
  91. train_writer = None
  92. val_writer = None
  93. tl.logging.info("Finished! use `tensorboard --logdir=%s/` to start tensorboard" % tensorboard_dir)
  94. tl.logging.info("Start training the network ...")
  95. start_time_begin = time.time()
  96. for epoch in range(n_epoch):
  97. start_time = time.time()
  98. loss_ep, _, __ = train_epoch(network, X_train, y_train, cost=cost, train_op=train_op, batch_size=batch_size)
  99. train_loss, train_acc = None, None
  100. val_loss, val_acc = None, None
  101. if tensorboard_dir is not None and hasattr(tf, 'summary'):
  102. if epoch + 1 == 1 or (epoch + 1) % tensorboard_epoch_freq == 0:
  103. if eval_train is True:
  104. train_loss, train_acc, _ = run_epoch(
  105. network, X_train, y_train, cost=cost, acc=acc, batch_size=batch_size
  106. )
  107. with train_writer.as_default():
  108. tf.compat.v2.summary.scalar('loss', train_loss, step=epoch)
  109. if acc is not None:
  110. tf.summary.scalar('acc', train_acc, step=epoch)
  111. # FIXME : there seems to be an internal error in Tensorboard (misuse of tf.name_scope)
  112. # if tensorboard_weight_histograms is not None:
  113. # for param in network.all_weights:
  114. # tf.summary.histogram(param.name, param, step=epoch)
  115. if (X_val is not None) and (y_val is not None):
  116. val_loss, val_acc, _ = run_epoch(network, X_val, y_val, cost=cost, acc=acc, batch_size=batch_size)
  117. with val_writer.as_default():
  118. tf.summary.scalar('loss', val_loss, step=epoch)
  119. if acc is not None:
  120. tf.summary.scalar('acc', val_acc, step=epoch)
  121. # FIXME : there seems to be an internal error in Tensorboard (misuse of tf.name_scope)
  122. # if tensorboard_weight_histograms is not None:
  123. # for param in network.all_weights:
  124. # tf.summary.histogram(param.name, param, step=epoch)
  125. if epoch + 1 == 1 or (epoch + 1) % print_freq == 0:
  126. if (X_val is not None) and (y_val is not None):
  127. tl.logging.info("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time))
  128. if eval_train is True:
  129. if train_loss is None:
  130. train_loss, train_acc, _ = run_epoch(
  131. network, X_train, y_train, cost=cost, acc=acc, batch_size=batch_size
  132. )
  133. tl.logging.info(" train loss: %f" % train_loss)
  134. if acc is not None:
  135. tl.logging.info(" train acc: %f" % train_acc)
  136. if val_loss is None:
  137. val_loss, val_acc, _ = run_epoch(network, X_val, y_val, cost=cost, acc=acc, batch_size=batch_size)
  138. # tl.logging.info(" val loss: %f" % val_loss)
  139. if acc is not None:
  140. pass
  141. # tl.logging.info(" val acc: %f" % val_acc)
  142. else:
  143. tl.logging.info(
  144. "Epoch %d of %d took %fs, loss %f" % (epoch + 1, n_epoch, time.time() - start_time, loss_ep)
  145. )
  146. tl.logging.info("Total training time: %fs" % (time.time() - start_time_begin))
  147. def test(network, acc, X_test, y_test, batch_size, cost=None):
  148. """
  149. Test a given non time-series network by the given test data and metric.
  150. Parameters
  151. ----------
  152. network : TensorLayer Model
  153. The network.
  154. acc : TensorFlow/numpy expression or None
  155. Metric for accuracy or others.
  156. - If None, would not print the information.
  157. X_test : numpy.array
  158. The input of testing data.
  159. y_test : numpy array
  160. The target of testing data
  161. batch_size : int or None
  162. The batch size for testing, when dataset is large, we should use minibatche for testing;
  163. if dataset is small, we can set it to None.
  164. cost : TensorLayer or TensorFlow loss function
  165. Metric for loss function, e.g tl.cost.cross_entropy. If None, would not print the information.
  166. Examples
  167. --------
  168. See `tutorial_mnist_simple.py <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_mnist_simple.py>`_
  169. >>> def acc(_logits, y_batch):
  170. ... return np.mean(np.equal(np.argmax(_logits, 1), y_batch))
  171. >>> tl.utils.test(network, acc, X_test, y_test, batch_size=None, cost=tl.cost.cross_entropy)
  172. """
  173. tl.logging.info('Start testing the network ...')
  174. network.eval()
  175. if batch_size is None:
  176. y_pred = network(X_test)
  177. if cost is not None:
  178. test_loss = cost(y_pred, y_test)
  179. # tl.logging.info(" test loss: %f" % test_loss)
  180. test_acc = acc(y_pred, y_test)
  181. # tl.logging.info(" test acc: %f" % (test_acc / test_acc))
  182. return test_acc
  183. else:
  184. test_loss, test_acc, n_batch = run_epoch(
  185. network, X_test, y_test, cost=cost, acc=acc, batch_size=batch_size, shuffle=False
  186. )
  187. if cost is not None:
  188. tl.logging.info(" test loss: %f" % test_loss)
  189. tl.logging.info(" test acc: %f" % test_acc)
  190. return test_acc
  191. def predict(network, X, batch_size=None):
  192. """
  193. Return the predict results of given non time-series network.
  194. Parameters
  195. ----------
  196. network : TensorLayer Model
  197. The network.
  198. X : numpy.array
  199. The inputs.
  200. batch_size : int or None
  201. The batch size for prediction, when dataset is large, we should use minibatche for prediction;
  202. if dataset is small, we can set it to None.
  203. Examples
  204. --------
  205. See `tutorial_mnist_simple.py <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_mnist_simple.py>`_
  206. >>> _logits = tl.utils.predict(network, X_test)
  207. >>> y_pred = np.argmax(_logits, 1)
  208. """
  209. network.eval()
  210. if batch_size is None:
  211. y_pred = network(X)
  212. return y_pred
  213. else:
  214. result = None
  215. for X_a, _ in tl.iterate.minibatches(X, X, batch_size, shuffle=False):
  216. result_a = network(X_a)
  217. if result is None:
  218. result = result_a
  219. else:
  220. result = np.concatenate((result, result_a))
  221. if result is None:
  222. if len(X) % batch_size == 0:
  223. result_a = network(X[-(len(X) % batch_size):, :])
  224. result = result_a
  225. else:
  226. if len(X) != len(result) and len(X) % batch_size != 0:
  227. result_a = network(X[-(len(X) % batch_size):, :])
  228. result = np.concatenate((result, result_a))
  229. return result
  230. ## Evaluation
  231. def evaluation(y_test=None, y_predict=None, n_classes=None):
  232. """
  233. Input the predicted results, targets results and
  234. the number of class, return the confusion matrix, F1-score of each class,
  235. accuracy and macro F1-score.
  236. Parameters
  237. ----------
  238. y_test : list
  239. The target results
  240. y_predict : list
  241. The predicted results
  242. n_classes : int
  243. The number of classes
  244. Examples
  245. --------
  246. >>> c_mat, f1, acc, f1_macro = tl.utils.evaluation(y_test, y_predict, n_classes)
  247. """
  248. c_mat = confusion_matrix(y_test, y_predict, labels=[x for x in range(n_classes)])
  249. f1 = f1_score(y_test, y_predict, average=None, labels=[x for x in range(n_classes)])
  250. f1_macro = f1_score(y_test, y_predict, average='macro')
  251. acc = accuracy_score(y_test, y_predict)
  252. tl.logging.info('confusion matrix: \n%s' % c_mat)
  253. tl.logging.info('f1-score : %s' % f1)
  254. tl.logging.info('f1-score(macro) : %f' % f1_macro) # same output with > f1_score(y_true, y_pred, average='macro')
  255. tl.logging.info('accuracy-score : %f' % acc)
  256. return c_mat, f1, acc, f1_macro
  257. def dict_to_one(dp_dict):
  258. """Input a dictionary, return a dictionary that all items are set to one.
  259. Used for disable dropout, dropconnect layer and so on.
  260. Parameters
  261. ----------
  262. dp_dict : dictionary
  263. The dictionary contains key and number, e.g. keeping probabilities.
  264. """
  265. return {x: 1 for x in dp_dict}
  266. def flatten_list(list_of_list):
  267. """Input a list of list, return a list that all items are in a list.
  268. Parameters
  269. ----------
  270. list_of_list : a list of list
  271. Examples
  272. --------
  273. >>> tl.utils.flatten_list([[1, 2, 3],[4, 5],[6]])
  274. [1, 2, 3, 4, 5, 6]
  275. """
  276. return sum(list_of_list, [])
  277. def class_balancing_oversample(X_train=None, y_train=None, printable=True):
  278. """Input the features and labels, return the features and labels after oversampling.
  279. Parameters
  280. ----------
  281. X_train : numpy.array
  282. The inputs.
  283. y_train : numpy.array
  284. The targets.
  285. Examples
  286. --------
  287. One X
  288. >>> X_train, y_train = class_balancing_oversample(X_train, y_train, printable=True)
  289. Two X
  290. >>> X, y = tl.utils.class_balancing_oversample(X_train=np.hstack((X1, X2)), y_train=y, printable=False)
  291. >>> X1 = X[:, 0:5]
  292. >>> X2 = X[:, 5:]
  293. """
  294. # ======== Classes balancing
  295. if printable:
  296. tl.logging.info("Classes balancing for training examples...")
  297. c = Counter(y_train)
  298. if printable:
  299. tl.logging.info('the occurrence number of each stage: %s' % c.most_common())
  300. tl.logging.info('the least stage is Label %s have %s instances' % c.most_common()[-1])
  301. tl.logging.info('the most stage is Label %s have %s instances' % c.most_common(1)[0])
  302. most_num = c.most_common(1)[0][1]
  303. if printable:
  304. tl.logging.info('most num is %d, all classes tend to be this num' % most_num)
  305. locations = {}
  306. number = {}
  307. for lab, num in c.most_common(): # find the index from y_train
  308. number[lab] = num
  309. locations[lab] = np.where(np.array(y_train) == lab)[0]
  310. if printable:
  311. tl.logging.info('convert list(np.array) to dict format')
  312. X = {} # convert list to dict
  313. for lab, num in number.items():
  314. X[lab] = X_train[locations[lab]]
  315. # oversampling
  316. if printable:
  317. tl.logging.info('start oversampling')
  318. for key in X:
  319. temp = X[key]
  320. while True:
  321. if len(X[key]) >= most_num:
  322. break
  323. X[key] = np.vstack((X[key], temp))
  324. if printable:
  325. tl.logging.info('first features of label 0 > %d' % len(X[0][0]))
  326. tl.logging.info('the occurrence num of each stage after oversampling')
  327. for key in X:
  328. tl.logging.info("%s %d" % (key, len(X[key])))
  329. if printable:
  330. tl.logging.info('make each stage have same num of instances')
  331. for key in X:
  332. X[key] = X[key][0:most_num, :]
  333. tl.logging.info("%s %d" % (key, len(X[key])))
  334. # convert dict to list
  335. if printable:
  336. tl.logging.info('convert from dict to list format')
  337. y_train = []
  338. X_train = np.empty(shape=(0, len(X[0][0])))
  339. for key in X:
  340. X_train = np.vstack((X_train, X[key]))
  341. y_train.extend([key for i in range(len(X[key]))])
  342. # tl.logging.info(len(X_train), len(y_train))
  343. c = Counter(y_train)
  344. if printable:
  345. tl.logging.info('the occurrence number of each stage after oversampling: %s' % c.most_common())
  346. # ================ End of Classes balancing
  347. return X_train, y_train
  348. ## Random
  349. def get_random_int(min_v=0, max_v=10, number=5, seed=None):
  350. """Return a list of random integer by the given range and quantity.
  351. Parameters
  352. -----------
  353. min_v : number
  354. The minimum value.
  355. max_v : number
  356. The maximum value.
  357. number : int
  358. Number of value.
  359. seed : int or None
  360. The seed for random.
  361. Examples
  362. ---------
  363. >>> r = get_random_int(min_v=0, max_v=10, number=5)
  364. [10, 2, 3, 3, 7]
  365. """
  366. rnd = random.Random()
  367. if seed:
  368. rnd = random.Random(seed)
  369. # return [random.randint(min,max) for p in range(0, number)]
  370. return [rnd.randint(min_v, max_v) for p in range(0, number)]
  371. def list_string_to_dict(string):
  372. """Inputs ``['a', 'b', 'c']``, returns ``{'a': 0, 'b': 1, 'c': 2}``."""
  373. dictionary = {}
  374. for idx, c in enumerate(string):
  375. dictionary.update({c: idx})
  376. return dictionary
  377. def exit_tensorflow(port=6006):
  378. """Close TensorBoard and Nvidia-process if available.
  379. Parameters
  380. ----------
  381. port : int
  382. TensorBoard port you want to close, `6006` as default.
  383. """
  384. text = "[TL] Close tensorboard and nvidia-process if available"
  385. text2 = "[TL] Close tensorboard and nvidia-process not yet supported by this function (tl.ops.exit_tf) on "
  386. if _platform == "linux" or _platform == "linux2":
  387. tl.logging.info('linux: %s' % text)
  388. os.system('nvidia-smi')
  389. os.system('fuser ' + str(port) + '/tcp -k') # kill tensorboard 6006
  390. os.system("nvidia-smi | grep python |awk '{print $3}'|xargs kill") # kill all nvidia-smi python process
  391. _exit()
  392. elif _platform == "darwin":
  393. tl.logging.info('OS X: %s' % text)
  394. subprocess.Popen(
  395. "lsof -i tcp:" + str(port) + " | grep -v PID | awk '{print $2}' | xargs kill", shell=True
  396. ) # kill tensorboard
  397. elif _platform == "win32":
  398. raise NotImplementedError("this function is not supported on the Windows platform")
  399. else:
  400. tl.logging.info(text2 + _platform)
  401. def open_tensorboard(log_dir='/tmp/tensorflow', port=6006):
  402. """Open Tensorboard.
  403. Parameters
  404. ----------
  405. log_dir : str
  406. Directory where your tensorboard logs are saved
  407. port : int
  408. TensorBoard port you want to open, 6006 is tensorboard default
  409. """
  410. text = "[TL] Open tensorboard, go to localhost:" + str(port) + " to access"
  411. text2 = " not yet supported by this function (tl.ops.open_tb)"
  412. if not tl.files.exists_or_mkdir(log_dir, verbose=False):
  413. tl.logging.info("[TL] Log reportory was created at %s" % log_dir)
  414. if _platform == "linux" or _platform == "linux2":
  415. tl.logging.info('linux: %s' % text)
  416. subprocess.Popen(
  417. sys.prefix + " | python -m tensorflow.tensorboard --logdir=" + log_dir + " --port=" + str(port), shell=True
  418. ) # open tensorboard in localhost:6006/ or whatever port you chose
  419. elif _platform == "darwin":
  420. tl.logging.info('OS X: %s' % text)
  421. subprocess.Popen(
  422. sys.prefix + " | python -m tensorflow.tensorboard --logdir=" + log_dir + " --port=" + str(port), shell=True
  423. ) # open tensorboard in localhost:6006/ or whatever port you chose
  424. elif _platform == "win32":
  425. raise NotImplementedError("this function is not supported on the Windows platform")
  426. else:
  427. tl.logging.info(_platform + text2)
  428. def clear_all_placeholder_variables(printable=True):
  429. """Clears all the placeholder variables of keep prob,
  430. including keeping probabilities of all dropout, denoising, dropconnect etc.
  431. Parameters
  432. ----------
  433. printable : boolean
  434. If True, print all deleted variables.
  435. """
  436. tl.logging.info('clear all .....................................')
  437. gl = globals().copy()
  438. for var in gl:
  439. if var[0] == '_': continue
  440. if 'func' in str(globals()[var]): continue
  441. if 'module' in str(globals()[var]): continue
  442. if 'class' in str(globals()[var]): continue
  443. if printable:
  444. tl.logging.info(" clear_all ------- %s" % str(globals()[var]))
  445. del globals()[var]
  446. def set_gpu_fraction(gpu_fraction=0.3):
  447. """Set the GPU memory fraction for the application.
  448. Parameters
  449. ----------
  450. gpu_fraction : None or float
  451. Fraction of GPU memory, (0 ~ 1]. If None, allow gpu memory growth.
  452. References
  453. ----------
  454. - `TensorFlow using GPU <https://www.tensorflow.org/alpha/guide/using_gpu#allowing_gpu_memory_growth>`__
  455. """
  456. if gpu_fraction is None:
  457. tl.logging.info("[TL]: ALLOW GPU MEM GROWTH")
  458. tf.config.gpu.set_per_process_memory_growth(True)
  459. else:
  460. tl.logging.info("[TL]: GPU MEM Fraction %f" % gpu_fraction)
  461. tf.config.gpu.set_per_process_memory_fraction(0.4)
  462. # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction)
  463. # sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
  464. # return sess
  465. def train_epoch(
  466. network, X, y, cost, train_op=tf.optimizers.Adam(learning_rate=0.0001), acc=None, batch_size=100, shuffle=True
  467. ):
  468. """Training a given non time-series network by the given cost function, training data, batch_size etc.
  469. for one epoch.
  470. Parameters
  471. ----------
  472. network : TensorLayer Model
  473. the network to be trained.
  474. X : numpy.array
  475. The input of training data
  476. y : numpy.array
  477. The target of training data
  478. cost : TensorLayer or TensorFlow loss function
  479. Metric for loss function, e.g tl.cost.cross_entropy.
  480. train_op : TensorFlow optimizer
  481. The optimizer for training e.g. tf.optimizers.Adam().
  482. acc : TensorFlow/numpy expression or None
  483. Metric for accuracy or others. If None, would not print the information.
  484. batch_size : int
  485. The batch size for training and evaluating.
  486. shuffle : boolean
  487. Indicating whether to shuffle the dataset in training.
  488. Returns
  489. -------
  490. loss_ep : Tensor. Average loss of this epoch.
  491. acc_ep : Tensor or None. Average accuracy(metric) of this epoch. None if acc is not given.
  492. n_step : int. Number of iterations taken in this epoch.
  493. """
  494. network.train()
  495. loss_ep = 0
  496. acc_ep = 0
  497. n_step = 0
  498. for X_batch, y_batch in tl.iterate.minibatches(X, y, batch_size, shuffle=shuffle):
  499. _loss, _acc = _train_step(network, X_batch, y_batch, cost=cost, train_op=train_op, acc=acc)
  500. loss_ep += _loss
  501. if acc is not None:
  502. acc_ep += _acc
  503. n_step += 1
  504. loss_ep = loss_ep / n_step
  505. acc_ep = acc_ep / n_step if acc is not None else None
  506. return loss_ep, acc_ep, n_step
  507. def run_epoch(network, X, y, cost=None, acc=None, batch_size=100, shuffle=False):
  508. """Run a given non time-series network by the given cost function, test data, batch_size etc.
  509. for one epoch.
  510. Parameters
  511. ----------
  512. network : TensorLayer Model
  513. the network to be trained.
  514. X : numpy.array
  515. The input of training data
  516. y : numpy.array
  517. The target of training data
  518. cost : TensorLayer or TensorFlow loss function
  519. Metric for loss function, e.g tl.cost.cross_entropy.
  520. acc : TensorFlow/numpy expression or None
  521. Metric for accuracy or others. If None, would not print the information.
  522. batch_size : int
  523. The batch size for training and evaluating.
  524. shuffle : boolean
  525. Indicating whether to shuffle the dataset in training.
  526. Returns
  527. -------
  528. loss_ep : Tensor. Average loss of this epoch. None if 'cost' is not given.
  529. acc_ep : Tensor. Average accuracy(metric) of this epoch. None if 'acc' is not given.
  530. n_step : int. Number of iterations taken in this epoch.
  531. """
  532. network.eval()
  533. loss_ep = 0
  534. acc_ep = 0
  535. n_step = 0
  536. for X_batch, y_batch in tl.iterate.minibatches(X, y, batch_size, shuffle=shuffle):
  537. _loss, _acc = _run_step(network, X_batch, y_batch, cost=cost, acc=acc)
  538. if cost is not None:
  539. loss_ep += _loss
  540. if acc is not None:
  541. acc_ep += _acc
  542. n_step += 1
  543. loss_ep = loss_ep / n_step if cost is not None else None
  544. acc_ep = acc_ep / n_step if acc is not None else None
  545. return loss_ep, acc_ep, n_step
  546. @tf.function
  547. def _train_step(network, X_batch, y_batch, cost, train_op=tf.optimizers.Adam(learning_rate=0.0001), acc=None):
  548. """Train for one step"""
  549. with tf.GradientTape() as tape:
  550. y_pred = network(X_batch)
  551. _loss = cost(y_pred, y_batch)
  552. grad = tape.gradient(_loss, network.trainable_weights)
  553. train_op.apply_gradients(zip(grad, network.trainable_weights))
  554. if acc is not None:
  555. _acc = acc(y_pred, y_batch)
  556. return _loss, _acc
  557. else:
  558. return _loss, None
  559. # @tf.function # FIXME : enable tf.function will cause some bugs in numpy, need fixing
  560. def _run_step(network, X_batch, y_batch, cost=None, acc=None):
  561. """Run for one step"""
  562. y_pred = network(X_batch)
  563. _loss, _acc = None, None
  564. if cost is not None:
  565. _loss = cost(y_pred, y_batch)
  566. if acc is not None:
  567. _acc = acc(y_pred, y_batch)
  568. return _loss, _acc

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.