You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

core.py 22 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. from collections.abc import Iterable
  4. from tensorlayer.layers.core.common import _save_weights, _load_weights
  5. import tensorlayer as tl
  6. from tensorlayer.layers.core import Module
  7. import numpy as np
  8. import time
  9. if tl.BACKEND == 'tensorflow':
  10. import tensorflow as tf
  11. if tl.BACKEND == 'mindspore':
  12. from mindspore.ops import composite
  13. from mindspore.ops import operations as P
  14. from mindspore.common import ParameterTuple
  15. if tl.BACKEND == 'paddle':
  16. import paddle as pd
  17. __all__ = ['Model', 'WithLoss', 'TrainOneStep']
  18. class Model:
  19. """
  20. High-Level API for Training or Testing.
  21. `Model` groups layers into an object with training and inference features.
  22. Parameters
  23. ----------
  24. network : tensorlayer model
  25. The training or testing network.
  26. loss_fn : function
  27. Objective function
  28. optimizer : class
  29. Optimizer for updating the weights
  30. metrics : class
  31. Dict or set of metrics to be evaluated by the model during
  32. Methods
  33. ---------
  34. trin()
  35. Model training.
  36. eval()
  37. Model prediction.
  38. save_weights()
  39. Input file_path, save model weights into a file of given format.
  40. Use load_weights() to restore.
  41. load_weights()
  42. Load model weights from a given file, which should be previously saved by save_weights().
  43. Examples
  44. --------
  45. >>> import tensorlayer as tl
  46. >>> class Net(Module):
  47. >>> def __init__(self):
  48. >>> super(Net, self).__init__()
  49. >>> self.conv = tl.layers.Conv2d(n_filter=32, filter_size=(3, 3), strides=(2, 2), in_channels=5, name='conv2d')
  50. >>> self.bn = tl.layers.BatchNorm2d(num_features=32, act=tl.ReLU)
  51. >>> self.flatten = tl.layers.Flatten()
  52. >>> self.fc = tl.layers.Dense(n_units=12, in_channels=32*224*224) # padding=0
  53. >>>
  54. >>> def construct(self, x):
  55. >>> x = self.conv(x)
  56. >>> x = self.bn(x)
  57. >>> x = self.flatten(x)
  58. >>> out = self.fc(x)
  59. >>> return out
  60. >>>
  61. >>> net = Net()
  62. >>> loss = tl.cost.softmax_cross_entropy_with_logits
  63. >>> optim = tl.optimizers.Momentum(params=net.trainable_weights, learning_rate=0.1, momentum=0.9)
  64. >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
  65. >>> dataset = get_dataset()
  66. >>> model.train(2, dataset)
  67. """
  68. def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, **kwargs):
  69. self.network = network
  70. self.loss_fn = loss_fn
  71. self.optimizer = optimizer
  72. self.metrics = metrics
  73. self.all_weights = network.all_weights
  74. self.train_weights = self.network.trainable_weights
  75. def train(self, n_epoch, train_dataset=None, test_dataset=False, print_train_batch=False, print_freq=5):
  76. if not isinstance(train_dataset, Iterable):
  77. raise Exception("Expected type in (train_dataset, Iterable), but got {}.".format(type(train_dataset)))
  78. if tl.BACKEND == 'tensorflow':
  79. self.tf_train(
  80. n_epoch=n_epoch, train_dataset=train_dataset, network=self.network, loss_fn=self.loss_fn,
  81. train_weights=self.train_weights, optimizer=self.optimizer, metrics=self.metrics,
  82. print_train_batch=print_train_batch, print_freq=print_freq, test_dataset=test_dataset
  83. )
  84. elif tl.BACKEND == 'mindspore':
  85. self.ms_train(
  86. n_epoch=n_epoch, train_dataset=train_dataset, network=self.network, loss_fn=self.loss_fn,
  87. train_weights=self.train_weights, optimizer=self.optimizer, metrics=self.metrics,
  88. print_train_batch=print_train_batch, print_freq=print_freq, test_dataset=test_dataset
  89. )
  90. elif tl.BACKEND == 'paddle':
  91. self.pd_train(
  92. n_epoch=n_epoch, train_dataset=train_dataset, network=self.network, loss_fn=self.loss_fn,
  93. train_weights=self.train_weights, optimizer=self.optimizer, metrics=self.metrics,
  94. print_train_batch=print_train_batch, print_freq=print_freq, test_dataset=test_dataset
  95. )
  96. def eval(self, test_dataset):
  97. self.network.set_eval()
  98. test_loss, test_acc, n_iter = 0, 0, 0
  99. for X_batch, y_batch in test_dataset:
  100. _logits = self.network(X_batch)
  101. test_loss += self.loss_fn(_logits, y_batch)
  102. if self.metrics:
  103. try:
  104. test_acc += self.metrics(_logits, y_batch)
  105. except:
  106. self.metrics.update(_logits, y_batch)
  107. test_acc += self.metrics.result()
  108. self.metrics.reset()
  109. else:
  110. test_acc += np.mean(np.equal(np.argmax(_logits, 1), y_batch))
  111. n_iter += 1
  112. print(" test loss: {}".format(test_loss / n_iter))
  113. print(" test acc: {}".format(test_acc / n_iter))
  114. def save_weights(self, file_path, format=None):
  115. """Input file_path, save model weights into a file of given format.
  116. Use self.load_weights() to restore.
  117. Parameters
  118. ----------
  119. file_path : str
  120. Filename to which the model weights will be saved.
  121. format : str or None
  122. Saved file format.
  123. Value should be None, 'hdf5', 'npz', 'npz_dict' or 'ckpt'. Other format is not supported now.
  124. 1) If this is set to None, then the postfix of file_path will be used to decide saved format.
  125. If the postfix is not in ['h5', 'hdf5', 'npz', 'ckpt'], then file will be saved in hdf5 format by default.
  126. 2) 'hdf5' will save model weights name in a list and each layer has its weights stored in a group of
  127. the hdf5 file.
  128. 3) 'npz' will save model weights sequentially into a npz file.
  129. 4) 'npz_dict' will save model weights along with its name as a dict into a npz file.
  130. 5) 'ckpt' will save model weights into a tensorflow ckpt file.
  131. Default None.
  132. Examples
  133. --------
  134. 1) Save model weights in hdf5 format by default.
  135. >>> net = vgg16()
  136. >>> optimizer = tl.optimizers.Adam(learning_rate=0.001)
  137. >>> metric = tl.metric.Accuracy()
  138. >>> model = tl.models.Model(network=net, loss_fn=tl.cost.softmax_cross_entropy_with_logits, optimizer=optimizer, metrics=metric)
  139. >>> model.save_weights('./model.h5')
  140. ...
  141. >>> model.load_weights('./model.h5')
  142. 2) Save model weights in npz/npz_dict format
  143. >>> model.save_weights('./model.npz')
  144. >>> model.save_weights('./model.npz', format='npz_dict')
  145. """
  146. _save_weights(net=self, file_path=file_path, format=format)
  147. def load_weights(self, file_path, format=None, in_order=True, skip=False):
  148. """Load model weights from a given file, which should be previously saved by self.save_weights().
  149. Parameters
  150. ----------
  151. file_path : str
  152. Filename from which the model weights will be loaded.
  153. format : str or None
  154. If not specified (None), the postfix of the file_path will be used to decide its format. If specified,
  155. value should be 'hdf5', 'npz', 'npz_dict' or 'ckpt'. Other format is not supported now.
  156. In addition, it should be the same format when you saved the file using self.save_weights().
  157. Default is None.
  158. in_order : bool
  159. Allow loading weights into model in a sequential way or by name. Only useful when 'format' is 'hdf5'.
  160. If 'in_order' is True, weights from the file will be loaded into model in a sequential way.
  161. If 'in_order' is False, weights from the file will be loaded into model by matching the name
  162. with the weights of the model, particularly useful when trying to restore model in eager(graph) mode from
  163. a weights file which is saved in graph(eager) mode.
  164. Default is True.
  165. skip : bool
  166. Allow skipping weights whose name is mismatched between the file and model. Only useful when 'format' is
  167. 'hdf5' or 'npz_dict'. If 'skip' is True, 'in_order' argument will be ignored and those loaded weights
  168. whose name is not found in model weights (self.all_weights) will be skipped. If 'skip' is False, error will
  169. occur when mismatch is found.
  170. Default is False.
  171. Examples
  172. --------
  173. 1) load model from a hdf5 file.
  174. >>> net = vgg16()
  175. >>> optimizer = tl.optimizers.Adam(learning_rate=0.001)
  176. >>> metric = tl.metric.Accuracy()
  177. >>> model = tl.models.Model(network=net, loss_fn=tl.cost.softmax_cross_entropy_with_logits, optimizer=optimizer, metrics=metric)
  178. >>> model.load_weights('./model_graph.h5', in_order=False, skip=True) # load weights by name, skipping mismatch
  179. >>> model.load_weights('./model_eager.h5') # load sequentially
  180. 2) load model from a npz file
  181. >>> model.load_weights('./model.npz')
  182. 3) load model from a npz file, which is saved as npz_dict previously
  183. >>> model.load_weights('./model.npz', format='npz_dict')
  184. Notes
  185. -------
  186. 1) 'in_order' is only useful when 'format' is 'hdf5'. If you are trying to load a weights file which is
  187. saved in a different mode, it is recommended to set 'in_order' be True.
  188. 2) 'skip' is useful when 'format' is 'hdf5' or 'npz_dict'. If 'skip' is True,
  189. 'in_order' argument will be ignored.
  190. """
  191. _load_weights(net=self, file_path=file_path, format=format, in_order=in_order, skip=skip)
  192. def tf_train(
  193. self, n_epoch, train_dataset, network, loss_fn, train_weights, optimizer, metrics, print_train_batch,
  194. print_freq, test_dataset
  195. ):
  196. for epoch in range(n_epoch):
  197. start_time = time.time()
  198. train_loss, train_acc, n_iter = 0, 0, 0
  199. for X_batch, y_batch in train_dataset:
  200. network.set_train()
  201. with tf.GradientTape() as tape:
  202. # compute outputs
  203. _logits = network(X_batch)
  204. # compute loss and update model
  205. _loss_ce = loss_fn(_logits, y_batch)
  206. grad = tape.gradient(_loss_ce, train_weights)
  207. optimizer.apply_gradients(zip(grad, train_weights))
  208. train_loss += _loss_ce
  209. if metrics:
  210. metrics.update(_logits, y_batch)
  211. train_acc += metrics.result()
  212. metrics.reset()
  213. else:
  214. train_acc += np.mean(np.equal(np.argmax(_logits, 1), y_batch))
  215. n_iter += 1
  216. if print_train_batch:
  217. print("Epoch {} of {} took {}".format(epoch + 1, n_epoch, time.time() - start_time))
  218. print(" train loss: {}".format(train_loss / n_iter))
  219. print(" train acc: {}".format(train_acc / n_iter))
  220. if epoch + 1 == 1 or (epoch + 1) % print_freq == 0:
  221. print("Epoch {} of {} took {}".format(epoch + 1, n_epoch, time.time() - start_time))
  222. print(" train loss: {}".format(train_loss / n_iter))
  223. print(" train acc: {}".format(train_acc / n_iter))
  224. if test_dataset:
  225. # use training and evaluation sets to evaluate the model every print_freq epoch
  226. if epoch + 1 == 1 or (epoch + 1) % print_freq == 0:
  227. network.set_eval()
  228. val_loss, val_acc, n_iter = 0, 0, 0
  229. for X_batch, y_batch in test_dataset:
  230. _logits = network(X_batch) # is_train=False, disable dropout
  231. val_loss += loss_fn(_logits, y_batch, name='eval_loss')
  232. if metrics:
  233. metrics.update(_logits, y_batch)
  234. val_acc += metrics.result()
  235. metrics.reset()
  236. else:
  237. val_acc += np.mean(np.equal(np.argmax(_logits, 1), y_batch))
  238. n_iter += 1
  239. print(" val loss: {}".format(val_loss / n_iter))
  240. print(" val acc: {}".format(val_acc / n_iter))
  241. def ms_train(
  242. self, n_epoch, train_dataset, network, loss_fn, train_weights, optimizer, metrics, print_train_batch,
  243. print_freq, test_dataset
  244. ):
  245. net_with_criterion = WithLoss(network, loss_fn)
  246. train_network = GradWrap(net_with_criterion, network.trainable_weights)
  247. train_network.set_train()
  248. for epoch in range(n_epoch):
  249. start_time = time.time()
  250. train_loss, train_acc, n_iter = 0, 0, 0
  251. for X_batch, y_batch in train_dataset:
  252. output = network(X_batch)
  253. loss_output = loss_fn(output, y_batch)
  254. grads = train_network(X_batch, y_batch)
  255. success = optimizer.apply_gradients(zip(grads, train_weights))
  256. loss = loss_output.asnumpy()
  257. train_loss += loss
  258. if metrics:
  259. metrics.update(output, y_batch)
  260. train_acc += metrics.result()
  261. metrics.reset()
  262. else:
  263. train_acc += np.mean((P.Equal()(P.Argmax(axis=1)(output), y_batch).asnumpy()))
  264. n_iter += 1
  265. if print_train_batch:
  266. print("Epoch {} of {} took {}".format(epoch + 1, n_epoch, time.time() - start_time))
  267. print(" train loss: {}".format(train_loss / n_iter))
  268. print(" train acc: {}".format(train_acc / n_iter))
  269. if epoch + 1 == 1 or (epoch + 1) % print_freq == 0:
  270. print("Epoch {} of {} took {}".format(epoch + 1, n_epoch, time.time() - start_time))
  271. print(" train loss: {}".format(train_loss / n_iter))
  272. print(" train acc: {}".format(train_acc / n_iter))
  273. if test_dataset:
  274. # use training and evaluation sets to evaluate the model every print_freq epoch
  275. if epoch + 1 == 1 or (epoch + 1) % print_freq == 0:
  276. network.set_eval()
  277. val_loss, val_acc, n_iter = 0, 0, 0
  278. for X_batch, y_batch in test_dataset:
  279. _logits = network(X_batch)
  280. val_loss += loss_fn(_logits, y_batch, name='eval_loss')
  281. if metrics:
  282. metrics.update(_logits, y_batch)
  283. val_acc += metrics.result()
  284. metrics.reset()
  285. else:
  286. val_acc += np.mean((P.Equal()(P.Argmax(axis=1)(_logits), y_batch).asnumpy()))
  287. n_iter += 1
  288. print(" val loss: {}".format(val_loss / n_iter))
  289. print(" val acc: {}".format(val_acc / n_iter))
  290. def pd_train(
  291. self, n_epoch, train_dataset, network, loss_fn, train_weights, optimizer, metrics, print_train_batch,
  292. print_freq, test_dataset
  293. ):
  294. for epoch in range(n_epoch):
  295. start_time = time.time()
  296. train_loss, train_acc, n_iter = 0, 0, 0
  297. for X_batch, y_batch in train_dataset:
  298. network.set_train()
  299. output = network(X_batch)
  300. loss = loss_fn(output, y_batch)
  301. loss_ce = loss.numpy()
  302. params_grads = optimizer.gradient(loss, train_weights)
  303. optimizer.apply_gradients(params_grads)
  304. train_loss += loss_ce
  305. if metrics:
  306. metrics.update(output, y_batch)
  307. train_acc += metrics.result()
  308. metrics.reset()
  309. else:
  310. train_acc += pd.metric.accuracy(output, y_batch)
  311. n_iter += 1
  312. if print_train_batch:
  313. print("Epoch {} of {} took {}".format(epoch + 1, n_epoch, time.time() - start_time))
  314. print(" train loss: {}".format(train_loss / n_iter))
  315. print(" train acc: {}".format(train_acc / n_iter))
  316. if epoch + 1 == 1 or (epoch + 1) % print_freq == 0:
  317. print("Epoch {} of {} took {}".format(epoch + 1, n_epoch, time.time() - start_time))
  318. print(" train loss: {}".format(train_loss / n_iter))
  319. print(" train acc: {}".format(train_acc / n_iter))
  320. if test_dataset:
  321. # use training and evaluation sets to evaluate the model every print_freq epoch
  322. if epoch + 1 == 1 or (epoch + 1) % print_freq == 0:
  323. network.set_eval()
  324. val_loss, val_acc, n_iter = 0, 0, 0
  325. for X_batch, y_batch in test_dataset:
  326. _logits = network(X_batch) # is_train=False, disable dropout
  327. val_loss += loss_fn(_logits, y_batch, name='eval_loss')
  328. if metrics:
  329. metrics.update(_logits, y_batch)
  330. val_acc += metrics.result()
  331. metrics.reset()
  332. else:
  333. val_acc += np.mean(np.equal(np.argmax(_logits, 1), y_batch))
  334. n_iter += 1
  335. print(" val loss: {}".format(val_loss / n_iter))
  336. print(" val acc: {}".format(val_acc / n_iter))
  337. class WithLoss(Module):
  338. """
  339. High-Level API for Training or Testing.
  340. Wraps the network with loss function. This Module accepts data and label as inputs and
  341. the computed loss will be returned.
  342. Parameters
  343. ----------
  344. backbone : tensorlayer model
  345. The tensorlayer network.
  346. loss_fn : function
  347. Objective function
  348. Methods
  349. ---------
  350. forward()
  351. Model inference.
  352. Examples:
  353. >>> import tensorlayer as tl
  354. >>> net = vgg16()
  355. >>> loss_fn = tl.cost.softmax_cross_entropy_with_logits
  356. >>> net_with_loss = tl.models.WithLoss(net, loss_fn)
  357. """
  358. def __init__(self, backbone, loss_fn):
  359. super(WithLoss, self).__init__()
  360. self._backbone = backbone
  361. self._loss_fn = loss_fn
  362. def forward(self, data, label):
  363. out = self._backbone(data)
  364. return self._loss_fn(out, label)
  365. @property
  366. def backbone_network(self):
  367. return self._backbone
  368. class GradWrap(Module):
  369. """ GradWrap definition """
  370. def __init__(self, network, trainable_weights):
  371. super(GradWrap, self).__init__(auto_prefix=False)
  372. self.network = network
  373. self.weights = ParameterTuple(trainable_weights)
  374. def forward(self, x, label):
  375. return composite.GradOperation(get_by_list=True)(self.network, self.weights)(x, label)
  376. class TrainOneStepWithTF(object):
  377. def __init__(self, net_with_loss, optimizer, train_weights):
  378. self.net_with_loss = net_with_loss
  379. self.optimzer = optimizer
  380. self.train_weights = train_weights
  381. def __call__(self, data, label):
  382. with tf.GradientTape() as tape:
  383. loss = self.net_with_loss(data, label)
  384. grad = tape.gradient(loss, self.train_weights)
  385. self.optimzer.apply_gradients(zip(grad, self.train_weights))
  386. return loss
  387. class TrainOneStepWithMS(object):
  388. def __init__(self, net_with_loss, optimizer, train_weights):
  389. self.net_with_loss = net_with_loss
  390. self.optimizer = optimizer
  391. self.train_weights = train_weights
  392. self.net_with_loss = net_with_loss
  393. self.train_network = GradWrap(net_with_loss, train_weights)
  394. def __call__(self, data, label):
  395. loss = self.net_with_loss(data, label)
  396. grads = self.train_network(data, label)
  397. self.optimizer.apply_gradients(zip(grads, self.train_weights))
  398. loss = loss.asnumpy()
  399. return loss
  400. class TrainOneStepWithPD(object):
  401. def __init__(self, net_with_loss, optimizer, train_weights):
  402. self.net_with_loss = net_with_loss
  403. self.optimizer = optimizer
  404. self.train_weights = train_weights
  405. def __call__(self, data, label):
  406. loss = self.net_with_loss(data, label)
  407. params_grads = self.optimizer.gradient(loss, self.train_weights)
  408. self.optimizer.apply_gradients(params_grads)
  409. return loss.numpy()
  410. class TrainOneStep(object):
  411. """
  412. High-Level API for Training One Step.
  413. Wraps the network with an optimizer. It can be trained in one step using the optimizer to get the loss.
  414. Parameters
  415. ----------
  416. net_with_loss : tensorlayer WithLoss
  417. The training or testing network.
  418. optimizer : class
  419. Optimizer for updating the weights
  420. train_weights : class
  421. Dict or set of metrics to be evaluated by the model during
  422. Examples
  423. --------
  424. >>> import tensorlayer as tl
  425. >>> net = vgg16()
  426. >>> train_weights = net.trainable_weights
  427. >>> loss_fn = tl.cost.softmax_cross_entropy_with_logits
  428. >>> optimizer = tl.optimizers.Adam(learning_rate=1e-3)
  429. >>> net_with_loss = tl.models.WithLoss(net, loss_fn)
  430. >>> train_one_step = tl.models.TrainOneStep(net_with_loss, optimizer, train_weights)
  431. >>> inputs, labels = tl.layers.Input((128, 784), dtype=tl.float32), tl.layers.Input((128, 1), dtype=tl.int32)
  432. >>> train_one_step(inputs, labels)
  433. """
  434. def __init__(self, net_with_loss, optimizer, train_weights):
  435. if tl.BACKEND == 'tensorflow':
  436. self.net_with_train = TrainOneStepWithTF(net_with_loss, optimizer, train_weights)
  437. elif tl.BACKEND == 'mindspore':
  438. self.net_with_train = TrainOneStepWithMS(net_with_loss, optimizer, train_weights)
  439. elif tl.BACKEND == 'paddle':
  440. self.net_with_train = TrainOneStepWithPD(net_with_loss, optimizer, train_weights)
  441. else:
  442. raise NotImplementedError("This backend is not supported")
  443. def __call__(self, data, label):
  444. loss = self.net_with_train(data, label)
  445. return loss

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.