You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensorflow_nn.py 62 kB

4 years ago
4 years ago
4 years ago
4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import tensorflow as tf
  4. from tensorflow.python.framework import ops
  5. from tensorflow.python.ops import math_ops
  6. from tensorflow.python.training import moving_averages
  7. from math import floor, ceil
  8. # loss function
  9. sparse_softmax_cross_entropy_with_logits = tf.nn.sparse_softmax_cross_entropy_with_logits
  10. sigmoid_cross_entropy_with_logits = tf.nn.sigmoid_cross_entropy_with_logits
  11. def padding_format(padding):
  12. """
  13. Checks that the padding format correspond format.
  14. Parameters
  15. ----------
  16. padding : str
  17. Must be one of the following:"same", "SAME", "VALID", "valid"
  18. Returns
  19. -------
  20. str "SAME" or "VALID"
  21. """
  22. if padding in ["SAME", "same"]:
  23. padding = "SAME"
  24. elif padding in ["VALID", "valid"]:
  25. padding = "VALID"
  26. elif padding == None:
  27. padding = None
  28. else:
  29. raise Exception("Unsupported padding: " + str(padding))
  30. return padding
  31. def preprocess_1d_format(data_format, padding):
  32. """
  33. Checks that the 1-D dataformat format correspond format.
  34. Parameters
  35. ----------
  36. data_format : str
  37. Must be one of the following:"channels_last","NWC","NCW","channels_first"
  38. padding : str
  39. Must be one of the following:"same","valid","SAME","VALID"
  40. Returns
  41. -------
  42. str "NWC" or "NCW" and "SAME" or "VALID"
  43. """
  44. if data_format in ["channels_last", "NWC"]:
  45. data_format = "NWC"
  46. elif data_format in ["channels_first", "NCW"]:
  47. data_format = "NCW"
  48. elif data_format == None:
  49. data_format = None
  50. else:
  51. raise Exception("Unsupported data format: " + str(data_format))
  52. padding = padding_format(padding)
  53. return data_format, padding
  54. def preprocess_2d_format(data_format, padding):
  55. """
  56. Checks that the 2-D dataformat format correspond format.
  57. Parameters
  58. ----------
  59. data_format : str
  60. Must be one of the following:"channels_last","NHWC","NCHW","channels_first"
  61. padding : str
  62. Must be one of the following:"same","valid","SAME","VALID"
  63. Returns
  64. -------
  65. str "NHWC" or "NCHW" and "SAME" or "VALID"
  66. """
  67. if data_format in ["channels_last", "NHWC"]:
  68. data_format = "NHWC"
  69. elif data_format in ["channels_first", "NCHW"]:
  70. data_format = "NCHW"
  71. elif data_format == None:
  72. data_format = None
  73. else:
  74. raise Exception("Unsupported data format: " + str(data_format))
  75. padding = padding_format(padding)
  76. return data_format, padding
  77. def preprocess_3d_format(data_format, padding):
  78. """
  79. Checks that the 3-D dataformat format correspond format.
  80. Parameters
  81. ----------
  82. data_format : str
  83. Must be one of the following:"channels_last","NDHWC","NCDHW","channels_first"
  84. padding : str
  85. Must be one of the following:"same","valid","SAME","VALID"
  86. Returns
  87. -------
  88. str "NDHWC" or "NCDHW" and "SAME" or "VALID"
  89. """
  90. if data_format in ['channels_last', 'NDHWC']:
  91. data_format = 'NDHWC'
  92. elif data_format in ['channels_first', 'NCDHW']:
  93. data_format = 'NCDHW'
  94. elif data_format == None:
  95. data_format = None
  96. else:
  97. raise Exception("Unsupported data format: " + str(data_format))
  98. padding = padding_format(padding)
  99. return data_format, padding
  100. def nchw_to_nhwc(x):
  101. """
  102. Channels first to channels last
  103. Parameters
  104. ----------
  105. x : tensor
  106. channels first tensor data
  107. Returns
  108. -------
  109. channels last tensor data
  110. """
  111. if len(x.shape) == 3:
  112. x = tf.transpose(x, (0, 2, 1))
  113. elif len(x.shape) == 4:
  114. x = tf.transpose(x, (0, 2, 3, 1))
  115. elif len(x.shape) == 5:
  116. x = tf.transpose(x, (0, 2, 3, 4, 1))
  117. else:
  118. raise Exception("Unsupported dimensions")
  119. return x
  120. def nhwc_to_nchw(x):
  121. """
  122. Channles last to channels first
  123. Parameters
  124. ----------
  125. x : tensor
  126. channels last tensor data
  127. Returns
  128. -------
  129. channels first tensor data
  130. """
  131. if len(x.shape) == 3:
  132. x = tf.transpose(x, (0, 2, 1))
  133. elif len(x.shape) == 4:
  134. x = tf.transpose(x, (0, 3, 1, 2))
  135. elif len(x.shape) == 5:
  136. x = tf.transpose(x, (0, 4, 1, 2, 3))
  137. else:
  138. raise Exception("Unsupported dimensions")
  139. return x
  140. class ReLU(object):
  141. def __init__(self):
  142. pass
  143. def __call__(self, x):
  144. return tf.nn.relu(x)
  145. def relu(x):
  146. """
  147. Computes rectified linear: max(features, 0).
  148. Parameters
  149. ----------
  150. x : tensor
  151. Must be one of the following types: float32, float64, int32, uint8, int16,
  152. int8, int64, bfloat16, uint16, half, uint32, uint64, qint8.
  153. Returns
  154. -------
  155. A Tensor. Has the same type as features.
  156. """
  157. return tf.nn.relu(x)
  158. class ReLU6(object):
  159. def __init__(self):
  160. pass
  161. def __call__(self, x):
  162. return tf.nn.relu6(x)
  163. def relu6(x):
  164. """
  165. Computes Rectified Linear 6: min(max(features, 0), 6).
  166. Parameters
  167. ----------
  168. x : tensor
  169. Must be one of the following types: float32, float64, int32, uint8, int16,
  170. int8, int64, bfloat16, uint16, half, uint32, uint64, qint8.
  171. Returns
  172. -------
  173. A Tensor with the same type as features.
  174. """
  175. return tf.nn.relu6(x)
  176. class LeakyReLU(object):
  177. def __init__(self, alpha=0.2):
  178. self.alpha = alpha
  179. def __call__(self, x):
  180. return tf.nn.leaky_relu(x, alpha=self.alpha)
  181. def leaky_relu(x, alpha=0.2):
  182. """
  183. Compute the Leaky ReLU activation function.
  184. Parameters
  185. ----------
  186. x : tensor
  187. representing preactivation values. Must be one of the following types:
  188. float16, float32, float64, int32, int64.
  189. Returns
  190. -------
  191. The activation value.
  192. """
  193. return tf.nn.leaky_relu(x, alpha=alpha)
  194. class Softplus(object):
  195. def __init__(self):
  196. pass
  197. def __call__(self, x):
  198. return tf.nn.softplus(x)
  199. def softplus(x):
  200. """
  201. Computes softplus: log(exp(features) + 1).
  202. Parameters
  203. ----------
  204. x : tensor
  205. Must be one of the following types: half, bfloat16, float32, float64.
  206. Returns
  207. -------
  208. A Tensor. Has the same type as features.
  209. """
  210. return tf.nn.softplus(x)
  211. class Tanh(object):
  212. def __init__(self):
  213. pass
  214. def __call__(self, x):
  215. return tf.nn.tanh(x)
  216. def tanh(x):
  217. """
  218. Computes hyperbolic tangent of x element-wise.
  219. Parameters
  220. ----------
  221. x : tensor
  222. Must be one of the following types: bfloat16, half, float32, float64, complex64, complex128.
  223. Returns
  224. -------
  225. A Tensor. Has the same type as x.
  226. """
  227. return tf.nn.tanh(x)
  228. class Sigmoid(object):
  229. def __init__(self):
  230. pass
  231. def __call__(self, x):
  232. return tf.nn.sigmoid(x)
  233. def sigmoid(x):
  234. """
  235. Computes sigmoid of x element-wise.
  236. Parameters
  237. ----------
  238. x : tensor
  239. A Tensor with type float16, float32, float64, complex64, or complex128.
  240. Returns
  241. -------
  242. A Tensor with the same type as x.
  243. """
  244. return tf.nn.sigmoid(x)
  245. class Softmax(object):
  246. def __init__(self):
  247. pass
  248. def __call__(self, x):
  249. return tf.nn.softmax(x)
  250. def softmax(logits, axis=None):
  251. """
  252. Computes softmax activations.
  253. Parameters
  254. ----------
  255. logits : tensor
  256. Must be one of the following types: half, float32, float64.
  257. axis : int
  258. The dimension softmax would be performed on. The default is -1 which indicates the last dimension.
  259. Returns
  260. -------
  261. A Tensor. Has the same type and shape as logits.
  262. """
  263. return tf.nn.softmax(logits, axis)
  264. class Dropout(object):
  265. def __init__(self, keep, seed=0):
  266. self.keep = keep
  267. self.seed = seed
  268. def __call__(self, inputs, *args, **kwargs):
  269. outputs = tf.nn.dropout(inputs, rate=1 - (self.keep), seed=self.seed)
  270. return outputs
  271. class BiasAdd(object):
  272. """
  273. Adds bias to value.
  274. Parameters
  275. ----------
  276. x : tensor
  277. A Tensor with type float, double, int64, int32, uint8, int16, int8, complex64, or complex128.
  278. bias : tensor
  279. Must be the same type as value unless value is a quantized type,
  280. in which case a different quantized type may be used.
  281. Returns
  282. -------
  283. A Tensor with the same type as value.
  284. """
  285. def __init__(self, data_format=None):
  286. self.data_format = data_format
  287. def __call__(self, x, bias):
  288. return tf.nn.bias_add(x, bias, data_format=self.data_format)
  289. def bias_add(x, bias, data_format=None, name=None):
  290. """
  291. Adds bias to value.
  292. Parameters
  293. ----------
  294. x : tensor
  295. A Tensor with type float, double, int64, int32, uint8, int16, int8, complex64, or complex128.
  296. bias : tensor
  297. Must be the same type as value unless value is a quantized type,
  298. in which case a different quantized type may be used.
  299. data_format : A string.
  300. 'N...C' and 'NC...' are supported.
  301. name : str
  302. A name for the operation (optional).
  303. Returns
  304. -------
  305. A Tensor with the same type as value.
  306. """
  307. x = tf.nn.bias_add(x, bias, data_format=data_format, name=name)
  308. return x
  309. class Conv1D(object):
  310. def __init__(self, stride, padding, data_format='NWC', dilations=None, out_channel=None, k_size=None):
  311. self.stride = stride
  312. self.dilations = dilations
  313. self.data_format, self.padding = preprocess_1d_format(data_format, padding)
  314. def __call__(self, input, filters):
  315. outputs = tf.nn.conv1d(
  316. input=input,
  317. filters=filters,
  318. stride=self.stride,
  319. padding=self.padding,
  320. data_format=self.data_format,
  321. dilations=self.dilations,
  322. # name=name
  323. )
  324. return outputs
  325. def conv1d(input, filters, stride, padding, data_format='NWC', dilations=None):
  326. """
  327. Computes a 1-D convolution given 3-D input and filter tensors.
  328. Parameters
  329. ----------
  330. input : tensor
  331. A 3D Tensor. Must be of type float16, float32, or float64
  332. filters : tensor
  333. A 3D Tensor. Must have the same type as input.
  334. stride : int of list
  335. An int or list of ints that has length 1 or 3. The number of entries by which the filter is moved right at each step.
  336. padding : string
  337. 'SAME' or 'VALID'
  338. data_format : string
  339. An optional string from "NWC", "NCW". Defaults to "NWC", the data is stored in the order of
  340. [batch, in_width, in_channels]. The "NCW" format stores data as [batch, in_channels, in_width].
  341. dilations : int or list
  342. An int or list of ints that has length 1 or 3 which defaults to 1.
  343. The dilation factor for each dimension of input. If set to k > 1,
  344. there will be k-1 skipped cells between each filter element on that dimension.
  345. Dilations in the batch and depth dimensions must be 1.
  346. name : string
  347. A name for the operation (optional).
  348. Returns
  349. -------
  350. A Tensor. Has the same type as input.
  351. """
  352. data_format, padding = preprocess_1d_format(data_format, padding)
  353. outputs = tf.nn.conv1d(
  354. input=input,
  355. filters=filters,
  356. stride=stride,
  357. padding=padding,
  358. data_format=data_format,
  359. dilations=dilations,
  360. # name=name
  361. )
  362. return outputs
  363. class Conv2D(object):
  364. def __init__(self, strides, padding, data_format='NHWC', dilations=None, out_channel=None, k_size=None):
  365. self.strides = strides
  366. self.dilations = dilations
  367. self.data_format, self.padding = preprocess_2d_format(data_format, padding)
  368. def __call__(self, input, filters):
  369. outputs = tf.nn.conv2d(
  370. input=input,
  371. filters=filters,
  372. strides=self.strides,
  373. padding=self.padding,
  374. data_format=self.data_format,
  375. dilations=self.dilations,
  376. )
  377. return outputs
  378. def conv2d(input, filters, strides, padding, data_format='NHWC', dilations=None):
  379. """
  380. Computes a 2-D convolution given 4-D input and filters tensors.
  381. Parameters
  382. ----------
  383. input : tensor
  384. Must be one of the following types: half, bfloat16, float32, float64. A 4-D tensor.
  385. The dimension order is interpreted according to the value of data_format, see below for details.
  386. filters : tensor
  387. Must have the same type as input. A 4-D tensor of shape [filter_height, filter_width, in_channels, out_channels]
  388. strides : int of list
  389. The stride of the sliding window for each dimension of input. If a single value is given it is replicated in the H and W dimension.
  390. By default the N and C dimensions are set to 1. The dimension order is determined by the value of data_format, see below for details.
  391. padding : string
  392. "SAME" or "VALID"
  393. data_format : string
  394. "NHWC", "NCHW". Defaults to "NHWC".
  395. dilations : list or ints
  396. list of ints that has length 1, 2 or 4, defaults to 1. The dilation factor for each dimension ofinput.
  397. name : string
  398. A name for the operation (optional).
  399. Returns
  400. -------
  401. A Tensor. Has the same type as input.
  402. """
  403. data_format, padding = preprocess_2d_format(data_format, padding)
  404. outputs = tf.nn.conv2d(
  405. input=input,
  406. filters=filters,
  407. strides=strides,
  408. padding=padding,
  409. data_format=data_format,
  410. dilations=dilations,
  411. )
  412. return outputs
  413. class Conv3D(object):
  414. def __init__(self, strides, padding, data_format='NDHWC', dilations=None, out_channel=None, k_size=None):
  415. self.strides = strides
  416. self.dilations = dilations
  417. self.data_format, self.padding = preprocess_3d_format(data_format, padding)
  418. def __call__(self, input, filters):
  419. outputs = tf.nn.conv3d(
  420. input=input,
  421. filters=filters,
  422. strides=self.strides,
  423. padding=self.padding,
  424. data_format=self.data_format,
  425. dilations=self.dilations,
  426. )
  427. return outputs
  428. def conv3d(input, filters, strides, padding, data_format='NDHWC', dilations=None):
  429. """
  430. Computes a 3-D convolution given 5-D input and filters tensors.
  431. Parameters
  432. ----------
  433. input : tensor
  434. Must be one of the following types: half, bfloat16, float32, float64.
  435. Shape [batch, in_depth, in_height, in_width, in_channels].
  436. filters : tensor
  437. Must have the same type as input. Shape [filter_depth, filter_height, filter_width, in_channels, out_channels].
  438. in_channels must match between input and filters.
  439. strides : list of ints
  440. A list of ints that has length >= 5. 1-D tensor of length 5.
  441. The stride of the sliding window for each dimension of input.
  442. Must have strides[0] = strides[4] = 1.
  443. padding : string
  444. A string from: "SAME", "VALID". The type of padding algorithm to use.
  445. data_format : string
  446. An optional string from: "NDHWC", "NCDHW". Defaults to "NDHWC". The data format of the input and output data.
  447. With the default format "NDHWC", the data is stored in the order of: [batch, in_depth, in_height, in_width, in_channels].
  448. Alternatively, the format could be "NCDHW", the data storage order is: [batch, in_channels, in_depth, in_height, in_width].
  449. dilations : list of ints
  450. Defaults to [1, 1, 1, 1, 1]. 1-D tensor of length 5. The dilation factor for each dimension of input.
  451. If set to k > 1, there will be k-1 skipped cells between each filter element on that dimension.
  452. The dimension order is determined by the value of data_format, see above for details.
  453. Dilations in the batch and depth dimensions must be 1.
  454. name : string
  455. A name for the operation (optional).
  456. Returns
  457. -------
  458. A Tensor. Has the same type as input.
  459. """
  460. data_format, padding = preprocess_3d_format(data_format, padding)
  461. outputs = tf.nn.conv3d(
  462. input=input,
  463. filters=filters,
  464. strides=strides,
  465. padding=padding,
  466. data_format=data_format, # 'NDHWC',
  467. dilations=dilations, # [1, 1, 1, 1, 1],
  468. # name=name,
  469. )
  470. return outputs
  471. def lrn(inputs, depth_radius, bias, alpha, beta):
  472. """
  473. Local Response Normalization.
  474. Parameters
  475. ----------
  476. inputs : tensor
  477. Must be one of the following types: half, bfloat16, float32. 4-D.
  478. depth_radius : int
  479. Defaults to 5. 0-D. Half-width of the 1-D normalization window.
  480. bias : float
  481. Defaults to 1. An offset (usually positive to avoid dividing by 0).
  482. alpha : float
  483. Defaults to 1. A scale factor, usually positive.
  484. beta : float
  485. Defaults to 0.5. An exponent.
  486. Returns
  487. -------
  488. A Tensor. Has the same type as input.
  489. """
  490. outputs = tf.nn.lrn(inputs, depth_radius=depth_radius, bias=bias, alpha=alpha, beta=beta)
  491. return outputs
  492. def moments(x, axes, shift=None, keepdims=False):
  493. """
  494. Calculates the mean and variance of x.
  495. Parameters
  496. ----------
  497. x : tensor
  498. A Tensor
  499. axes : list or ints
  500. Axes along which to compute mean and variance.
  501. shift : int
  502. Not used in the current implementation.
  503. keepdims : bool
  504. produce moments with the same dimensionality as the input.
  505. Returns
  506. -------
  507. Two Tensor objects: mean and variance.
  508. """
  509. outputs = tf.nn.moments(x, axes, shift, keepdims)
  510. return outputs
  511. class MaxPool(object):
  512. def __init__(self, ksize, strides, padding, data_format=None):
  513. self.ksize = ksize
  514. self.strides = strides
  515. self.data_format = data_format
  516. self.padding = padding
  517. def __call__(self, inputs):
  518. if inputs.ndim == 3:
  519. self.data_format, self.padding = preprocess_1d_format(data_format=self.data_format, padding=self.padding)
  520. elif inputs.ndim == 4:
  521. self.data_format, self.padding = preprocess_2d_format(data_format=self.data_format, padding=self.padding)
  522. elif inputs.ndim == 5:
  523. self.data_format, self.padding = preprocess_3d_format(data_format=self.data_format, padding=self.padding)
  524. outputs = tf.nn.max_pool(
  525. input=inputs, ksize=self.ksize, strides=self.strides, padding=self.padding, data_format=self.data_format
  526. )
  527. return outputs
  528. def max_pool(input, ksize, strides, padding, data_format=None):
  529. """
  530. Performs the max pooling on the input.
  531. Parameters
  532. ----------
  533. input : tensor
  534. Tensor of rank N+2, of shape [batch_size] + input_spatial_shape + [num_channels] if data_format does not start
  535. with "NC" (default), or [batch_size, num_channels] + input_spatial_shape if data_format starts with "NC".
  536. Pooling happens over the spatial dimensions only.
  537. ksize : int or list of ints
  538. An int or list of ints that has length 1, N or N+2.
  539. The size of the window for each dimension of the input tensor.
  540. strides : int or list of ints
  541. An int or list of ints that has length 1, N or N+2.
  542. The stride of the sliding window for each dimension of the input tensor.
  543. padding : string
  544. 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
  545. name : string
  546. A name for the operation (optional).
  547. Returns
  548. -------
  549. A Tensor of format specified by data_format. The max pooled output tensor.
  550. """
  551. if input.ndim == 3:
  552. data_format, padding = preprocess_1d_format(data_format=data_format, padding=padding)
  553. elif input.ndim == 4:
  554. data_format, padding = preprocess_2d_format(data_format=data_format, padding=padding)
  555. elif input.ndim == 5:
  556. data_format, padding = preprocess_3d_format(data_format=data_format, padding=padding)
  557. outputs = tf.nn.max_pool(input=input, ksize=ksize, strides=strides, padding=padding, data_format=data_format)
  558. return outputs
  559. class AvgPool(object):
  560. def __init__(self, ksize, strides, padding, data_format=None):
  561. self.ksize = ksize
  562. self.strides = strides
  563. self.data_format = data_format
  564. self.padding = padding_format(padding)
  565. def __call__(self, inputs):
  566. outputs = tf.nn.avg_pool(
  567. input=inputs, ksize=self.ksize, strides=self.strides, padding=self.padding, data_format=self.data_format
  568. )
  569. return outputs
  570. def avg_pool(input, ksize, strides, padding):
  571. """
  572. Performs the avg pooling on the input.
  573. Parameters
  574. ----------
  575. input : tensor
  576. Tensor of rank N+2, of shape [batch_size] + input_spatial_shape + [num_channels]
  577. if data_format does not start with "NC" (default), or [batch_size, num_channels] + input_spatial_shape
  578. if data_format starts with "NC". Pooling happens over the spatial dimensions only.
  579. ksize : int or list of ints
  580. An int or list of ints that has length 1, N or N+2.
  581. The size of the window for each dimension of the input tensor.
  582. strides : int or list of ints
  583. An int or list of ints that has length 1, N or N+2.
  584. The stride of the sliding window for each dimension of the input tensor.
  585. padding : string
  586. 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
  587. name : string
  588. Optional name for the operation.
  589. Returns
  590. -------
  591. A Tensor of format specified by data_format. The average pooled output tensor.
  592. """
  593. padding = padding_format(padding)
  594. outputs = tf.nn.avg_pool(
  595. input=input,
  596. ksize=ksize,
  597. strides=strides,
  598. padding=padding,
  599. )
  600. return outputs
  601. def max_pool3d(input, ksize, strides, padding, data_format=None):
  602. """
  603. Performs the max pooling on the input.
  604. Parameters
  605. ----------
  606. input : tensor
  607. A 5-D Tensor of the format specified by data_format.
  608. ksize : int or list of ints
  609. An int or list of ints that has length 1, 3 or 5.
  610. The size of the window for each dimension of the input tensor.
  611. strides : int or list of ints
  612. An int or list of ints that has length 1, 3 or 5.
  613. The stride of the sliding window for each dimension of the input tensor.
  614. padding : string
  615. 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
  616. data_format : string
  617. "NDHWC", "NCDHW". Defaults to "NDHWC". The data format of the input and output data.
  618. With the default format "NDHWC", the data is stored in the order of: [batch, in_depth, in_height, in_width, in_channels].
  619. Alternatively, the format could be "NCDHW", the data storage order is: [batch, in_channels, in_depth, in_height, in_width].
  620. name : string
  621. A name for the operation (optional).
  622. Returns
  623. -------
  624. A Tensor of format specified by data_format. The max pooled output tensor.
  625. """
  626. data_format, padding = preprocess_3d_format(data_format, padding)
  627. outputs = tf.nn.max_pool3d(
  628. input=input,
  629. ksize=ksize,
  630. strides=strides,
  631. padding=padding,
  632. data_format=data_format,
  633. )
  634. return outputs
  635. def avg_pool3d(input, ksize, strides, padding, data_format=None):
  636. """
  637. Performs the average pooling on the input.
  638. Parameters
  639. ----------
  640. input : tensor
  641. A 5-D Tensor of shape [batch, height, width, channels] and type float32, float64, qint8, quint8, or qint32.
  642. ksize : int or list of ints
  643. An int or list of ints that has length 1, 3 or 5. The size of the window for each dimension of the input tensor.
  644. strides : int or list of ints
  645. An int or list of ints that has length 1, 3 or 5.
  646. The stride of the sliding window for each dimension of the input tensor.
  647. padding : string
  648. 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
  649. data_format : string
  650. 'NDHWC' and 'NCDHW' are supported.
  651. name : string
  652. Optional name for the operation.
  653. Returns
  654. -------
  655. A Tensor with the same type as value. The average pooled output tensor.
  656. """
  657. data_format, padding = preprocess_3d_format(data_format, padding)
  658. outputs = tf.nn.avg_pool3d(
  659. input=input,
  660. ksize=ksize,
  661. strides=strides,
  662. padding=padding,
  663. data_format=data_format,
  664. )
  665. return outputs
  666. def pool(input, window_shape, pooling_type, strides=None, padding='VALID', data_format=None, dilations=None, name=None):
  667. """
  668. Performs an N-D pooling operation.
  669. Parameters
  670. ----------
  671. input : tensor
  672. Tensor of rank N+2, of shape [batch_size] + input_spatial_shape + [num_channels]
  673. if data_format does not start with "NC" (default), or [batch_size, num_channels] + input_spatial_shape
  674. if data_format starts with "NC". Pooling happens over the spatial dimensions only.
  675. window_shape : int
  676. Sequence of N ints >= 1.
  677. pooling_type : string
  678. Specifies pooling operation, must be "AVG" or "MAX".
  679. strides : ints
  680. Sequence of N ints >= 1. Defaults to [1]*N. If any value of strides is > 1, then all values of dilation_rate must be 1.
  681. padding : string
  682. The padding algorithm, must be "SAME" or "VALID". Defaults to "SAME".
  683. See the "returns" section of tf.ops.convolution for details.
  684. data_format : string
  685. Specifies whether the channel dimension of the input and output is the last dimension (default, or if data_format does not start with "NC"),
  686. or the second dimension (if data_format starts with "NC").
  687. For N=1, the valid values are "NWC" (default) and "NCW". For N=2, the valid values are "NHWC" (default) and "NCHW".
  688. For N=3, the valid values are "NDHWC" (default) and "NCDHW".
  689. dilations : list of ints
  690. Dilation rate. List of N ints >= 1. Defaults to [1]*N. If any value of dilation_rate is > 1, then all values of strides must be 1.
  691. name : string
  692. Optional. Name of the op.
  693. Returns
  694. -------
  695. Tensor of rank N+2, of shape [batch_size] + output_spatial_shape + [num_channels]
  696. """
  697. if pooling_type in ["MAX", "max"]:
  698. pooling_type = "MAX"
  699. elif pooling_type in ["AVG", "avg"]:
  700. pooling_type = "AVG"
  701. else:
  702. raise ValueError('Unsupported pool_mode: ' + str(pooling_type))
  703. padding = padding_format(padding)
  704. outputs = tf.nn.pool(
  705. input=input,
  706. window_shape=window_shape,
  707. pooling_type=pooling_type,
  708. strides=strides,
  709. padding=padding,
  710. data_format=data_format,
  711. dilations=dilations,
  712. name=name,
  713. )
  714. return outputs
  715. class DepthwiseConv2d(object):
  716. def __init__(self, strides, padding, data_format=None, dilations=None, ksize=None, channel_multiplier=1):
  717. self.data_format, self.padding = preprocess_2d_format(data_format, padding)
  718. self.strides = strides
  719. self.dilations = dilations
  720. def __call__(self, input, filter):
  721. outputs = tf.nn.depthwise_conv2d(
  722. input=input,
  723. filter=filter,
  724. strides=self.strides,
  725. padding=self.padding,
  726. data_format=self.data_format,
  727. dilations=self.dilations,
  728. )
  729. return outputs
  730. def depthwise_conv2d(input, filter, strides, padding, data_format=None, dilations=None, name=None):
  731. """
  732. Depthwise 2-D convolution.
  733. Parameters
  734. ----------
  735. input : tensor
  736. 4-D with shape according to data_format.
  737. filter : tensor
  738. 4-D with shape [filter_height, filter_width, in_channels, channel_multiplier].
  739. strides : list
  740. 1-D of size 4. The stride of the sliding window for each dimension of input.
  741. padding : string
  742. 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
  743. data_format : string
  744. The data format for input. Either "NHWC" (default) or "NCHW".
  745. dilations : list
  746. 1-D of size 2. The dilation rate in which we sample input values across the height and width dimensions in atrous convolution.
  747. If it is greater than 1, then all values of strides must be 1.
  748. name : string
  749. A name for this operation (optional).
  750. Returns
  751. -------
  752. A 4-D Tensor with shape according to data_format.
  753. E.g., for "NHWC" format, shape is [batch, out_height, out_width, in_channels * channel_multiplier].
  754. """
  755. data_format, padding = preprocess_2d_format(data_format, padding)
  756. outputs = tf.nn.depthwise_conv2d(
  757. input=input,
  758. filter=filter,
  759. strides=strides,
  760. padding=padding,
  761. data_format=data_format,
  762. dilations=dilations,
  763. name=name,
  764. )
  765. return outputs
  766. class Conv1d_transpose(object):
  767. def __init__(
  768. self, strides, padding, data_format='NWC', dilations=None, out_channel=None, k_size=None, in_channels=None
  769. ):
  770. self.strides = strides
  771. self.dilations = dilations
  772. self.data_format, self.padding = preprocess_1d_format(data_format, padding)
  773. def __call__(self, input, filters):
  774. batch_size = input.shape[0]
  775. if self.data_format == 'NWC':
  776. w_axis, c_axis = 1, 2
  777. else:
  778. w_axis, c_axis = 2, 1
  779. input_shape = input.shape.as_list()
  780. filters_shape = filters.shape.as_list()
  781. input_w = input_shape[w_axis]
  782. filters_w = filters_shape[0]
  783. output_channels = filters_shape[1]
  784. dilations_w = 1
  785. if isinstance(self.strides, int):
  786. strides_w = self.strides
  787. else:
  788. strides_list = list(self.strides)
  789. strides_w = strides_list[w_axis]
  790. if self.dilations is not None:
  791. if isinstance(self.dilations, int):
  792. dilations_w = self.dilations
  793. else:
  794. dilations_list = list(self.dilations)
  795. dilations_w = dilations_list[w_axis]
  796. filters_w = filters_w + (filters_w - 1) * (dilations_w - 1)
  797. assert self.padding in {'SAME', 'VALID'}
  798. if self.padding == 'VALID':
  799. output_w = input_w * strides_w + max(filters_w - strides_w, 0)
  800. elif self.padding == 'SAME':
  801. output_w = input_w * strides_w
  802. if self.data_format == 'NCW':
  803. output_shape = (batch_size, output_channels, output_w)
  804. else:
  805. output_shape = (batch_size, output_w, output_channels)
  806. output_shape = tf.stack(output_shape)
  807. outputs = tf.nn.conv1d_transpose(
  808. input=input,
  809. filters=filters,
  810. output_shape=output_shape,
  811. strides=self.strides,
  812. padding=self.padding,
  813. data_format=self.data_format,
  814. dilations=self.dilations,
  815. )
  816. return outputs
  817. def conv1d_transpose(
  818. input, filters, output_shape, strides, padding='SAME', data_format='NWC', dilations=None, name=None
  819. ):
  820. """
  821. The transpose of conv1d.
  822. Parameters
  823. ----------
  824. input : tensor
  825. A 3-D Tensor of type float and shape [batch, in_width, in_channels]
  826. for NWC data format or [batch, in_channels, in_width] for NCW data format.
  827. filters : tensor
  828. A 3-D Tensor with the same type as value and shape [filter_width, output_channels, in_channels].
  829. filter's in_channels dimension must match that of value.
  830. output_shape : tensor
  831. A 1-D Tensor, containing three elements, representing the output shape of the deconvolution op.
  832. strides : list
  833. An int or list of ints that has length 1 or 3. The number of entries by which the filter is moved right at each step.
  834. padding : string
  835. 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
  836. data_format : string
  837. 'NWC' and 'NCW' are supported.
  838. dilations : list
  839. An int or list of ints that has length 1 or 3 which defaults to 1.
  840. The dilation factor for each dimension of input. If set to k > 1,
  841. there will be k-1 skipped cells between each filter element on that dimension.
  842. Dilations in the batch and depth dimensions must be 1.
  843. name : string
  844. Optional name for the returned tensor.
  845. Returns
  846. -------
  847. A Tensor with the same type as value.
  848. """
  849. data_format, padding = preprocess_1d_format(data_format, padding)
  850. outputs = tf.nn.conv1d_transpose(
  851. input=input,
  852. filters=filters,
  853. output_shape=output_shape,
  854. strides=strides,
  855. padding=padding,
  856. data_format=data_format,
  857. dilations=dilations,
  858. name=name,
  859. )
  860. return outputs
  861. class Conv2d_transpose(object):
  862. def __init__(
  863. self, strides, padding, data_format='NHWC', dilations=None, name=None, out_channel=None, k_size=None,
  864. in_channels=None
  865. ):
  866. self.strides = strides
  867. self.dilations = dilations
  868. self.name = name
  869. self.data_format, self.padding = preprocess_2d_format(data_format, padding)
  870. def __call__(self, input, filters):
  871. if self.data_format == 'NHWC':
  872. h_axis, w_axis = 1, 2
  873. else:
  874. h_axis, w_axis = 2, 3
  875. input_shape = input.shape.as_list()
  876. filters_shape = filters.shape.as_list()
  877. batch_size = input.shape[0]
  878. input_h, input_w = input_shape[h_axis], input_shape[w_axis]
  879. kernel_h, kernel_w = filters_shape[0], filters_shape[1]
  880. output_channels = filters_shape[2]
  881. dilations_h, dilations_w = 1, 1
  882. if isinstance(self.strides, int):
  883. strides_h = self.strides
  884. strides_w = self.strides
  885. else:
  886. strides_list = list(self.strides)
  887. if len(strides_list) != 4:
  888. strides_h = strides_list[0]
  889. strides_w = strides_list[1]
  890. else:
  891. strides_h = strides_list[h_axis]
  892. strides_w = strides_list[w_axis]
  893. if self.dilations is not None:
  894. if isinstance(self.dilations, int):
  895. dilations_h = self.dilations
  896. dilations_w = self.dilations
  897. else:
  898. dilations_list = list(self.dilations)
  899. if len(dilations_list) != 4:
  900. dilations_h = dilations_list[0]
  901. dilations_w = dilations_list[1]
  902. else:
  903. dilations_h = dilations_list[h_axis]
  904. dilations_w = dilations_list[w_axis]
  905. kernel_h = kernel_h + (kernel_h - 1) * (dilations_h - 1)
  906. kernel_w = kernel_w + (kernel_w - 1) * (dilations_w - 1)
  907. assert self.padding in {'SAME', 'VALID'}
  908. if self.padding == 'VALID':
  909. output_h = input_h * strides_h + max(kernel_h - strides_h, 0)
  910. output_w = input_w * strides_w + max(kernel_w - strides_w, 0)
  911. elif self.padding == 'SAME':
  912. output_h = input_h * strides_h
  913. output_w = input_w * strides_w
  914. if self.data_format == 'NCHW':
  915. out_shape = (batch_size, output_channels, output_h, output_w)
  916. else:
  917. out_shape = (batch_size, output_h, output_w, output_channels)
  918. output_shape = tf.stack(out_shape)
  919. outputs = tf.nn.conv2d_transpose(
  920. input=input, filters=filters, output_shape=output_shape, strides=self.strides, padding=self.padding,
  921. data_format=self.data_format, dilations=self.dilations, name=self.name
  922. )
  923. return outputs
  924. def conv2d_transpose(
  925. input, filters, output_shape, strides, padding='SAME', data_format='NHWC', dilations=None, name=None
  926. ):
  927. """
  928. The transpose of conv2d.
  929. Parameters
  930. ----------
  931. input : tensor
  932. A 4-D Tensor of type float and shape [batch, height, width, in_channels]
  933. for NHWC data format or [batch, in_channels, height, width] for NCHW data format.
  934. filters : tensor
  935. A 4-D Tensor with the same type as input and shape [height, width,
  936. output_channels, in_channels]. filter's in_channels dimension must match that of input.
  937. output_shape : tensor
  938. A 1-D Tensor representing the output shape of the deconvolution op.
  939. strides : list
  940. An int or list of ints that has length 1, 2 or 4. The stride of the sliding window for each dimension of input.
  941. If a single value is given it is replicated in the H and W dimension.
  942. By default the N and C dimensions are set to 0.
  943. The dimension order is determined by the value of data_format, see below for details.
  944. padding : string
  945. 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
  946. data_format : string
  947. 'NHWC' and 'NCHW' are supported.
  948. dilations : list
  949. An int or list of ints that has length 1, 2 or 4, defaults to 1.
  950. name : string
  951. Optional name for the returned tensor.
  952. Returns
  953. -------
  954. A Tensor with the same type as input.
  955. """
  956. data_format, padding = preprocess_2d_format(data_format, padding)
  957. outputs = tf.nn.conv2d_transpose(
  958. input=input,
  959. filters=filters,
  960. output_shape=output_shape,
  961. strides=strides,
  962. padding=padding,
  963. data_format=data_format,
  964. dilations=dilations,
  965. name=name,
  966. )
  967. return outputs
  968. class Conv3d_transpose(object):
  969. def __init__(
  970. self, strides, padding, data_format='NDHWC', dilations=None, name=None, out_channel=None, k_size=None,
  971. in_channels=None
  972. ):
  973. self.strides = strides
  974. self.dilations = dilations
  975. self.name = name
  976. self.out_channel = out_channel
  977. self.data_format, self.padding = preprocess_3d_format(data_format, padding)
  978. def __call__(self, input, filters):
  979. if self.data_format == 'NDHWC':
  980. d_axis, h_axis, w_axis = 1, 2, 3
  981. else:
  982. d_axis, h_axis, w_axis = 2, 3, 4
  983. input_shape = input.shape.as_list()
  984. filters_shape = filters.shape.as_list()
  985. batch_size = input_shape[0]
  986. input_d, input_h, input_w = input_shape[d_axis], input_shape[h_axis], input_shape[w_axis]
  987. kernel_d, kernel_h, kernel_w = filters_shape[0], filters_shape[1], filters_shape[2]
  988. dilations_d, dilations_h, dilations_w = 1, 1, 1
  989. if isinstance(self.strides, int):
  990. strides_d, strides_h, strides_w = self.strides
  991. else:
  992. strides_list = list(self.strides)
  993. if len(strides_list) != 5:
  994. strides_d, strides_h, strides_w = \
  995. strides_list[0], \
  996. strides_list[1], \
  997. strides_list[2]
  998. else:
  999. strides_d, strides_h, strides_w = \
  1000. strides_list[d_axis], \
  1001. strides_list[h_axis], \
  1002. strides_list[w_axis]
  1003. if self.dilations is not None:
  1004. if isinstance(self.dilations, int):
  1005. dilations_d, dilations_h, dilations_w = self.dilations
  1006. else:
  1007. dilations_list = list(self.dilations)
  1008. if len(dilations_list) != 5:
  1009. dilations_d, dilations_h, dilations_w = \
  1010. dilations_list[0], \
  1011. dilations_list[1], \
  1012. dilations_list[2]
  1013. else:
  1014. dilations_d, dilations_h, dilations_w = \
  1015. dilations_list[d_axis],\
  1016. dilations_list[h_axis], \
  1017. dilations_list[w_axis]
  1018. assert self.padding in {'VALID', 'SAME'}
  1019. kernel_d = kernel_d + (kernel_d - 1) * (dilations_d - 1)
  1020. kernel_h = kernel_h + (kernel_h - 1) * (dilations_h - 1)
  1021. kernel_w = kernel_w + (kernel_w - 1) * (dilations_w - 1)
  1022. if self.padding == 'VALID':
  1023. output_d = input_d * strides_d + max(kernel_d - strides_d, 0)
  1024. output_h = input_h * strides_h + max(kernel_h - strides_h, 0)
  1025. output_w = input_w * strides_w + max(kernel_w - strides_w, 0)
  1026. elif self.padding == 'SAME':
  1027. output_d = input_d * strides_d
  1028. output_h = input_h * strides_h
  1029. output_w = input_w * strides_w
  1030. if self.data_format == 'NDHWC':
  1031. output_shape = (batch_size, output_d, output_h, output_w, self.out_channel)
  1032. else:
  1033. output_shape = (batch_size, self.out_channel, output_d, output_h, output_w)
  1034. output_shape = tf.stack(output_shape)
  1035. outputs = tf.nn.conv3d_transpose(
  1036. input=input, filters=filters, output_shape=output_shape, strides=self.strides, padding=self.padding,
  1037. data_format=self.data_format, dilations=self.dilations, name=self.name
  1038. )
  1039. return outputs
  1040. def conv3d_transpose(
  1041. input, filters, output_shape, strides, padding='SAME', data_format='NDHWC', dilations=None, name=None
  1042. ):
  1043. """
  1044. The transpose of conv3d.
  1045. Parameters
  1046. ----------
  1047. input : tensor
  1048. A 5-D Tensor of type float and shape [batch, height, width, in_channels] for
  1049. NHWC data format or [batch, in_channels, height, width] for NCHW data format.
  1050. filters : tensor
  1051. A 5-D Tensor with the same type as value and shape [height, width, output_channels, in_channels].
  1052. filter's in_channels dimension must match that of value.
  1053. output_shape : tensor
  1054. A 1-D Tensor representing the output shape of the deconvolution op.
  1055. strides : list
  1056. An int or list of ints that has length 1, 3 or 5.
  1057. padding : string
  1058. 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
  1059. data_format : string
  1060. 'NDHWC' and 'NCDHW' are supported.
  1061. dilations : list of ints
  1062. An int or list of ints that has length 1, 3 or 5, defaults to 1.
  1063. name : string
  1064. Optional name for the returned tensor.
  1065. Returns
  1066. -------
  1067. A Tensor with the same type as value.
  1068. """
  1069. data_format, padding = preprocess_3d_format(data_format, padding)
  1070. outputs = tf.nn.conv3d_transpose(
  1071. input=input, filters=filters, output_shape=output_shape, strides=strides, padding=padding,
  1072. data_format=data_format, dilations=dilations, name=name
  1073. )
  1074. return outputs
  1075. def depthwise_conv2d(input, filters, strides, padding='SAME', data_format='NHWC', dilations=None, name=None):
  1076. """
  1077. Depthwise 2-D convolution.
  1078. Parameters
  1079. ----------
  1080. input : tensor
  1081. 4-D with shape according to data_format.
  1082. filters : tensor
  1083. 4-D with shape [filter_height, filter_width, in_channels, channel_multiplier].
  1084. strides : tuple
  1085. 1-D of size 4. The stride of the sliding window for each dimension of input.
  1086. padding : string
  1087. 'VALID' or 'SAME'
  1088. data_format : string
  1089. "NHWC" (default) or "NCHW".
  1090. dilations : tuple
  1091. The dilation rate in which we sample input values across the height and width dimensions in atrous convolution.
  1092. If it is greater than 1, then all values of strides must be 1.
  1093. name : string
  1094. A name for this operation (optional).
  1095. Returns
  1096. -------
  1097. A 4-D Tensor with shape according to data_format.
  1098. """
  1099. data_format, padding = preprocess_2d_format(data_format, padding)
  1100. outputs = tf.nn.depthwise_conv2d(
  1101. input=input,
  1102. filter=filters,
  1103. strides=strides,
  1104. padding=padding,
  1105. data_format=data_format,
  1106. dilations=dilations,
  1107. name=name,
  1108. )
  1109. return outputs
  1110. def _to_channel_first_bias(b):
  1111. """Reshape [c] to [c, 1, 1]."""
  1112. channel_size = int(b.shape[0])
  1113. new_shape = (channel_size, 1, 1)
  1114. return tf.reshape(b, new_shape)
  1115. def _bias_scale(x, b, data_format):
  1116. """The multiplication counter part of tf.nn.bias_add."""
  1117. if data_format == 'NHWC':
  1118. return x * b
  1119. elif data_format == 'NCHW':
  1120. return x * _to_channel_first_bias(b)
  1121. else:
  1122. raise ValueError('invalid data_format: %s' % data_format)
  1123. def _bias_add(x, b, data_format):
  1124. """Alternative implementation of tf.nn.bias_add which is compatiable with tensorRT."""
  1125. if data_format == 'NHWC':
  1126. return tf.add(x, b)
  1127. elif data_format == 'NCHW':
  1128. return tf.add(x, _to_channel_first_bias(b))
  1129. else:
  1130. raise ValueError('invalid data_format: %s' % data_format)
  1131. def batch_normalization(x, mean, variance, offset, scale, variance_epsilon, data_format, name=None):
  1132. """Data Format aware version of tf.nn.batch_normalization."""
  1133. if data_format == 'channels_last':
  1134. mean = tf.reshape(mean, [1] * (len(x.shape) - 1) + [-1])
  1135. variance = tf.reshape(variance, [1] * (len(x.shape) - 1) + [-1])
  1136. offset = tf.reshape(offset, [1] * (len(x.shape) - 1) + [-1])
  1137. scale = tf.reshape(scale, [1] * (len(x.shape) - 1) + [-1])
  1138. elif data_format == 'channels_first':
  1139. mean = tf.reshape(mean, [1] + [-1] + [1] * (len(x.shape) - 2))
  1140. variance = tf.reshape(variance, [1] + [-1] + [1] * (len(x.shape) - 2))
  1141. offset = tf.reshape(offset, [1] + [-1] + [1] * (len(x.shape) - 2))
  1142. scale = tf.reshape(scale, [1] + [-1] + [1] * (len(x.shape) - 2))
  1143. else:
  1144. raise ValueError('invalid data_format: %s' % data_format)
  1145. with ops.name_scope(name, 'batchnorm', [x, mean, variance, scale, offset]):
  1146. inv = math_ops.rsqrt(variance + variance_epsilon)
  1147. if scale is not None:
  1148. inv *= scale
  1149. a = math_ops.cast(inv, x.dtype)
  1150. b = math_ops.cast(offset - mean * inv if offset is not None else -mean * inv, x.dtype)
  1151. # Return a * x + b with customized data_format.
  1152. # Currently TF doesn't have bias_scale, and tensorRT has bug in converting tf.nn.bias_add
  1153. # So we reimplemted them to allow make the model work with tensorRT.
  1154. # See https://github.com/tensorlayer/openpose-plus/issues/75 for more details.
  1155. # df = {'channels_first': 'NCHW', 'channels_last': 'NHWC'}
  1156. # return _bias_add(_bias_scale(x, a, df[data_format]), b, df[data_format])
  1157. return a * x + b
  1158. class BatchNorm(object):
  1159. """
  1160. The :class:`BatchNorm` is a batch normalization layer for both fully-connected and convolution outputs.
  1161. See ``tf.nn.batch_normalization`` and ``tf.nn.moments``.
  1162. Parameters
  1163. ----------
  1164. decay : float
  1165. A decay factor for `ExponentialMovingAverage`.
  1166. Suggest to use a large value for large dataset.
  1167. epsilon : float
  1168. Eplison.
  1169. act : activation function
  1170. The activation function of this layer.
  1171. is_train : boolean
  1172. Is being used for training or inference.
  1173. beta_init : initializer or None
  1174. The initializer for initializing beta, if None, skip beta.
  1175. Usually you should not skip beta unless you know what happened.
  1176. gamma_init : initializer or None
  1177. The initializer for initializing gamma, if None, skip gamma.
  1178. When the batch normalization layer is use instead of 'biases', or the next layer is linear, this can be
  1179. disabled since the scaling can be done by the next layer. see `Inception-ResNet-v2 <https://github.com/tensorflow/models/blob/master/research/slim/nets/inception_resnet_v2.py>`__
  1180. moving_mean_init : initializer or None
  1181. The initializer for initializing moving mean, if None, skip moving mean.
  1182. moving_var_init : initializer or None
  1183. The initializer for initializing moving var, if None, skip moving var.
  1184. num_features: int
  1185. Number of features for input tensor. Useful to build layer if using BatchNorm1d, BatchNorm2d or BatchNorm3d,
  1186. but should be left as None if using BatchNorm. Default None.
  1187. data_format : str
  1188. channels_last 'channel_last' (default) or channels_first.
  1189. name : None or str
  1190. A unique layer name.
  1191. Examples
  1192. ---------
  1193. With TensorLayer
  1194. >>> net = tl.layers.Input([None, 50, 50, 32], name='input')
  1195. >>> net = tl.layers.BatchNorm()(net)
  1196. Notes
  1197. -----
  1198. The :class:`BatchNorm` is universally suitable for 3D/4D/5D input in static model, but should not be used
  1199. in dynamic model where layer is built upon class initialization. So the argument 'num_features' should only be used
  1200. for subclasses :class:`BatchNorm1d`, :class:`BatchNorm2d` and :class:`BatchNorm3d`. All the three subclasses are
  1201. suitable under all kinds of conditions.
  1202. References
  1203. ----------
  1204. - `Source <https://github.com/ry/tensorflow-resnet/blob/master/resnet.py>`__
  1205. - `stackoverflow <http://stackoverflow.com/questions/38312668/how-does-one-do-inference-with-batch-normalization-with-tensor-flow>`__
  1206. """
  1207. def __init__(
  1208. self, decay=0.9, epsilon=0.00001, beta=None, gamma=None, moving_mean=None, moving_var=None, num_features=None,
  1209. data_format='channels_last', is_train=False
  1210. ):
  1211. self.decay = decay
  1212. self.epsilon = epsilon
  1213. self.data_format = data_format
  1214. self.beta = beta
  1215. self.gamma = gamma
  1216. self.moving_mean = moving_mean
  1217. self.moving_var = moving_var
  1218. self.num_features = num_features
  1219. self.is_train = is_train
  1220. self.axes = None
  1221. if self.decay < 0.0 or 1.0 < self.decay:
  1222. raise ValueError("decay should be between 0 to 1")
  1223. def _get_param_shape(self, inputs_shape):
  1224. if self.data_format == 'channels_last':
  1225. axis = -1
  1226. elif self.data_format == 'channels_first':
  1227. axis = 1
  1228. else:
  1229. raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first'))
  1230. channels = inputs_shape[axis]
  1231. params_shape = [channels]
  1232. return params_shape
  1233. def _check_input_shape(self, inputs):
  1234. if inputs.ndim <= 1:
  1235. raise ValueError('expected input at least 2D, but got {}D input'.format(inputs.ndim))
  1236. def __call__(self, inputs):
  1237. self._check_input_shape(inputs)
  1238. self.channel_axis = len(inputs.shape) - 1 if self.data_format == 'channels_last' else 1
  1239. if self.axes is None:
  1240. self.axes = [i for i in range(len(inputs.shape)) if i != self.channel_axis]
  1241. mean, var = tf.nn.moments(inputs, self.axes, keepdims=False)
  1242. if self.is_train:
  1243. # update moving_mean and moving_var
  1244. self.moving_mean = moving_averages.assign_moving_average(
  1245. self.moving_mean, mean, self.decay, zero_debias=False
  1246. )
  1247. self.moving_var = moving_averages.assign_moving_average(self.moving_var, var, self.decay, zero_debias=False)
  1248. outputs = batch_normalization(inputs, mean, var, self.beta, self.gamma, self.epsilon, self.data_format)
  1249. else:
  1250. outputs = batch_normalization(
  1251. inputs, self.moving_mean, self.moving_var, self.beta, self.gamma, self.epsilon, self.data_format
  1252. )
  1253. return outputs
  1254. class GroupConv2D(object):
  1255. def __init__(self, strides, padding, data_format, dilations, out_channel, k_size, groups):
  1256. self.data_format, self.padding = preprocess_2d_format(data_format, padding)
  1257. self.strides = strides
  1258. self.dilations = dilations
  1259. self.groups = groups
  1260. if self.data_format == 'NHWC':
  1261. self.channels_axis = 3
  1262. else:
  1263. self.channels_axis = 1
  1264. def __call__(self, input, filters):
  1265. if self.groups == 1:
  1266. outputs = tf.nn.conv2d(
  1267. input=input,
  1268. filters=filters,
  1269. strides=self.strides,
  1270. padding=self.padding,
  1271. data_format=self.data_format,
  1272. dilations=self.dilations,
  1273. )
  1274. else:
  1275. inputgroups = tf.split(input, num_or_size_splits=self.groups, axis=self.channels_axis)
  1276. weightsgroups = tf.split(filters, num_or_size_splits=self.groups, axis=self.channels_axis)
  1277. convgroups = []
  1278. for i, k in zip(inputgroups, weightsgroups):
  1279. convgroups.append(
  1280. tf.nn.conv2d(
  1281. input=i,
  1282. filters=k,
  1283. strides=self.strides,
  1284. padding=self.padding,
  1285. data_format=self.data_format,
  1286. dilations=self.dilations,
  1287. )
  1288. )
  1289. outputs = tf.concat(axis=self.channels_axis, values=convgroups)
  1290. return outputs
  1291. class SeparableConv1D(object):
  1292. def __init__(self, stride, padding, data_format, dilations, out_channel, k_size, in_channel, depth_multiplier):
  1293. self.data_format, self.padding = preprocess_1d_format(data_format, padding)
  1294. if self.data_format == 'NWC':
  1295. self.spatial_start_dim = 1
  1296. self.strides = (1, stride, stride, 1)
  1297. self.data_format = 'NHWC'
  1298. else:
  1299. self.spatial_start_dim = 2
  1300. self.strides = (1, 1, stride, stride)
  1301. self.data_format = 'NCHW'
  1302. self.dilation_rate = (1, dilations)
  1303. def __call__(self, inputs, depthwise_filters, pointwise_filters):
  1304. inputs = tf.expand_dims(inputs, axis=self.spatial_start_dim)
  1305. depthwise_filters = tf.expand_dims(depthwise_filters, 0)
  1306. pointwise_filters = tf.expand_dims(pointwise_filters, 0)
  1307. outputs = tf.nn.separable_conv2d(
  1308. inputs, depthwise_filters, pointwise_filters, strides=self.strides, padding=self.padding,
  1309. dilations=self.dilation_rate, data_format=self.data_format
  1310. )
  1311. outputs = tf.squeeze(outputs, axis=self.spatial_start_dim)
  1312. return outputs
  1313. class SeparableConv2D(object):
  1314. def __init__(self, strides, padding, data_format, dilations, out_channel, k_size, in_channel, depth_multiplier):
  1315. self.data_format, self.padding = preprocess_2d_format(data_format, padding)
  1316. self.strides = strides
  1317. self.dilations = (dilations[2], dilations[2])
  1318. def __call__(self, inputs, depthwise_filters, pointwise_filters):
  1319. outputs = tf.nn.separable_conv2d(
  1320. inputs, depthwise_filters, pointwise_filters, strides=self.strides, padding=self.padding,
  1321. dilations=self.dilations, data_format=self.data_format
  1322. )
  1323. return outputs
  1324. class AdaptiveMeanPool1D(object):
  1325. def __init__(self, output_size, data_format):
  1326. self.data_format, _ = preprocess_1d_format(data_format, None)
  1327. self.output_size = output_size
  1328. def __call__(self, input):
  1329. if self.data_format == 'NWC':
  1330. n, w, c = input.shape
  1331. else:
  1332. n, c, w = input.shape
  1333. stride = floor(w / self.output_size)
  1334. kernel = w - (self.output_size - 1) * stride
  1335. output = tf.nn.avg_pool1d(input, ksize=kernel, strides=stride, data_format=self.data_format, padding='VALID')
  1336. return output
  1337. class AdaptiveMeanPool2D(object):
  1338. def __init__(self, output_size, data_format):
  1339. self.data_format, _ = preprocess_2d_format(data_format, None)
  1340. self.output_size = output_size
  1341. def __call__(self, inputs):
  1342. if self.data_format == 'NHWC':
  1343. n, h, w, c = inputs.shape
  1344. else:
  1345. n, c, h, w = inputs.shape
  1346. out_h, out_w = self.output_size
  1347. stride_h = floor(h / out_h)
  1348. kernel_h = h - (out_h - 1) * stride_h
  1349. stride_w = floor(w / out_w)
  1350. kernel_w = w - (out_w - 1) * stride_w
  1351. outputs = tf.nn.avg_pool2d(
  1352. inputs, ksize=(kernel_h, kernel_w), strides=(stride_h, stride_w), data_format=self.data_format,
  1353. padding='VALID'
  1354. )
  1355. return outputs
  1356. class AdaptiveMeanPool3D(object):
  1357. def __init__(self, output_size, data_format):
  1358. self.data_format, _ = preprocess_3d_format(data_format, None)
  1359. self.output_size = output_size
  1360. def __call__(self, inputs):
  1361. if self.data_format == 'NDHWC':
  1362. n, d, h, w, c = inputs.shape
  1363. else:
  1364. n, c, d, h, w = inputs.shape
  1365. out_d, out_h, out_w = self.output_size
  1366. stride_d = floor(d / out_d)
  1367. kernel_d = d - (out_d - 1) * stride_d
  1368. stride_h = floor(h / out_h)
  1369. kernel_h = h - (out_h - 1) * stride_h
  1370. stride_w = floor(w / out_w)
  1371. kernel_w = w - (out_w - 1) * stride_w
  1372. outputs = tf.nn.avg_pool3d(
  1373. inputs, ksize=(kernel_d, kernel_h, kernel_w), strides=(stride_d, stride_h, stride_w),
  1374. data_format=self.data_format, padding='VALID'
  1375. )
  1376. return outputs
  1377. class AdaptiveMaxPool1D(object):
  1378. def __init__(self, output_size, data_format):
  1379. self.data_format, _ = preprocess_1d_format(data_format, None)
  1380. self.output_size = output_size
  1381. def __call__(self, input):
  1382. if self.data_format == 'NWC':
  1383. n, w, c = input.shape
  1384. else:
  1385. n, c, w = input.shape
  1386. stride = floor(w / self.output_size)
  1387. kernel = w - (self.output_size - 1) * stride
  1388. output = tf.nn.max_pool1d(input, ksize=kernel, strides=stride, data_format=self.data_format, padding='VALID')
  1389. return output
  1390. class AdaptiveMaxPool2D(object):
  1391. def __init__(self, output_size, data_format):
  1392. self.data_format, _ = preprocess_2d_format(data_format, None)
  1393. self.output_size = output_size
  1394. def __call__(self, inputs):
  1395. if self.data_format == 'NHWC':
  1396. n, h, w, c = inputs.shape
  1397. else:
  1398. n, c, h, w = inputs.shape
  1399. out_h, out_w = self.output_size
  1400. stride_h = floor(h / out_h)
  1401. kernel_h = h - (out_h - 1) * stride_h
  1402. stride_w = floor(w / out_w)
  1403. kernel_w = w - (out_w - 1) * stride_w
  1404. outputs = tf.nn.max_pool2d(
  1405. inputs, ksize=(kernel_h, kernel_w), strides=(stride_h, stride_w), data_format=self.data_format,
  1406. padding='VALID'
  1407. )
  1408. return outputs
  1409. class AdaptiveMaxPool3D(object):
  1410. def __init__(self, output_size, data_format):
  1411. self.data_format, _ = preprocess_3d_format(data_format, None)
  1412. self.output_size = output_size
  1413. def __call__(self, inputs):
  1414. if self.data_format == 'NDHWC':
  1415. n, d, h, w, c = inputs.shape
  1416. else:
  1417. n, c, d, h, w = inputs.shape
  1418. out_d, out_h, out_w = self.output_size
  1419. stride_d = floor(d / out_d)
  1420. kernel_d = d - (out_d - 1) * stride_d
  1421. stride_h = floor(h / out_h)
  1422. kernel_h = h - (out_h - 1) * stride_h
  1423. stride_w = floor(w / out_w)
  1424. kernel_w = w - (out_w - 1) * stride_w
  1425. outputs = tf.nn.max_pool3d(
  1426. inputs, ksize=(kernel_d, kernel_h, kernel_w), strides=(stride_d, stride_h, stride_w),
  1427. data_format=self.data_format, padding='VALID'
  1428. )
  1429. return outputs
  1430. class BinaryConv2D(object):
  1431. def __init__(self, strides, padding, data_format, dilations, out_channel, k_size, in_channel):
  1432. self.data_format, self.padding = preprocess_2d_format(data_format, padding)
  1433. self.strides = strides
  1434. self.dilations = dilations
  1435. # @tf.RegisterGradient("TL_Sign_QuantizeGrad")
  1436. # def _quantize_grad(op, grad):
  1437. # """Clip and binarize tensor using the straight through estimator (STE) for the gradient."""
  1438. # return tf.clip_by_value(grad, -1, 1)
  1439. def quantize(self, x):
  1440. # ref: https://github.com/AngusG/tensorflow-xnor-bnn/blob/master/models/binary_net.py#L70
  1441. # https://github.com/itayhubara/BinaryNet.tf/blob/master/nnUtils.py
  1442. with tf.compat.v1.get_default_graph().gradient_override_map({"Sign": "TL_Sign_QuantizeGrad"}):
  1443. return tf.sign(x)
  1444. def __call__(self, inputs, filters):
  1445. filters = self.quantize(filters)
  1446. outputs = tf.nn.conv2d(
  1447. input=inputs, filters=filters, strides=self.strides, padding=self.padding, data_format=self.data_format,
  1448. dilations=self.dilations
  1449. )
  1450. return outputs
  1451. class DorefaConv2D(object):
  1452. def __init__(self, bitW, bitA, strides, padding, data_format, dilations, out_channel, k_size, in_channel):
  1453. self.data_format, self.padding = preprocess_2d_format(data_format, padding)
  1454. self.strides = strides
  1455. self.dilations = dilations
  1456. self.bitW = bitW
  1457. self.bitA = bitA
  1458. def _quantize_dorefa(self, x, k):
  1459. G = tf.compat.v1.get_default_graph()
  1460. n = float(2**k - 1)
  1461. with G.gradient_override_map({"Round": "Identity"}):
  1462. return tf.round(x * n) / n
  1463. def cabs(self, x):
  1464. return tf.minimum(1.0, tf.abs(x), name='cabs')
  1465. def quantize_active(self, x, bitA):
  1466. if bitA == 32:
  1467. return x
  1468. return self._quantize_dorefa(x, bitA)
  1469. def quantize_weight(self, x, bitW, force_quantization=False):
  1470. G = tf.compat.v1.get_default_graph()
  1471. if bitW == 32 and not force_quantization:
  1472. return x
  1473. if bitW == 1: # BWN
  1474. with G.gradient_override_map({"Sign": "Identity"}):
  1475. E = tf.stop_gradient(tf.reduce_mean(input_tensor=tf.abs(x)))
  1476. return tf.sign(x / E) * E
  1477. x = tf.clip_by_value(
  1478. x * 0.5 + 0.5, 0.0, 1.0
  1479. ) # it seems as though most weights are within -1 to 1 region anyways
  1480. return 2 * self._quantize_dorefa(x, bitW) - 1
  1481. def __call__(self, inputs, filters):
  1482. inputs = self.quantize_active(self.cabs(inputs), self.bitA)
  1483. filters = self.quantize_weight(filters, self.bitW)
  1484. outputs = tf.nn.conv2d(
  1485. input=inputs,
  1486. filters=filters,
  1487. strides=self.strides,
  1488. padding=self.padding,
  1489. data_format=self.data_format,
  1490. dilations=self.dilations,
  1491. )
  1492. return outputs

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.