You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensorflow_cost.py 34 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import numbers
  4. import tensorflow as tf
  5. from tensorflow.python.framework import ops
  6. from tensorflow.python.ops import array_ops, math_ops, nn_ops, standard_ops
  7. from tensorlayer import logging
  8. __all__ = [
  9. 'softmax_cross_entropy_with_logits',
  10. 'sigmoid_cross_entropy',
  11. 'binary_cross_entropy',
  12. 'mean_squared_error',
  13. 'normalized_mean_square_error',
  14. 'absolute_difference_error',
  15. 'dice_coe',
  16. 'dice_hard_coe',
  17. 'iou_coe',
  18. 'cross_entropy_seq',
  19. 'cross_entropy_seq_with_mask',
  20. 'cosine_similarity',
  21. 'li_regularizer',
  22. 'lo_regularizer',
  23. 'maxnorm_regularizer',
  24. 'maxnorm_o_regularizer',
  25. 'maxnorm_i_regularizer',
  26. ]
  27. def softmax_cross_entropy_with_logits(output, target, name=None):
  28. """Softmax cross-entropy operation, returns the TensorFlow expression of cross-entropy for two distributions,
  29. it implements softmax internally. See ``tf.ops.sparse_softmax_cross_entropy_with_logits``.
  30. Parameters
  31. ----------
  32. output : Tensor
  33. A batch of distribution with shape: [batch_size, num of classes].
  34. target : Tensor
  35. A batch of index with shape: [batch_size, ].
  36. name : string
  37. Name of this loss.
  38. Examples
  39. --------
  40. >>> import tensorlayer as tl
  41. >>> ce = tl.cost.softmax_cross_entropy_with_logits(y_logits, y_target_logits, 'my_loss')
  42. References
  43. -----------
  44. - About cross-entropy: `<https://en.wikipedia.org/wiki/Cross_entropy>`__.
  45. - The code is borrowed from: `<https://en.wikipedia.org/wiki/Cross_entropy>`__.
  46. """
  47. # if name is None:
  48. # raise Exception("Please give a unique name to tl.cost.cross_entropy for TF1.0+")
  49. return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target, logits=output), name=name)
  50. def sigmoid_cross_entropy(output, target, name=None):
  51. """Sigmoid cross-entropy operation, see ``tf.ops.sigmoid_cross_entropy_with_logits``.
  52. Parameters
  53. ----------
  54. output : Tensor
  55. A batch of distribution with shape: [batch_size, num of classes].
  56. target : Tensor
  57. A batch of index with shape: [batch_size, ].
  58. name : string
  59. Name of this loss.
  60. """
  61. return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output), name=name)
  62. def binary_cross_entropy(output, target, epsilon=1e-8, name='bce_loss'):
  63. """Binary cross entropy operation.
  64. Parameters
  65. ----------
  66. output : Tensor
  67. Tensor with type of `float32` or `float64`.
  68. target : Tensor
  69. The target distribution, format the same with `output`.
  70. epsilon : float
  71. A small value to avoid output to be zero.
  72. name : str
  73. An optional name to attach to this function.
  74. References
  75. -----------
  76. - `ericjang-DRAW <https://github.com/ericjang/draw/blob/master/draw.py#L73>`__
  77. """
  78. # with ops.op_scope([output, target], name, "bce_loss") as name:
  79. # output = ops.convert_to_tensor(output, name="preds")
  80. # target = ops.convert_to_tensor(targets, name="target")
  81. # with tf.name_scope(name):
  82. return tf.reduce_mean(
  83. tf.reduce_sum(
  84. -(target * tf.math.log(output + epsilon) + (1. - target) * tf.math.log(1. - output + epsilon)), axis=1
  85. ), name=name
  86. )
  87. # For brevity, let `x = output`, `z = target`. The binary cross entropy loss is
  88. #
  89. # loss(x, z) = - sum_i (x[i] * log(z[i]) + (1 - x[i]) * log(1 - z[i]))
  90. def mean_squared_error(output, target, is_mean=False, axis=-1, name="mean_squared_error"):
  91. """Return the TensorFlow expression of mean-square-error (L2) of two batch of data.
  92. Parameters
  93. ----------
  94. output : Tensor
  95. 2D, 3D or 4D tensor i.e. [batch_size, n_feature], [batch_size, height, width] or [batch_size, height, width, channel].
  96. target : Tensor
  97. The target distribution, format the same with `output`.
  98. is_mean : boolean
  99. Whether compute the mean or sum for each example.
  100. - If True, use ``tf.reduce_mean`` to compute the loss between one target and predict data.
  101. - If False, use ``tf.reduce_sum`` (default).
  102. axis : int or list of int
  103. The dimensions to reduce.
  104. name : str
  105. An optional name to attach to this function.
  106. References
  107. ------------
  108. - `Wiki Mean Squared Error <https://en.wikipedia.org/wiki/Mean_squared_error>`__
  109. """
  110. # with tf.name_scope(name):
  111. # if len(output.shape) == 2: # [batch_size, n_feature]
  112. # axis = 1
  113. # elif len(output.shape) == 3: # [batch_size, w, h]
  114. # axis = [1, 2]
  115. # elif len(output.shape) == 4: # [batch_size, w, h, c]
  116. # axis = [1, 2, 3]
  117. # else:
  118. # raise Exception("Unknow dimension")
  119. if is_mean:
  120. mse = tf.reduce_mean(tf.reduce_mean(tf.math.squared_difference(output, target), axis), name=name)
  121. else:
  122. mse = tf.reduce_mean(tf.reduce_sum(tf.math.squared_difference(output, target), axis), name=name)
  123. return mse
  124. def normalized_mean_square_error(output, target, axis=-1, name="normalized_mean_squared_error_loss"):
  125. """Return the TensorFlow expression of normalized mean-square-error of two distributions.
  126. Parameters
  127. ----------
  128. output : Tensor
  129. 2D, 3D or 4D tensor i.e. [batch_size, n_feature], [batch_size, height, width] or [batch_size, height, width, channel].
  130. target : Tensor
  131. The target distribution, format the same with `output`.
  132. axis : int or list of int
  133. The dimensions to reduce.
  134. name : str
  135. An optional name to attach to this function.
  136. """
  137. with tf.name_scope("normalized_mean_squared_error_loss"):
  138. # if len(output.shape) == 2: # [batch_size, n_feature]
  139. # axis = 1
  140. # elif len(output.shape) == 3: # [batch_size, w, h]
  141. # axis = [1, 2]
  142. # elif len(output.shape) == 4: # [batch_size, w, h, c]
  143. # axis = [1, 2, 3]
  144. nmse_a = tf.sqrt(tf.reduce_sum(tf.math.squared_difference(output, target), axis=axis))
  145. nmse_b = tf.sqrt(tf.reduce_sum(tf.square(target), axis=axis))
  146. nmse = tf.reduce_mean(nmse_a / nmse_b, name=name)
  147. return nmse
  148. def absolute_difference_error(output, target, is_mean=False, axis=-1, name="absolute_difference_error_loss"):
  149. """Return the TensorFlow expression of absolute difference error (L1) of two batch of data.
  150. Parameters
  151. ----------
  152. output : Tensor
  153. 2D, 3D or 4D tensor i.e. [batch_size, n_feature], [batch_size, height, width] or [batch_size, height, width, channel].
  154. target : Tensor
  155. The target distribution, format the same with `output`.
  156. is_mean : boolean
  157. Whether compute the mean or sum for each example.
  158. - If True, use ``tf.reduce_mean`` to compute the loss between one target and predict data.
  159. - If False, use ``tf.reduce_sum`` (default).
  160. axis : int or list of int
  161. The dimensions to reduce.
  162. name : str
  163. An optional name to attach to this function.
  164. """
  165. # # with tf.name_scope("absolute_difference_error_loss"):
  166. # if len(output.shape) == 2: # [batch_size, n_feature]
  167. # axis = 1
  168. # elif len(output.shape) == 3: # [batch_size, w, h]
  169. # axis = [1, 2]
  170. # elif len(output.shape) == 4: # [batch_size, w, h, c]
  171. # axis = [1, 2, 3]
  172. # else:
  173. # raise Exception("Unknow dimension")
  174. if is_mean:
  175. loss = tf.reduce_mean(tf.reduce_mean(tf.abs(output - target), axis), name=name)
  176. else:
  177. loss = tf.reduce_mean(tf.reduce_sum(tf.abs(output - target), axis), name=name)
  178. return loss
  179. def dice_coe(output, target, loss_type='jaccard', axis=(1, 2, 3), smooth=1e-5):
  180. """Soft dice (Sørensen or Jaccard) coefficient for comparing the similarity
  181. of two batch of data, usually be used for binary image segmentation
  182. i.e. labels are binary. The coefficient between 0 to 1, 1 means totally match.
  183. Parameters
  184. -----------
  185. output : Tensor
  186. A distribution with shape: [batch_size, ....], (any dimensions).
  187. target : Tensor
  188. The target distribution, format the same with `output`.
  189. loss_type : str
  190. ``jaccard`` or ``sorensen``, default is ``jaccard``.
  191. axis : tuple of int
  192. All dimensions are reduced, default ``[1,2,3]``.
  193. smooth : float
  194. This small value will be added to the numerator and denominator.
  195. - If both output and target are empty, it makes sure dice is 1.
  196. - If either output or target are empty (all pixels are background), dice = ```smooth/(small_value + smooth)``, then if smooth is very small, dice close to 0 (even the image values lower than the threshold), so in this case, higher smooth can have a higher dice.
  197. Examples
  198. ---------
  199. >>> import tensorlayer as tl
  200. >>> outputs = tl.ops.softmax(outputs)
  201. >>> dice_loss = 1 - tl.cost.dice_coe(outputs, y_)
  202. References
  203. -----------
  204. - `Wiki-Dice <https://en.wikipedia.org/wiki/Sørensen–Dice_coefficient>`__
  205. """
  206. inse = tf.reduce_sum(output * target, axis=axis)
  207. if loss_type == 'jaccard':
  208. l = tf.reduce_sum(output * output, axis=axis)
  209. r = tf.reduce_sum(target * target, axis=axis)
  210. elif loss_type == 'sorensen':
  211. l = tf.reduce_sum(output, axis=axis)
  212. r = tf.reduce_sum(target, axis=axis)
  213. else:
  214. raise Exception("Unknow loss_type")
  215. # old axis=[0,1,2,3]
  216. # dice = 2 * (inse) / (l + r)
  217. # epsilon = 1e-5
  218. # dice = tf.clip_by_value(dice, 0, 1.0-epsilon) # if all empty, dice = 1
  219. # new haodong
  220. dice = (2. * inse + smooth) / (l + r + smooth)
  221. ##
  222. dice = tf.reduce_mean(dice, name='dice_coe')
  223. return dice
  224. def dice_hard_coe(output, target, threshold=0.5, axis=(1, 2, 3), smooth=1e-5):
  225. """Non-differentiable Sørensen–Dice coefficient for comparing the similarity
  226. of two batch of data, usually be used for binary image segmentation i.e. labels are binary.
  227. The coefficient between 0 to 1, 1 if totally match.
  228. Parameters
  229. -----------
  230. output : tensor
  231. A distribution with shape: [batch_size, ....], (any dimensions).
  232. target : tensor
  233. The target distribution, format the same with `output`.
  234. threshold : float
  235. The threshold value to be true.
  236. axis : tuple of integer
  237. All dimensions are reduced, default ``(1,2,3)``.
  238. smooth : float
  239. This small value will be added to the numerator and denominator, see ``dice_coe``.
  240. References
  241. -----------
  242. - `Wiki-Dice <https://en.wikipedia.org/wiki/Sørensen–Dice_coefficient>`__
  243. """
  244. output = tf.cast(output > threshold, dtype=tf.float32)
  245. target = tf.cast(target > threshold, dtype=tf.float32)
  246. inse = tf.reduce_sum(tf.multiply(output, target), axis=axis)
  247. l = tf.reduce_sum(output, axis=axis)
  248. r = tf.reduce_sum(target, axis=axis)
  249. # old axis=[0,1,2,3]
  250. # hard_dice = 2 * (inse) / (l + r)
  251. # epsilon = 1e-5
  252. # hard_dice = tf.clip_by_value(hard_dice, 0, 1.0-epsilon)
  253. # new haodong
  254. hard_dice = (2. * inse + smooth) / (l + r + smooth)
  255. ##
  256. hard_dice = tf.reduce_mean(hard_dice, name='hard_dice')
  257. return hard_dice
  258. def iou_coe(output, target, threshold=0.5, axis=(1, 2, 3), smooth=1e-5):
  259. """Non-differentiable Intersection over Union (IoU) for comparing the
  260. similarity of two batch of data, usually be used for evaluating binary image segmentation.
  261. The coefficient between 0 to 1, and 1 means totally match.
  262. Parameters
  263. -----------
  264. output : tensor
  265. A batch of distribution with shape: [batch_size, ....], (any dimensions).
  266. target : tensor
  267. The target distribution, format the same with `output`.
  268. threshold : float
  269. The threshold value to be true.
  270. axis : tuple of integer
  271. All dimensions are reduced, default ``(1,2,3)``.
  272. smooth : float
  273. This small value will be added to the numerator and denominator, see ``dice_coe``.
  274. Notes
  275. ------
  276. - IoU cannot be used as training loss, people usually use dice coefficient for training, IoU and hard-dice for evaluating.
  277. """
  278. pre = tf.cast(output > threshold, dtype=tf.float32)
  279. truth = tf.cast(target > threshold, dtype=tf.float32)
  280. inse = tf.reduce_sum(tf.multiply(pre, truth), axis=axis) # AND
  281. union = tf.reduce_sum(tf.cast(tf.add(pre, truth) >= 1, dtype=tf.float32), axis=axis) # OR
  282. # old axis=[0,1,2,3]
  283. # epsilon = 1e-5
  284. # batch_iou = inse / (union + epsilon)
  285. # new haodong
  286. batch_iou = (inse + smooth) / (union + smooth)
  287. iou = tf.reduce_mean(batch_iou, name='iou_coe')
  288. return iou # , pre, truth, inse, union
  289. # ## test soft/hard dice and iou
  290. # import numpy as np
  291. # y = np.zeros((1,10,10,1))
  292. # # y[0,0:5,0:5]=1.0
  293. # o = np.zeros((1,10,10,1))
  294. # # o[:,:,:,:] = 0 # what we want: dice=0 iou=0 OK
  295. # # o[0,0:2,0:2]=0.3 # what we want: dice larger iou=0 OK
  296. # # o[0,0:2,0:2]=0.6 # what we want: dice larger iou small OK
  297. # # o[0,0:3,0:3]=0.6 # what we want: dice larger iou larger OK
  298. # # o[0,0:3,0:3]=1 # what we want: dice larger iou same OK
  299. # # o[0,0:5,0:5]=1 # what we want: dice=1 iou=1 OK
  300. # # o[0,0:5,0:5]=0.3 # what we want: dice smaller iou=0 OK
  301. # # o[0,0:5,0:5]=1e-2 # what we want: dice≈0 iou=0 OK
  302. # # o[0,8:10,8:10]=1.0 # what we want: dice=0 iou=0 OK
  303. # # o[0,8:10,8:10]=1e-10 # what we want: dice=0 iou=0 OK
  304. # # y[:,:,:,:] = o[:,:,:,:] = 0 # what we want: dice=1 iou=1 OK
  305. # ## why in u-net, dice=1 hard-dice=1 iou=1 exist?? print bug?
  306. #
  307. # d = dice_coe(o, y, 'jaccard', smooth=1.)
  308. # hd = dice_hard_coe(o, y, smooth=1e-5)
  309. # i = iou_coe(o, y, smooth=1e-5)
  310. # sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
  311. # # sess.run(tf.local_variables_initializer())
  312. # print(sess.run([d,hd,i]))
  313. # # p, t, i, u = sess.run([pre, truth, inse, union])
  314. # # import pprint
  315. # # pprint.pprint(((y>0.5)*(o>0.5)).astype(int).tolist())
  316. # # pprint.pprint(p.tolist())
  317. # # pprint.pprint(t.tolist())
  318. # # pprint.pprint(i)
  319. # # pprint.pprint(u)
  320. # exit()
  321. def sequence_loss_by_example(
  322. logits, targets, weights, average_across_timesteps=True, softmax_loss_function=None, name=None
  323. ):
  324. """Weighted cross-entropy loss for a sequence of logits (per example). see original tensorflow code :
  325. <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py#L1057>
  326. Parameters
  327. ----------
  328. logits: List
  329. List of 2D Tensors of shape [batch_size x num_decoder_symbols].
  330. targets: List
  331. List of 1D batch-sized int32 Tensors of the same length as logits.
  332. weights: List
  333. List of 1D batch-sized float-Tensors of the same length as logits.
  334. average_across_timesteps: Boolean
  335. If set, divide the returned cost by the total label weight.
  336. softmax_loss_function: None or Function
  337. Function (labels, logits) -> loss-batch to be used instead of the standard softmax (the default if this is None).
  338. **Note that to avoid confusion, it is required for the function to accept named arguments.**
  339. name: None or str
  340. Optional name for this operation, default: "sequence_loss_by_example".
  341. Returns
  342. -------
  343. 1D batch-sized float Tensor: The log-perplexity for each sequence.
  344. Raises
  345. ------
  346. ValueError: If len(logits) is different from len(targets) or len(weights).
  347. """
  348. if len(targets) != len(logits) or len(weights) != len(logits):
  349. raise ValueError(
  350. "Lengths of logits, weights, and targets must be the same "
  351. "%d, %d, %d." % (len(logits), len(weights), len(targets))
  352. )
  353. with ops.name_scope(name, "sequence_loss_by_example", logits + targets + weights):
  354. log_perp_list = []
  355. for logit, target, weight in zip(logits, targets, weights):
  356. if softmax_loss_function is None:
  357. # TODO(irving,ebrevdo): This reshape is needed because
  358. # sequence_loss_by_example is called with scalars sometimes, which
  359. # violates our general scalar strictness policy.
  360. target = array_ops.reshape(target, [-1])
  361. crossent = nn_ops.sparse_softmax_cross_entropy_with_logits(labels=target, logits=logit)
  362. else:
  363. crossent = softmax_loss_function(labels=target, logits=logit)
  364. log_perp_list.append(crossent * weight)
  365. log_perps = math_ops.add_n(log_perp_list)
  366. if average_across_timesteps:
  367. total_size = math_ops.add_n(weights)
  368. total_size += 1e-12 # Just to avoid division by 0 for all-0 weights.
  369. log_perps /= total_size
  370. return log_perps
  371. def cross_entropy_seq(logits, target_seqs, batch_size=None):
  372. """Returns the expression of cross-entropy of two sequences, implement
  373. softmax internally. Normally be used for fixed length RNN outputs, see `PTB example <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_ptb_lstm.py>`__.
  374. Parameters
  375. ----------
  376. logits : Tensor
  377. 2D tensor with shape of `[batch_size * n_steps, n_classes]`.
  378. target_seqs : Tensor
  379. The target sequence, 2D tensor `[batch_size, n_steps]`, if the number of step is dynamic, please use ``tl.cost.cross_entropy_seq_with_mask`` instead.
  380. batch_size : None or int.
  381. Whether to divide the cost by batch size.
  382. - If integer, the return cost will be divided by `batch_size`.
  383. - If None (default), the return cost will not be divided by anything.
  384. Examples
  385. --------
  386. >>> import tensorlayer as tl
  387. >>> # see `PTB example <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_ptb_lstm.py>`__.for more details
  388. >>> # outputs shape : (batch_size * n_steps, n_classes)
  389. >>> # targets shape : (batch_size, n_steps)
  390. >>> cost = tl.cost.cross_entropy_seq(outputs, targets)
  391. """
  392. sequence_loss_by_example_fn = sequence_loss_by_example
  393. loss = sequence_loss_by_example_fn(
  394. [logits], [tf.reshape(target_seqs, [-1])], [tf.ones_like(tf.reshape(target_seqs, [-1]), dtype=tf.float32)]
  395. )
  396. # [tf.ones([batch_size * num_steps])])
  397. cost = tf.reduce_sum(loss) # / batch_size
  398. if batch_size is not None:
  399. cost = cost / batch_size
  400. return cost
  401. def cross_entropy_seq_with_mask(logits, target_seqs, input_mask, return_details=False, name=None):
  402. """Returns the expression of cross-entropy of two sequences, implement
  403. softmax internally. Normally be used for Dynamic RNN with Synced sequence input and output.
  404. Parameters
  405. -----------
  406. logits : Tensor
  407. 2D tensor with shape of [batch_size * ?, n_classes], `?` means dynamic IDs for each example.
  408. - Can be get from `DynamicRNNLayer` by setting ``return_seq_2d`` to `True`.
  409. target_seqs : Tensor
  410. int of tensor, like word ID. [batch_size, ?], `?` means dynamic IDs for each example.
  411. input_mask : Tensor
  412. The mask to compute loss, it has the same size with `target_seqs`, normally 0 or 1.
  413. return_details : boolean
  414. Whether to return detailed losses.
  415. - If False (default), only returns the loss.
  416. - If True, returns the loss, losses, weights and targets (see source code).
  417. Examples
  418. --------
  419. >>> import tensorlayer as tl
  420. >>> import tensorflow as tf
  421. >>> import numpy as np
  422. >>> batch_size = 64
  423. >>> vocab_size = 10000
  424. >>> embedding_size = 256
  425. >>> ni = tl.layers.Input([batch_size, None], dtype=tf.int64)
  426. >>> net_lits = []
  427. >>> net_list.append(tl.layers.Embedding(
  428. ... vocabulary_size = vocab_size,
  429. ... embedding_size = embedding_size,
  430. ... name = 'seq_embedding'))
  431. >>> net_list.append(tl.layers.RNN(
  432. ... cell =tf.keras.layers.LSTMCell(units=embedding_size, dropout=0.1),
  433. ... return_seq_2d = True,
  434. ... name = 'dynamicrnn'))
  435. >>> net_list.append(tl.layers.Dense(n_units=vocab_size, name="output"))
  436. >>> model = tl.layers.SequentialLayer(net_list)
  437. >>> input_seqs = np.random.randint(0, 10, size=(batch_size, 10), dtype=np.int64)
  438. >>> target_seqs = np.random.randint(0, 10, size=(batch_size, 10), dtype=np.int64)
  439. >>> input_mask = np.random.randint(0, 2, size=(batch_size, 10), dtype=np.int64)
  440. >>> outputs = model(input_seqs)
  441. >>> loss = tl.cost.cross_entropy_seq_with_mask(outputs, target_seqs, input_mask)
  442. """
  443. targets = tf.reshape(target_seqs, [-1]) # to one vector
  444. weights = tf.cast(tf.reshape(input_mask, [-1]), dtype=tf.float32) # to one vector like targets
  445. losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=targets, name=name) * weights
  446. # losses = tf.reduce_mean(tf.ops.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=targets, name=name)) # for TF1.0 and others
  447. loss = tf.divide(
  448. tf.reduce_sum(losses), # loss from mask. reduce_sum before element-wise mul with mask !!
  449. tf.reduce_sum(weights),
  450. name="seq_loss_with_mask"
  451. )
  452. if return_details:
  453. return loss, losses, weights, targets
  454. else:
  455. return loss
  456. def cosine_similarity(v1, v2):
  457. """Cosine similarity [-1, 1].
  458. Parameters
  459. ----------
  460. v1, v2 : Tensor
  461. Tensor with the same shape [batch_size, n_feature].
  462. References
  463. ----------
  464. - `Wiki <https://en.wikipedia.org/wiki/Cosine_similarity>`__.
  465. """
  466. return tf.reduce_sum(tf.multiply(v1, v2), 1) / \
  467. (tf.sqrt(tf.reduce_sum(tf.multiply(v1, v1), 1)) *
  468. tf.sqrt(tf.reduce_sum(tf.multiply(v2, v2), 1)))
  469. # Regularization Functions
  470. def li_regularizer(scale, scope=None):
  471. """Li regularization removes the neurons of previous layer. The `i` represents `inputs`.
  472. Returns a function that can be used to apply group li regularization to weights.
  473. The implementation follows `TensorFlow contrib <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/layers/python/layers/regularizers.py>`__.
  474. Parameters
  475. ----------
  476. scale : float
  477. A scalar multiplier `Tensor`. 0.0 disables the regularizer.
  478. scope: str
  479. An optional scope name for this function.
  480. Returns
  481. --------
  482. A function with signature `li(weights, name=None)` that apply Li regularization.
  483. Raises
  484. ------
  485. ValueError : if scale is outside of the range [0.0, 1.0] or if scale is not a float.
  486. """
  487. if isinstance(scale, numbers.Integral):
  488. raise ValueError('scale cannot be an integer: %s' % scale)
  489. if isinstance(scale, numbers.Real):
  490. if scale < 0.:
  491. raise ValueError('Setting a scale less than 0 on a regularizer: %g' % scale)
  492. if scale >= 1.:
  493. raise ValueError('Setting a scale greater than 1 on a regularizer: %g' % scale)
  494. if scale == 0.:
  495. logging.info('Scale of 0 disables regularizer.')
  496. return lambda _, name=None: None
  497. def li(weights):
  498. """Applies li regularization to weights."""
  499. with tf.name_scope('li_regularizer') as scope:
  500. my_scale = ops.convert_to_tensor(scale, dtype=weights.dtype.base_dtype, name='scale')
  501. # if tf.__version__ <= '0.12':
  502. # standard_ops_fn = standard_ops.mul
  503. # else:
  504. standard_ops_fn = standard_ops.multiply
  505. return standard_ops_fn(
  506. my_scale, standard_ops.reduce_sum(standard_ops.sqrt(standard_ops.reduce_sum(tf.square(weights), 1))),
  507. name=scope
  508. )
  509. return li
  510. def lo_regularizer(scale):
  511. """Lo regularization removes the neurons of current layer. The `o` represents `outputs`
  512. Returns a function that can be used to apply group lo regularization to weights.
  513. The implementation follows `TensorFlow contrib <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/layers/python/layers/regularizers.py>`__.
  514. Parameters
  515. ----------
  516. scale : float
  517. A scalar multiplier `Tensor`. 0.0 disables the regularizer.
  518. Returns
  519. -------
  520. A function with signature `lo(weights, name=None)` that apply Lo regularization.
  521. Raises
  522. ------
  523. ValueError : If scale is outside of the range [0.0, 1.0] or if scale is not a float.
  524. """
  525. if isinstance(scale, numbers.Integral):
  526. raise ValueError('scale cannot be an integer: %s' % scale)
  527. if isinstance(scale, numbers.Real):
  528. if scale < 0.:
  529. raise ValueError('Setting a scale less than 0 on a regularizer: %g' % scale)
  530. if scale >= 1.:
  531. raise ValueError('Setting a scale greater than 1 on a regularizer: %g' % scale)
  532. if scale == 0.:
  533. logging.info('Scale of 0 disables regularizer.')
  534. return lambda _, name=None: None
  535. def lo(weights, name='lo_regularizer'):
  536. """Applies group column regularization to weights."""
  537. with tf.name_scope(name) as scope:
  538. my_scale = ops.convert_to_tensor(scale, dtype=weights.dtype.base_dtype, name='scale')
  539. # if tf.__version__ <= '0.12':
  540. # standard_ops_fn = standard_ops.mul
  541. # else:
  542. standard_ops_fn = standard_ops.multiply
  543. return standard_ops_fn(
  544. my_scale, standard_ops.reduce_sum(standard_ops.sqrt(standard_ops.reduce_sum(tf.square(weights), 0))),
  545. name=scope
  546. )
  547. return lo
  548. def maxnorm_regularizer(scale=1.0):
  549. """Max-norm regularization returns a function that can be used to apply max-norm regularization to weights.
  550. More about max-norm, see `wiki-max norm <https://en.wikipedia.org/wiki/Matrix_norm#Max_norm>`_.
  551. The implementation follows `TensorFlow contrib <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/layers/python/layers/regularizers.py>`__.
  552. Parameters
  553. ----------
  554. scale : float
  555. A scalar multiplier `Tensor`. 0.0 disables the regularizer.
  556. Returns
  557. ---------
  558. A function with signature `mn(weights, name=None)` that apply Lo regularization.
  559. Raises
  560. --------
  561. ValueError : If scale is outside of the range [0.0, 1.0] or if scale is not a float.
  562. """
  563. if isinstance(scale, numbers.Integral):
  564. raise ValueError('scale cannot be an integer: %s' % scale)
  565. if isinstance(scale, numbers.Real):
  566. if scale < 0.:
  567. raise ValueError('Setting a scale less than 0 on a regularizer: %g' % scale)
  568. # if scale >= 1.:
  569. # raise ValueError('Setting a scale greater than 1 on a regularizer: %g' %
  570. # scale)
  571. if scale == 0.:
  572. logging.info('Scale of 0 disables regularizer.')
  573. return lambda _, name=None: None
  574. def mn(weights, name='max_regularizer'):
  575. """Applies max-norm regularization to weights."""
  576. with tf.name_scope(name) as scope:
  577. my_scale = ops.convert_to_tensor(scale, dtype=weights.dtype.base_dtype, name='scale')
  578. # if tf.__version__ <= '0.12':
  579. # standard_ops_fn = standard_ops.mul
  580. # else:
  581. standard_ops_fn = standard_ops.multiply
  582. return standard_ops_fn(my_scale, standard_ops.reduce_max(standard_ops.abs(weights)), name=scope)
  583. return mn
  584. def maxnorm_o_regularizer(scale):
  585. """Max-norm output regularization removes the neurons of current layer.
  586. Returns a function that can be used to apply max-norm regularization to each column of weight matrix.
  587. The implementation follows `TensorFlow contrib <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/layers/python/layers/regularizers.py>`__.
  588. Parameters
  589. ----------
  590. scale : float
  591. A scalar multiplier `Tensor`. 0.0 disables the regularizer.
  592. Returns
  593. ---------
  594. A function with signature `mn_o(weights, name=None)` that apply Lo regularization.
  595. Raises
  596. ---------
  597. ValueError : If scale is outside of the range [0.0, 1.0] or if scale is not a float.
  598. """
  599. if isinstance(scale, numbers.Integral):
  600. raise ValueError('scale cannot be an integer: %s' % scale)
  601. if isinstance(scale, numbers.Real):
  602. if scale < 0.:
  603. raise ValueError('Setting a scale less than 0 on a regularizer: %g' % scale)
  604. # if scale >= 1.:
  605. # raise ValueError('Setting a scale greater than 1 on a regularizer: %g' %
  606. # scale)
  607. if scale == 0.:
  608. logging.info('Scale of 0 disables regularizer.')
  609. return lambda _, name=None: None
  610. def mn_o(weights, name='maxnorm_o_regularizer'):
  611. """Applies max-norm regularization to weights."""
  612. with tf.name_scope(name) as scope:
  613. my_scale = ops.convert_to_tensor(scale, dtype=weights.dtype.base_dtype, name='scale')
  614. if tf.__version__ <= '0.12':
  615. standard_ops_fn = standard_ops.mul
  616. else:
  617. standard_ops_fn = standard_ops.multiply
  618. return standard_ops_fn(
  619. my_scale, standard_ops.reduce_sum(standard_ops.reduce_max(standard_ops.abs(weights), 0)), name=scope
  620. )
  621. return mn_o
  622. def maxnorm_i_regularizer(scale):
  623. """Max-norm input regularization removes the neurons of previous layer.
  624. Returns a function that can be used to apply max-norm regularization to each row of weight matrix.
  625. The implementation follows `TensorFlow contrib <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/layers/python/layers/regularizers.py>`__.
  626. Parameters
  627. ----------
  628. scale : float
  629. A scalar multiplier `Tensor`. 0.0 disables the regularizer.
  630. Returns
  631. ---------
  632. A function with signature `mn_i(weights, name=None)` that apply Lo regularization.
  633. Raises
  634. ---------
  635. ValueError : If scale is outside of the range [0.0, 1.0] or if scale is not a float.
  636. """
  637. if isinstance(scale, numbers.Integral):
  638. raise ValueError('scale cannot be an integer: %s' % scale)
  639. if isinstance(scale, numbers.Real):
  640. if scale < 0.:
  641. raise ValueError('Setting a scale less than 0 on a regularizer: %g' % scale)
  642. # if scale >= 1.:
  643. # raise ValueError('Setting a scale greater than 1 on a regularizer: %g' %
  644. # scale)
  645. if scale == 0.:
  646. logging.info('Scale of 0 disables regularizer.')
  647. return lambda _, name=None: None
  648. def mn_i(weights, name='maxnorm_i_regularizer'):
  649. """Applies max-norm regularization to weights."""
  650. with tf.name_scope(name) as scope:
  651. my_scale = ops.convert_to_tensor(scale, dtype=weights.dtype.base_dtype, name='scale')
  652. if tf.__version__ <= '0.12':
  653. standard_ops_fn = standard_ops.mul
  654. else:
  655. standard_ops_fn = standard_ops.multiply
  656. return standard_ops_fn(
  657. my_scale, standard_ops.reduce_sum(standard_ops.reduce_max(standard_ops.abs(weights), 1)), name=scope
  658. )
  659. return mn_i
  660. def huber_loss(
  661. output, target, is_mean=True, delta=1.0, dynamichuber=False, reverse=False, axis=-1, epsilon=0.00001, name=None
  662. ):
  663. """Huber Loss operation, see ``https://en.wikipedia.org/wiki/Huber_loss`` .
  664. Reverse Huber Loss operation, see ''https://statweb.stanford.edu/~owen/reports/hhu.pdf''.
  665. Dynamic Reverse Huber Loss operation, see ''https://arxiv.org/pdf/1606.00373.pdf''.
  666. Parameters
  667. ----------
  668. output : Tensor
  669. A distribution with shape: [batch_size, ....], (any dimensions).
  670. target : Tensor
  671. The target distribution, format the same with `output`.
  672. is_mean : boolean
  673. Whether compute the mean or sum for each example.
  674. - If True, use ``tf.reduce_mean`` to compute the loss between one target and predict data (default).
  675. - If False, use ``tf.reduce_sum``.
  676. delta: float
  677. The point where the huber loss function changes from a quadratic to linear.
  678. dynamichuber: boolean
  679. Whether compute the coefficient c for each batch.
  680. - If True, c is 20% of the maximal per-batch error.
  681. - If False, c is delta.
  682. reverse: boolean
  683. Whether compute the reverse huber loss.
  684. axis : int or list of int
  685. The dimensions to reduce.
  686. epsilon:
  687. Eplison.
  688. name : string
  689. Name of this loss.
  690. """
  691. if reverse:
  692. if dynamichuber:
  693. huber_c = 0.2 * tf.reduce_max(tf.abs(output - target))
  694. else:
  695. huber_c = delta
  696. if is_mean:
  697. loss = tf.reduce_mean(
  698. tf.where(
  699. tf.less_equal(tf.abs(output - target), huber_c), tf.abs(output - target),
  700. tf.multiply(
  701. tf.pow(output - target, 2.0) + tf.pow(huber_c, 2.0),
  702. tf.math.divide_no_nan(.5, huber_c + epsilon)
  703. )
  704. ), name=name
  705. )
  706. else:
  707. loss = tf.reduce_mean(
  708. tf.reduce_sum(
  709. tf.where(
  710. tf.less_equal(tf.abs(output - target), huber_c), tf.abs(output - target),
  711. tf.multiply(
  712. tf.pow(output - target, 2.0) + tf.pow(huber_c, 2.0),
  713. tf.math.divide_no_nan(.5, huber_c + epsilon)
  714. )
  715. ), axis
  716. ), name=name
  717. )
  718. elif is_mean:
  719. loss = tf.reduce_mean(
  720. tf.where(
  721. tf.less_equal(tf.abs(output - target), delta), 0.5 * tf.pow(output - target, 2),
  722. delta * (tf.abs(output - target) - 0.5 * delta)
  723. ), name=name
  724. )
  725. else:
  726. loss = tf.reduce_mean(
  727. tf.reduce_sum(
  728. tf.where(
  729. tf.less_equal(tf.abs(output - target), delta), 0.5 * tf.pow(output - target, 2),
  730. delta * (tf.abs(output - target) - 0.5 * delta)
  731. ), axis
  732. ), name=name
  733. )
  734. return loss

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.