You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensorflow_nn.py 79 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import tensorflow as tf
  4. from tensorflow.python.framework import ops
  5. from tensorflow.python.ops import math_ops
  6. from tensorflow.python.training import moving_averages
  7. from math import floor, ceil
  8. import numpy as np
  9. # loss function
  10. sparse_softmax_cross_entropy_with_logits = tf.nn.sparse_softmax_cross_entropy_with_logits
  11. sigmoid_cross_entropy_with_logits = tf.nn.sigmoid_cross_entropy_with_logits
  12. def padding_format(padding):
  13. """
  14. Checks that the padding format correspond format.
  15. Parameters
  16. ----------
  17. padding : str
  18. Must be one of the following:"same", "SAME", "VALID", "valid"
  19. Returns
  20. -------
  21. str "SAME" or "VALID"
  22. """
  23. if padding in ["SAME", "same"]:
  24. padding = "SAME"
  25. elif padding in ["VALID", "valid"]:
  26. padding = "VALID"
  27. elif padding == None:
  28. padding = None
  29. else:
  30. raise Exception("Unsupported padding: " + str(padding))
  31. return padding
  32. def preprocess_1d_format(data_format, padding):
  33. """
  34. Checks that the 1-D dataformat format correspond format.
  35. Parameters
  36. ----------
  37. data_format : str
  38. Must be one of the following:"channels_last","NWC","NCW","channels_first"
  39. padding : str
  40. Must be one of the following:"same","valid","SAME","VALID"
  41. Returns
  42. -------
  43. str "NWC" or "NCW" and "SAME" or "VALID"
  44. """
  45. if data_format in ["channels_last", "NWC"]:
  46. data_format = "NWC"
  47. elif data_format in ["channels_first", "NCW"]:
  48. data_format = "NCW"
  49. elif data_format == None:
  50. data_format = None
  51. else:
  52. raise Exception("Unsupported data format: " + str(data_format))
  53. padding = padding_format(padding)
  54. return data_format, padding
  55. def preprocess_2d_format(data_format, padding):
  56. """
  57. Checks that the 2-D dataformat format correspond format.
  58. Parameters
  59. ----------
  60. data_format : str
  61. Must be one of the following:"channels_last","NHWC","NCHW","channels_first"
  62. padding : str
  63. Must be one of the following:"same","valid","SAME","VALID"
  64. Returns
  65. -------
  66. str "NHWC" or "NCHW" and "SAME" or "VALID"
  67. """
  68. if data_format in ["channels_last", "NHWC"]:
  69. data_format = "NHWC"
  70. elif data_format in ["channels_first", "NCHW"]:
  71. data_format = "NCHW"
  72. elif data_format == None:
  73. data_format = None
  74. else:
  75. raise Exception("Unsupported data format: " + str(data_format))
  76. padding = padding_format(padding)
  77. return data_format, padding
  78. def preprocess_3d_format(data_format, padding):
  79. """
  80. Checks that the 3-D dataformat format correspond format.
  81. Parameters
  82. ----------
  83. data_format : str
  84. Must be one of the following:"channels_last","NDHWC","NCDHW","channels_first"
  85. padding : str
  86. Must be one of the following:"same","valid","SAME","VALID"
  87. Returns
  88. -------
  89. str "NDHWC" or "NCDHW" and "SAME" or "VALID"
  90. """
  91. if data_format in ['channels_last', 'NDHWC']:
  92. data_format = 'NDHWC'
  93. elif data_format in ['channels_first', 'NCDHW']:
  94. data_format = 'NCDHW'
  95. elif data_format == None:
  96. data_format = None
  97. else:
  98. raise Exception("Unsupported data format: " + str(data_format))
  99. padding = padding_format(padding)
  100. return data_format, padding
  101. def nchw_to_nhwc(x):
  102. """
  103. Channels first to channels last
  104. Parameters
  105. ----------
  106. x : tensor
  107. channels first tensor data
  108. Returns
  109. -------
  110. channels last tensor data
  111. """
  112. if len(x.shape) == 3:
  113. x = tf.transpose(x, (0, 2, 1))
  114. elif len(x.shape) == 4:
  115. x = tf.transpose(x, (0, 2, 3, 1))
  116. elif len(x.shape) == 5:
  117. x = tf.transpose(x, (0, 2, 3, 4, 1))
  118. else:
  119. raise Exception("Unsupported dimensions")
  120. return x
  121. def nhwc_to_nchw(x):
  122. """
  123. Channles last to channels first
  124. Parameters
  125. ----------
  126. x : tensor
  127. channels last tensor data
  128. Returns
  129. -------
  130. channels first tensor data
  131. """
  132. if len(x.shape) == 3:
  133. x = tf.transpose(x, (0, 2, 1))
  134. elif len(x.shape) == 4:
  135. x = tf.transpose(x, (0, 3, 1, 2))
  136. elif len(x.shape) == 5:
  137. x = tf.transpose(x, (0, 4, 1, 2, 3))
  138. else:
  139. raise Exception("Unsupported dimensions")
  140. return x
  141. class ReLU(object):
  142. def __init__(self):
  143. pass
  144. def __call__(self, x):
  145. return tf.nn.relu(x)
  146. def relu(x):
  147. """
  148. Computes rectified linear: max(features, 0).
  149. Parameters
  150. ----------
  151. x : tensor
  152. Must be one of the following types: float32, float64, int32, uint8, int16,
  153. int8, int64, bfloat16, uint16, half, uint32, uint64, qint8.
  154. Returns
  155. -------
  156. A Tensor. Has the same type as features.
  157. """
  158. return tf.nn.relu(x)
  159. class ReLU6(object):
  160. def __init__(self):
  161. pass
  162. def __call__(self, x):
  163. return tf.nn.relu6(x)
  164. def relu6(x):
  165. """
  166. Computes Rectified Linear 6: min(max(features, 0), 6).
  167. Parameters
  168. ----------
  169. x : tensor
  170. Must be one of the following types: float32, float64, int32, uint8, int16,
  171. int8, int64, bfloat16, uint16, half, uint32, uint64, qint8.
  172. Returns
  173. -------
  174. A Tensor with the same type as features.
  175. """
  176. return tf.nn.relu6(x)
  177. class LeakyReLU(object):
  178. def __init__(self, alpha=0.2):
  179. self.alpha = alpha
  180. def __call__(self, x):
  181. return tf.nn.leaky_relu(x, alpha=self.alpha)
  182. def leaky_relu(x, alpha=0.2):
  183. """
  184. Compute the Leaky ReLU activation function.
  185. Parameters
  186. ----------
  187. x : tensor
  188. representing preactivation values. Must be one of the following types:
  189. float16, float32, float64, int32, int64.
  190. Returns
  191. -------
  192. The activation value.
  193. """
  194. return tf.nn.leaky_relu(x, alpha=alpha)
  195. class Softplus(object):
  196. def __init__(self):
  197. pass
  198. def __call__(self, x):
  199. return tf.nn.softplus(x)
  200. def softplus(x):
  201. """
  202. Computes softplus: log(exp(features) + 1).
  203. Parameters
  204. ----------
  205. x : tensor
  206. Must be one of the following types: half, bfloat16, float32, float64.
  207. Returns
  208. -------
  209. A Tensor. Has the same type as features.
  210. """
  211. return tf.nn.softplus(x)
  212. class Tanh(object):
  213. def __init__(self):
  214. pass
  215. def __call__(self, x):
  216. return tf.nn.tanh(x)
  217. def tanh(x):
  218. """
  219. Computes hyperbolic tangent of x element-wise.
  220. Parameters
  221. ----------
  222. x : tensor
  223. Must be one of the following types: bfloat16, half, float32, float64, complex64, complex128.
  224. Returns
  225. -------
  226. A Tensor. Has the same type as x.
  227. """
  228. return tf.nn.tanh(x)
  229. class Sigmoid(object):
  230. def __init__(self):
  231. pass
  232. def __call__(self, x):
  233. return tf.nn.sigmoid(x)
  234. def sigmoid(x):
  235. """
  236. Computes sigmoid of x element-wise.
  237. Parameters
  238. ----------
  239. x : tensor
  240. A Tensor with type float16, float32, float64, complex64, or complex128.
  241. Returns
  242. -------
  243. A Tensor with the same type as x.
  244. """
  245. return tf.nn.sigmoid(x)
  246. class Softmax(object):
  247. def __init__(self):
  248. pass
  249. def __call__(self, x):
  250. return tf.nn.softmax(x)
  251. def softmax(logits, axis=None):
  252. """
  253. Computes softmax activations.
  254. Parameters
  255. ----------
  256. logits : tensor
  257. Must be one of the following types: half, float32, float64.
  258. axis : int
  259. The dimension softmax would be performed on. The default is -1 which indicates the last dimension.
  260. Returns
  261. -------
  262. A Tensor. Has the same type and shape as logits.
  263. """
  264. return tf.nn.softmax(logits, axis)
  265. class Dropout(object):
  266. def __init__(self, keep, seed=0):
  267. self.keep = keep
  268. self.seed = seed
  269. def __call__(self, inputs, *args, **kwargs):
  270. outputs = tf.nn.dropout(inputs, rate=1 - (self.keep), seed=self.seed)
  271. return outputs
  272. class BiasAdd(object):
  273. """
  274. Adds bias to value.
  275. Parameters
  276. ----------
  277. x : tensor
  278. A Tensor with type float, double, int64, int32, uint8, int16, int8, complex64, or complex128.
  279. bias : tensor
  280. Must be the same type as value unless value is a quantized type,
  281. in which case a different quantized type may be used.
  282. Returns
  283. -------
  284. A Tensor with the same type as value.
  285. """
  286. def __init__(self, data_format=None):
  287. self.data_format = data_format
  288. def __call__(self, x, bias):
  289. return tf.nn.bias_add(x, bias, data_format=self.data_format)
  290. def bias_add(x, bias, data_format=None, name=None):
  291. """
  292. Adds bias to value.
  293. Parameters
  294. ----------
  295. x : tensor
  296. A Tensor with type float, double, int64, int32, uint8, int16, int8, complex64, or complex128.
  297. bias : tensor
  298. Must be the same type as value unless value is a quantized type,
  299. in which case a different quantized type may be used.
  300. data_format : A string.
  301. 'N...C' and 'NC...' are supported.
  302. name : str
  303. A name for the operation (optional).
  304. Returns
  305. -------
  306. A Tensor with the same type as value.
  307. """
  308. x = tf.nn.bias_add(x, bias, data_format=data_format, name=name)
  309. return x
  310. class Conv1D(object):
  311. def __init__(self, stride, padding, data_format='NWC', dilations=None, out_channel=None, k_size=None):
  312. self.stride = stride
  313. self.dilations = dilations
  314. self.data_format, self.padding = preprocess_1d_format(data_format, padding)
  315. def __call__(self, input, filters):
  316. outputs = tf.nn.conv1d(
  317. input=input,
  318. filters=filters,
  319. stride=self.stride,
  320. padding=self.padding,
  321. data_format=self.data_format,
  322. dilations=self.dilations,
  323. # name=name
  324. )
  325. return outputs
  326. def conv1d(input, filters, stride, padding, data_format='NWC', dilations=None):
  327. """
  328. Computes a 1-D convolution given 3-D input and filter tensors.
  329. Parameters
  330. ----------
  331. input : tensor
  332. A 3D Tensor. Must be of type float16, float32, or float64
  333. filters : tensor
  334. A 3D Tensor. Must have the same type as input.
  335. stride : int of list
  336. An int or list of ints that has length 1 or 3. The number of entries by which the filter is moved right at each step.
  337. padding : string
  338. 'SAME' or 'VALID'
  339. data_format : string
  340. An optional string from "NWC", "NCW". Defaults to "NWC", the data is stored in the order of
  341. [batch, in_width, in_channels]. The "NCW" format stores data as [batch, in_channels, in_width].
  342. dilations : int or list
  343. An int or list of ints that has length 1 or 3 which defaults to 1.
  344. The dilation factor for each dimension of input. If set to k > 1,
  345. there will be k-1 skipped cells between each filter element on that dimension.
  346. Dilations in the batch and depth dimensions must be 1.
  347. name : string
  348. A name for the operation (optional).
  349. Returns
  350. -------
  351. A Tensor. Has the same type as input.
  352. """
  353. data_format, padding = preprocess_1d_format(data_format, padding)
  354. outputs = tf.nn.conv1d(
  355. input=input,
  356. filters=filters,
  357. stride=stride,
  358. padding=padding,
  359. data_format=data_format,
  360. dilations=dilations,
  361. # name=name
  362. )
  363. return outputs
  364. class Conv2D(object):
  365. def __init__(self, strides, padding, data_format='NHWC', dilations=None, out_channel=None, k_size=None):
  366. self.strides = strides
  367. self.dilations = dilations
  368. self.data_format, self.padding = preprocess_2d_format(data_format, padding)
  369. def __call__(self, input, filters):
  370. outputs = tf.nn.conv2d(
  371. input=input,
  372. filters=filters,
  373. strides=self.strides,
  374. padding=self.padding,
  375. data_format=self.data_format,
  376. dilations=self.dilations,
  377. )
  378. return outputs
  379. def conv2d(input, filters, strides, padding, data_format='NHWC', dilations=None):
  380. """
  381. Computes a 2-D convolution given 4-D input and filters tensors.
  382. Parameters
  383. ----------
  384. input : tensor
  385. Must be one of the following types: half, bfloat16, float32, float64. A 4-D tensor.
  386. The dimension order is interpreted according to the value of data_format, see below for details.
  387. filters : tensor
  388. Must have the same type as input. A 4-D tensor of shape [filter_height, filter_width, in_channels, out_channels]
  389. strides : int of list
  390. The stride of the sliding window for each dimension of input. If a single value is given it is replicated in the H and W dimension.
  391. By default the N and C dimensions are set to 1. The dimension order is determined by the value of data_format, see below for details.
  392. padding : string
  393. "SAME" or "VALID"
  394. data_format : string
  395. "NHWC", "NCHW". Defaults to "NHWC".
  396. dilations : list or ints
  397. list of ints that has length 1, 2 or 4, defaults to 1. The dilation factor for each dimension ofinput.
  398. name : string
  399. A name for the operation (optional).
  400. Returns
  401. -------
  402. A Tensor. Has the same type as input.
  403. """
  404. data_format, padding = preprocess_2d_format(data_format, padding)
  405. outputs = tf.nn.conv2d(
  406. input=input,
  407. filters=filters,
  408. strides=strides,
  409. padding=padding,
  410. data_format=data_format,
  411. dilations=dilations,
  412. )
  413. return outputs
  414. class Conv3D(object):
  415. def __init__(self, strides, padding, data_format='NDHWC', dilations=None, out_channel=None, k_size=None):
  416. self.strides = strides
  417. self.dilations = dilations
  418. self.data_format, self.padding = preprocess_3d_format(data_format, padding)
  419. def __call__(self, input, filters):
  420. outputs = tf.nn.conv3d(
  421. input=input,
  422. filters=filters,
  423. strides=self.strides,
  424. padding=self.padding,
  425. data_format=self.data_format,
  426. dilations=self.dilations,
  427. )
  428. return outputs
  429. def conv3d(input, filters, strides, padding, data_format='NDHWC', dilations=None):
  430. """
  431. Computes a 3-D convolution given 5-D input and filters tensors.
  432. Parameters
  433. ----------
  434. input : tensor
  435. Must be one of the following types: half, bfloat16, float32, float64.
  436. Shape [batch, in_depth, in_height, in_width, in_channels].
  437. filters : tensor
  438. Must have the same type as input. Shape [filter_depth, filter_height, filter_width, in_channels, out_channels].
  439. in_channels must match between input and filters.
  440. strides : list of ints
  441. A list of ints that has length >= 5. 1-D tensor of length 5.
  442. The stride of the sliding window for each dimension of input.
  443. Must have strides[0] = strides[4] = 1.
  444. padding : string
  445. A string from: "SAME", "VALID". The type of padding algorithm to use.
  446. data_format : string
  447. An optional string from: "NDHWC", "NCDHW". Defaults to "NDHWC". The data format of the input and output data.
  448. With the default format "NDHWC", the data is stored in the order of: [batch, in_depth, in_height, in_width, in_channels].
  449. Alternatively, the format could be "NCDHW", the data storage order is: [batch, in_channels, in_depth, in_height, in_width].
  450. dilations : list of ints
  451. Defaults to [1, 1, 1, 1, 1]. 1-D tensor of length 5. The dilation factor for each dimension of input.
  452. If set to k > 1, there will be k-1 skipped cells between each filter element on that dimension.
  453. The dimension order is determined by the value of data_format, see above for details.
  454. Dilations in the batch and depth dimensions must be 1.
  455. name : string
  456. A name for the operation (optional).
  457. Returns
  458. -------
  459. A Tensor. Has the same type as input.
  460. """
  461. data_format, padding = preprocess_3d_format(data_format, padding)
  462. outputs = tf.nn.conv3d(
  463. input=input,
  464. filters=filters,
  465. strides=strides,
  466. padding=padding,
  467. data_format=data_format, # 'NDHWC',
  468. dilations=dilations, # [1, 1, 1, 1, 1],
  469. # name=name,
  470. )
  471. return outputs
  472. def lrn(inputs, depth_radius, bias, alpha, beta):
  473. """
  474. Local Response Normalization.
  475. Parameters
  476. ----------
  477. inputs : tensor
  478. Must be one of the following types: half, bfloat16, float32. 4-D.
  479. depth_radius : int
  480. Defaults to 5. 0-D. Half-width of the 1-D normalization window.
  481. bias : float
  482. Defaults to 1. An offset (usually positive to avoid dividing by 0).
  483. alpha : float
  484. Defaults to 1. A scale factor, usually positive.
  485. beta : float
  486. Defaults to 0.5. An exponent.
  487. Returns
  488. -------
  489. A Tensor. Has the same type as input.
  490. """
  491. outputs = tf.nn.lrn(inputs, depth_radius=depth_radius, bias=bias, alpha=alpha, beta=beta)
  492. return outputs
  493. def moments(x, axes, shift=None, keepdims=False):
  494. """
  495. Calculates the mean and variance of x.
  496. Parameters
  497. ----------
  498. x : tensor
  499. A Tensor
  500. axes : list or ints
  501. Axes along which to compute mean and variance.
  502. shift : int
  503. Not used in the current implementation.
  504. keepdims : bool
  505. produce moments with the same dimensionality as the input.
  506. Returns
  507. -------
  508. Two Tensor objects: mean and variance.
  509. """
  510. outputs = tf.nn.moments(x, axes, shift, keepdims)
  511. return outputs
  512. class MaxPool1d(object):
  513. def __init__(self, ksize, strides, padding, data_format=None):
  514. self.data_format, self.padding = preprocess_1d_format(data_format=data_format, padding=padding)
  515. self.ksize = ksize
  516. self.strides = strides
  517. def __call__(self, inputs):
  518. outputs = tf.nn.max_pool(
  519. input=inputs, ksize=self.ksize, strides=self.strides, padding=self.padding, data_format=self.data_format
  520. )
  521. return outputs
  522. class MaxPool(object):
  523. def __init__(self, ksize, strides, padding, data_format=None):
  524. self.ksize = ksize
  525. self.strides = strides
  526. self.data_format = data_format
  527. self.padding = padding
  528. def __call__(self, inputs):
  529. if inputs.ndim == 3:
  530. self.data_format, self.padding = preprocess_1d_format(data_format=self.data_format, padding=self.padding)
  531. elif inputs.ndim == 4:
  532. self.data_format, self.padding = preprocess_2d_format(data_format=self.data_format, padding=self.padding)
  533. elif inputs.ndim == 5:
  534. self.data_format, self.padding = preprocess_3d_format(data_format=self.data_format, padding=self.padding)
  535. outputs = tf.nn.max_pool(
  536. input=inputs, ksize=self.ksize, strides=self.strides, padding=self.padding, data_format=self.data_format
  537. )
  538. return outputs
  539. def max_pool(input, ksize, strides, padding, data_format=None):
  540. """
  541. Performs the max pooling on the input.
  542. Parameters
  543. ----------
  544. input : tensor
  545. Tensor of rank N+2, of shape [batch_size] + input_spatial_shape + [num_channels] if data_format does not start
  546. with "NC" (default), or [batch_size, num_channels] + input_spatial_shape if data_format starts with "NC".
  547. Pooling happens over the spatial dimensions only.
  548. ksize : int or list of ints
  549. An int or list of ints that has length 1, N or N+2.
  550. The size of the window for each dimension of the input tensor.
  551. strides : int or list of ints
  552. An int or list of ints that has length 1, N or N+2.
  553. The stride of the sliding window for each dimension of the input tensor.
  554. padding : string
  555. 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
  556. name : string
  557. A name for the operation (optional).
  558. Returns
  559. -------
  560. A Tensor of format specified by data_format. The max pooled output tensor.
  561. """
  562. if input.ndim == 3:
  563. data_format, padding = preprocess_1d_format(data_format=data_format, padding=padding)
  564. elif input.ndim == 4:
  565. data_format, padding = preprocess_2d_format(data_format=data_format, padding=padding)
  566. elif input.ndim == 5:
  567. data_format, padding = preprocess_3d_format(data_format=data_format, padding=padding)
  568. outputs = tf.nn.max_pool(input=input, ksize=ksize, strides=strides, padding=padding, data_format=data_format)
  569. return outputs
  570. class AvgPool1d(object):
  571. def __init__(self, ksize, strides, padding, data_format=None):
  572. self.data_format, self.padding = preprocess_1d_format(data_format=data_format, padding=padding)
  573. self.ksize = ksize
  574. self.strides = strides
  575. def __call__(self, inputs):
  576. outputs = tf.nn.pool(
  577. input=inputs,
  578. window_shape=self.ksize,
  579. pooling_type="AVG",
  580. strides=self.strides,
  581. padding=self.padding,
  582. data_format=self.data_format,
  583. )
  584. return outputs
  585. class AvgPool(object):
  586. def __init__(self, ksize, strides, padding, data_format=None):
  587. self.ksize = ksize
  588. self.strides = strides
  589. self.data_format = data_format
  590. self.padding = padding_format(padding)
  591. def __call__(self, inputs):
  592. outputs = tf.nn.avg_pool(
  593. input=inputs, ksize=self.ksize, strides=self.strides, padding=self.padding, data_format=self.data_format
  594. )
  595. return outputs
  596. def avg_pool(input, ksize, strides, padding):
  597. """
  598. Performs the avg pooling on the input.
  599. Parameters
  600. ----------
  601. input : tensor
  602. Tensor of rank N+2, of shape [batch_size] + input_spatial_shape + [num_channels]
  603. if data_format does not start with "NC" (default), or [batch_size, num_channels] + input_spatial_shape
  604. if data_format starts with "NC". Pooling happens over the spatial dimensions only.
  605. ksize : int or list of ints
  606. An int or list of ints that has length 1, N or N+2.
  607. The size of the window for each dimension of the input tensor.
  608. strides : int or list of ints
  609. An int or list of ints that has length 1, N or N+2.
  610. The stride of the sliding window for each dimension of the input tensor.
  611. padding : string
  612. 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
  613. name : string
  614. Optional name for the operation.
  615. Returns
  616. -------
  617. A Tensor of format specified by data_format. The average pooled output tensor.
  618. """
  619. padding = padding_format(padding)
  620. outputs = tf.nn.avg_pool(
  621. input=input,
  622. ksize=ksize,
  623. strides=strides,
  624. padding=padding,
  625. )
  626. return outputs
  627. class MaxPool3d(object):
  628. def __init__(self, ksize, strides, padding, data_format=None):
  629. self.data_format, self.padding = preprocess_3d_format(data_format, padding)
  630. self.ksize = ksize
  631. self.strides = strides
  632. def __call__(self, inputs):
  633. outputs = tf.nn.max_pool3d(
  634. input=inputs,
  635. ksize=self.ksize,
  636. strides=self.strides,
  637. padding=self.padding,
  638. data_format=self.data_format,
  639. )
  640. return outputs
  641. def max_pool3d(input, ksize, strides, padding, data_format=None):
  642. """
  643. Performs the max pooling on the input.
  644. Parameters
  645. ----------
  646. input : tensor
  647. A 5-D Tensor of the format specified by data_format.
  648. ksize : int or list of ints
  649. An int or list of ints that has length 1, 3 or 5.
  650. The size of the window for each dimension of the input tensor.
  651. strides : int or list of ints
  652. An int or list of ints that has length 1, 3 or 5.
  653. The stride of the sliding window for each dimension of the input tensor.
  654. padding : string
  655. 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
  656. data_format : string
  657. "NDHWC", "NCDHW". Defaults to "NDHWC". The data format of the input and output data.
  658. With the default format "NDHWC", the data is stored in the order of: [batch, in_depth, in_height, in_width, in_channels].
  659. Alternatively, the format could be "NCDHW", the data storage order is: [batch, in_channels, in_depth, in_height, in_width].
  660. name : string
  661. A name for the operation (optional).
  662. Returns
  663. -------
  664. A Tensor of format specified by data_format. The max pooled output tensor.
  665. """
  666. data_format, padding = preprocess_3d_format(data_format, padding)
  667. outputs = tf.nn.max_pool3d(
  668. input=input,
  669. ksize=ksize,
  670. strides=strides,
  671. padding=padding,
  672. data_format=data_format,
  673. )
  674. return outputs
  675. class AvgPool3d(object):
  676. def __init__(self, ksize, strides, padding, data_format=None):
  677. self.data_format, self.padding = preprocess_3d_format(data_format, padding)
  678. self.ksize = ksize
  679. self.strides = strides
  680. def __call__(self, inputs):
  681. outputs = tf.nn.avg_pool3d(
  682. input=inputs,
  683. ksize=self.ksize,
  684. strides=self.strides,
  685. padding=self.padding,
  686. data_format=self.data_format,
  687. )
  688. return outputs
  689. def avg_pool3d(input, ksize, strides, padding, data_format=None):
  690. """
  691. Performs the average pooling on the input.
  692. Parameters
  693. ----------
  694. input : tensor
  695. A 5-D Tensor of shape [batch, height, width, channels] and type float32, float64, qint8, quint8, or qint32.
  696. ksize : int or list of ints
  697. An int or list of ints that has length 1, 3 or 5. The size of the window for each dimension of the input tensor.
  698. strides : int or list of ints
  699. An int or list of ints that has length 1, 3 or 5.
  700. The stride of the sliding window for each dimension of the input tensor.
  701. padding : string
  702. 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
  703. data_format : string
  704. 'NDHWC' and 'NCDHW' are supported.
  705. name : string
  706. Optional name for the operation.
  707. Returns
  708. -------
  709. A Tensor with the same type as value. The average pooled output tensor.
  710. """
  711. data_format, padding = preprocess_3d_format(data_format, padding)
  712. outputs = tf.nn.avg_pool3d(
  713. input=input,
  714. ksize=ksize,
  715. strides=strides,
  716. padding=padding,
  717. data_format=data_format,
  718. )
  719. return outputs
  720. def pool(input, window_shape, pooling_type, strides=None, padding='VALID', data_format=None, dilations=None, name=None):
  721. """
  722. Performs an N-D pooling operation.
  723. Parameters
  724. ----------
  725. input : tensor
  726. Tensor of rank N+2, of shape [batch_size] + input_spatial_shape + [num_channels]
  727. if data_format does not start with "NC" (default), or [batch_size, num_channels] + input_spatial_shape
  728. if data_format starts with "NC". Pooling happens over the spatial dimensions only.
  729. window_shape : int
  730. Sequence of N ints >= 1.
  731. pooling_type : string
  732. Specifies pooling operation, must be "AVG" or "MAX".
  733. strides : ints
  734. Sequence of N ints >= 1. Defaults to [1]*N. If any value of strides is > 1, then all values of dilation_rate must be 1.
  735. padding : string
  736. The padding algorithm, must be "SAME" or "VALID". Defaults to "SAME".
  737. See the "returns" section of tf.ops.convolution for details.
  738. data_format : string
  739. Specifies whether the channel dimension of the input and output is the last dimension (default, or if data_format does not start with "NC"),
  740. or the second dimension (if data_format starts with "NC").
  741. For N=1, the valid values are "NWC" (default) and "NCW". For N=2, the valid values are "NHWC" (default) and "NCHW".
  742. For N=3, the valid values are "NDHWC" (default) and "NCDHW".
  743. dilations : list of ints
  744. Dilation rate. List of N ints >= 1. Defaults to [1]*N. If any value of dilation_rate is > 1, then all values of strides must be 1.
  745. name : string
  746. Optional. Name of the op.
  747. Returns
  748. -------
  749. Tensor of rank N+2, of shape [batch_size] + output_spatial_shape + [num_channels]
  750. """
  751. if pooling_type in ["MAX", "max"]:
  752. pooling_type = "MAX"
  753. elif pooling_type in ["AVG", "avg"]:
  754. pooling_type = "AVG"
  755. else:
  756. raise ValueError('Unsupported pool_mode: ' + str(pooling_type))
  757. padding = padding_format(padding)
  758. outputs = tf.nn.pool(
  759. input=input,
  760. window_shape=window_shape,
  761. pooling_type=pooling_type,
  762. strides=strides,
  763. padding=padding,
  764. data_format=data_format,
  765. dilations=dilations,
  766. name=name,
  767. )
  768. return outputs
  769. class DepthwiseConv2d(object):
  770. def __init__(self, strides, padding, data_format=None, dilations=None, ksize=None, channel_multiplier=1):
  771. self.data_format, self.padding = preprocess_2d_format(data_format, padding)
  772. self.strides = strides
  773. self.dilations = dilations
  774. def __call__(self, input, filter):
  775. outputs = tf.nn.depthwise_conv2d(
  776. input=input,
  777. filter=filter,
  778. strides=self.strides,
  779. padding=self.padding,
  780. data_format=self.data_format,
  781. dilations=self.dilations,
  782. )
  783. return outputs
  784. def depthwise_conv2d(input, filter, strides, padding, data_format=None, dilations=None, name=None):
  785. """
  786. Depthwise 2-D convolution.
  787. Parameters
  788. ----------
  789. input : tensor
  790. 4-D with shape according to data_format.
  791. filter : tensor
  792. 4-D with shape [filter_height, filter_width, in_channels, channel_multiplier].
  793. strides : list
  794. 1-D of size 4. The stride of the sliding window for each dimension of input.
  795. padding : string
  796. 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
  797. data_format : string
  798. The data format for input. Either "NHWC" (default) or "NCHW".
  799. dilations : list
  800. 1-D of size 2. The dilation rate in which we sample input values across the height and width dimensions in atrous convolution.
  801. If it is greater than 1, then all values of strides must be 1.
  802. name : string
  803. A name for this operation (optional).
  804. Returns
  805. -------
  806. A 4-D Tensor with shape according to data_format.
  807. E.g., for "NHWC" format, shape is [batch, out_height, out_width, in_channels * channel_multiplier].
  808. """
  809. data_format, padding = preprocess_2d_format(data_format, padding)
  810. outputs = tf.nn.depthwise_conv2d(
  811. input=input,
  812. filter=filter,
  813. strides=strides,
  814. padding=padding,
  815. data_format=data_format,
  816. dilations=dilations,
  817. name=name,
  818. )
  819. return outputs
  820. class Conv1d_transpose(object):
  821. def __init__(
  822. self, stride, padding, data_format='NWC', dilations=None, out_channel=None, k_size=None, in_channels=None
  823. ):
  824. self.stride = stride
  825. self.dilations = dilations
  826. self.data_format, self.padding = preprocess_1d_format(data_format, padding)
  827. def __call__(self, input, filters):
  828. batch_size = input.shape[0]
  829. if self.data_format == 'NWC':
  830. w_axis, c_axis = 1, 2
  831. else:
  832. w_axis, c_axis = 2, 1
  833. input_shape = input.shape.as_list()
  834. filters_shape = filters.shape.as_list()
  835. input_w = input_shape[w_axis]
  836. filters_w = filters_shape[0]
  837. output_channels = filters_shape[1]
  838. dilations_w = 1
  839. if isinstance(self.stride, int):
  840. strides_w = self.stride
  841. else:
  842. strides_list = list(self.stride)
  843. strides_w = strides_list[w_axis]
  844. if self.dilations is not None:
  845. if isinstance(self.dilations, int):
  846. dilations_w = self.dilations
  847. else:
  848. dilations_list = list(self.dilations)
  849. dilations_w = dilations_list[w_axis]
  850. filters_w = filters_w + (filters_w - 1) * (dilations_w - 1)
  851. assert self.padding in {'SAME', 'VALID'}
  852. if self.padding == 'VALID':
  853. output_w = input_w * strides_w + max(filters_w - strides_w, 0)
  854. elif self.padding == 'SAME':
  855. output_w = input_w * strides_w
  856. if self.data_format == 'NCW':
  857. output_shape = (batch_size, output_channels, output_w)
  858. else:
  859. output_shape = (batch_size, output_w, output_channels)
  860. output_shape = tf.stack(output_shape)
  861. outputs = tf.nn.conv1d_transpose(
  862. input=input,
  863. filters=filters,
  864. output_shape=output_shape,
  865. strides=self.stride,
  866. padding=self.padding,
  867. data_format=self.data_format,
  868. dilations=self.dilations,
  869. )
  870. return outputs
  871. def conv1d_transpose(
  872. input, filters, output_shape, strides, padding='SAME', data_format='NWC', dilations=None, name=None
  873. ):
  874. """
  875. The transpose of conv1d.
  876. Parameters
  877. ----------
  878. input : tensor
  879. A 3-D Tensor of type float and shape [batch, in_width, in_channels]
  880. for NWC data format or [batch, in_channels, in_width] for NCW data format.
  881. filters : tensor
  882. A 3-D Tensor with the same type as value and shape [filter_width, output_channels, in_channels].
  883. filter's in_channels dimension must match that of value.
  884. output_shape : tensor
  885. A 1-D Tensor, containing three elements, representing the output shape of the deconvolution op.
  886. strides : list
  887. An int or list of ints that has length 1 or 3. The number of entries by which the filter is moved right at each step.
  888. padding : string
  889. 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
  890. data_format : string
  891. 'NWC' and 'NCW' are supported.
  892. dilations : list
  893. An int or list of ints that has length 1 or 3 which defaults to 1.
  894. The dilation factor for each dimension of input. If set to k > 1,
  895. there will be k-1 skipped cells between each filter element on that dimension.
  896. Dilations in the batch and depth dimensions must be 1.
  897. name : string
  898. Optional name for the returned tensor.
  899. Returns
  900. -------
  901. A Tensor with the same type as value.
  902. """
  903. data_format, padding = preprocess_1d_format(data_format, padding)
  904. outputs = tf.nn.conv1d_transpose(
  905. input=input,
  906. filters=filters,
  907. output_shape=output_shape,
  908. strides=strides,
  909. padding=padding,
  910. data_format=data_format,
  911. dilations=dilations,
  912. name=name,
  913. )
  914. return outputs
  915. class Conv2d_transpose(object):
  916. def __init__(
  917. self, strides, padding, data_format='NHWC', dilations=None, name=None, out_channel=None, k_size=None,
  918. in_channels=None
  919. ):
  920. self.strides = strides
  921. self.dilations = dilations
  922. self.name = name
  923. self.data_format, self.padding = preprocess_2d_format(data_format, padding)
  924. def __call__(self, input, filters):
  925. if self.data_format == 'NHWC':
  926. h_axis, w_axis = 1, 2
  927. else:
  928. h_axis, w_axis = 2, 3
  929. input_shape = input.shape.as_list()
  930. filters_shape = filters.shape.as_list()
  931. batch_size = input.shape[0]
  932. input_h, input_w = input_shape[h_axis], input_shape[w_axis]
  933. kernel_h, kernel_w = filters_shape[0], filters_shape[1]
  934. output_channels = filters_shape[2]
  935. dilations_h, dilations_w = 1, 1
  936. if isinstance(self.strides, int):
  937. strides_h = self.strides
  938. strides_w = self.strides
  939. else:
  940. strides_list = list(self.strides)
  941. if len(strides_list) == 2:
  942. strides_h = strides_list[0]
  943. strides_w = strides_list[1]
  944. elif len(strides_list) == 4:
  945. strides_h = strides_list[h_axis]
  946. strides_w = strides_list[w_axis]
  947. if self.dilations is not None:
  948. if isinstance(self.dilations, int):
  949. dilations_h = self.dilations
  950. dilations_w = self.dilations
  951. else:
  952. dilations_list = list(self.dilations)
  953. if len(dilations_list) == 2:
  954. dilations_h = dilations_list[0]
  955. dilations_w = dilations_list[1]
  956. elif len(dilations_list) == 4:
  957. dilations_h = dilations_list[h_axis]
  958. dilations_w = dilations_list[w_axis]
  959. kernel_h = kernel_h + (kernel_h - 1) * (dilations_h - 1)
  960. kernel_w = kernel_w + (kernel_w - 1) * (dilations_w - 1)
  961. assert self.padding in {'SAME', 'VALID'}
  962. if self.padding == 'VALID':
  963. output_h = input_h * strides_h + max(kernel_h - strides_h, 0)
  964. output_w = input_w * strides_w + max(kernel_w - strides_w, 0)
  965. elif self.padding == 'SAME':
  966. output_h = input_h * strides_h
  967. output_w = input_w * strides_w
  968. if self.data_format == 'NCHW':
  969. out_shape = (batch_size, output_channels, output_h, output_w)
  970. else:
  971. out_shape = (batch_size, output_h, output_w, output_channels)
  972. output_shape = tf.stack(out_shape)
  973. outputs = tf.nn.conv2d_transpose(
  974. input=input, filters=filters, output_shape=output_shape, strides=self.strides, padding=self.padding,
  975. data_format=self.data_format, dilations=self.dilations, name=self.name
  976. )
  977. return outputs
  978. def conv2d_transpose(
  979. input, filters, output_shape, strides, padding='SAME', data_format='NHWC', dilations=None, name=None
  980. ):
  981. """
  982. The transpose of conv2d.
  983. Parameters
  984. ----------
  985. input : tensor
  986. A 4-D Tensor of type float and shape [batch, height, width, in_channels]
  987. for NHWC data format or [batch, in_channels, height, width] for NCHW data format.
  988. filters : tensor
  989. A 4-D Tensor with the same type as input and shape [height, width,
  990. output_channels, in_channels]. filter's in_channels dimension must match that of input.
  991. output_shape : tensor
  992. A 1-D Tensor representing the output shape of the deconvolution op.
  993. strides : list
  994. An int or list of ints that has length 1, 2 or 4. The stride of the sliding window for each dimension of input.
  995. If a single value is given it is replicated in the H and W dimension.
  996. By default the N and C dimensions are set to 0.
  997. The dimension order is determined by the value of data_format, see below for details.
  998. padding : string
  999. 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
  1000. data_format : string
  1001. 'NHWC' and 'NCHW' are supported.
  1002. dilations : list
  1003. An int or list of ints that has length 1, 2 or 4, defaults to 1.
  1004. name : string
  1005. Optional name for the returned tensor.
  1006. Returns
  1007. -------
  1008. A Tensor with the same type as input.
  1009. """
  1010. data_format, padding = preprocess_2d_format(data_format, padding)
  1011. outputs = tf.nn.conv2d_transpose(
  1012. input=input,
  1013. filters=filters,
  1014. output_shape=output_shape,
  1015. strides=strides,
  1016. padding=padding,
  1017. data_format=data_format,
  1018. dilations=dilations,
  1019. name=name,
  1020. )
  1021. return outputs
  1022. class Conv3d_transpose(object):
  1023. def __init__(
  1024. self, strides, padding, data_format='NDHWC', dilations=None, name=None, out_channel=None, k_size=None,
  1025. in_channels=None
  1026. ):
  1027. self.strides = strides
  1028. self.dilations = dilations
  1029. self.name = name
  1030. self.out_channel = out_channel
  1031. self.data_format, self.padding = preprocess_3d_format(data_format, padding)
  1032. def __call__(self, input, filters):
  1033. if self.data_format == 'NDHWC':
  1034. d_axis, h_axis, w_axis = 1, 2, 3
  1035. else:
  1036. d_axis, h_axis, w_axis = 2, 3, 4
  1037. input_shape = input.shape.as_list()
  1038. filters_shape = filters.shape.as_list()
  1039. batch_size = input_shape[0]
  1040. input_d, input_h, input_w = input_shape[d_axis], input_shape[h_axis], input_shape[w_axis]
  1041. kernel_d, kernel_h, kernel_w = filters_shape[0], filters_shape[1], filters_shape[2]
  1042. dilations_d, dilations_h, dilations_w = 1, 1, 1
  1043. if isinstance(self.strides, int):
  1044. strides_d, strides_h, strides_w = self.strides
  1045. else:
  1046. strides_list = list(self.strides)
  1047. if len(strides_list) == 3:
  1048. strides_d, strides_h, strides_w = \
  1049. strides_list[0], \
  1050. strides_list[1], \
  1051. strides_list[2]
  1052. elif len(strides_list) == 5:
  1053. strides_d, strides_h, strides_w = \
  1054. strides_list[d_axis], \
  1055. strides_list[h_axis], \
  1056. strides_list[w_axis]
  1057. if self.dilations is not None:
  1058. if isinstance(self.dilations, int):
  1059. dilations_d, dilations_h, dilations_w = self.dilations
  1060. else:
  1061. dilations_list = list(self.dilations)
  1062. if len(dilations_list) == 3:
  1063. dilations_d, dilations_h, dilations_w = \
  1064. dilations_list[0], \
  1065. dilations_list[1], \
  1066. dilations_list[2]
  1067. elif len(dilations_list) == 5:
  1068. dilations_d, dilations_h, dilations_w = \
  1069. dilations_list[d_axis],\
  1070. dilations_list[h_axis], \
  1071. dilations_list[w_axis]
  1072. assert self.padding in {'VALID', 'SAME'}
  1073. kernel_d = kernel_d + (kernel_d - 1) * (dilations_d - 1)
  1074. kernel_h = kernel_h + (kernel_h - 1) * (dilations_h - 1)
  1075. kernel_w = kernel_w + (kernel_w - 1) * (dilations_w - 1)
  1076. if self.padding == 'VALID':
  1077. output_d = input_d * strides_d + max(kernel_d - strides_d, 0)
  1078. output_h = input_h * strides_h + max(kernel_h - strides_h, 0)
  1079. output_w = input_w * strides_w + max(kernel_w - strides_w, 0)
  1080. elif self.padding == 'SAME':
  1081. output_d = input_d * strides_d
  1082. output_h = input_h * strides_h
  1083. output_w = input_w * strides_w
  1084. if self.data_format == 'NDHWC':
  1085. output_shape = (batch_size, output_d, output_h, output_w, self.out_channel)
  1086. else:
  1087. output_shape = (batch_size, self.out_channel, output_d, output_h, output_w)
  1088. output_shape = tf.stack(output_shape)
  1089. outputs = tf.nn.conv3d_transpose(
  1090. input=input, filters=filters, output_shape=output_shape, strides=self.strides, padding=self.padding,
  1091. data_format=self.data_format, dilations=self.dilations, name=self.name
  1092. )
  1093. return outputs
  1094. def conv3d_transpose(
  1095. input, filters, output_shape, strides, padding='SAME', data_format='NDHWC', dilations=None, name=None
  1096. ):
  1097. """
  1098. The transpose of conv3d.
  1099. Parameters
  1100. ----------
  1101. input : tensor
  1102. A 5-D Tensor of type float and shape [batch, height, width, in_channels] for
  1103. NHWC data format or [batch, in_channels, height, width] for NCHW data format.
  1104. filters : tensor
  1105. A 5-D Tensor with the same type as value and shape [height, width, output_channels, in_channels].
  1106. filter's in_channels dimension must match that of value.
  1107. output_shape : tensor
  1108. A 1-D Tensor representing the output shape of the deconvolution op.
  1109. strides : list
  1110. An int or list of ints that has length 1, 3 or 5.
  1111. padding : string
  1112. 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.ops.convolution for details.
  1113. data_format : string
  1114. 'NDHWC' and 'NCDHW' are supported.
  1115. dilations : list of ints
  1116. An int or list of ints that has length 1, 3 or 5, defaults to 1.
  1117. name : string
  1118. Optional name for the returned tensor.
  1119. Returns
  1120. -------
  1121. A Tensor with the same type as value.
  1122. """
  1123. data_format, padding = preprocess_3d_format(data_format, padding)
  1124. outputs = tf.nn.conv3d_transpose(
  1125. input=input, filters=filters, output_shape=output_shape, strides=strides, padding=padding,
  1126. data_format=data_format, dilations=dilations, name=name
  1127. )
  1128. return outputs
  1129. def depthwise_conv2d(input, filters, strides, padding='SAME', data_format='NHWC', dilations=None, name=None):
  1130. """
  1131. Depthwise 2-D convolution.
  1132. Parameters
  1133. ----------
  1134. input : tensor
  1135. 4-D with shape according to data_format.
  1136. filters : tensor
  1137. 4-D with shape [filter_height, filter_width, in_channels, channel_multiplier].
  1138. strides : tuple
  1139. 1-D of size 4. The stride of the sliding window for each dimension of input.
  1140. padding : string
  1141. 'VALID' or 'SAME'
  1142. data_format : string
  1143. "NHWC" (default) or "NCHW".
  1144. dilations : tuple
  1145. The dilation rate in which we sample input values across the height and width dimensions in atrous convolution.
  1146. If it is greater than 1, then all values of strides must be 1.
  1147. name : string
  1148. A name for this operation (optional).
  1149. Returns
  1150. -------
  1151. A 4-D Tensor with shape according to data_format.
  1152. """
  1153. data_format, padding = preprocess_2d_format(data_format, padding)
  1154. outputs = tf.nn.depthwise_conv2d(
  1155. input=input,
  1156. filter=filters,
  1157. strides=strides,
  1158. padding=padding,
  1159. data_format=data_format,
  1160. dilations=dilations,
  1161. name=name,
  1162. )
  1163. return outputs
  1164. def _to_channel_first_bias(b):
  1165. """Reshape [c] to [c, 1, 1]."""
  1166. channel_size = int(b.shape[0])
  1167. new_shape = (channel_size, 1, 1)
  1168. return tf.reshape(b, new_shape)
  1169. def _bias_scale(x, b, data_format):
  1170. """The multiplication counter part of tf.nn.bias_add."""
  1171. if data_format == 'NHWC':
  1172. return x * b
  1173. elif data_format == 'NCHW':
  1174. return x * _to_channel_first_bias(b)
  1175. else:
  1176. raise ValueError('invalid data_format: %s' % data_format)
  1177. def _bias_add(x, b, data_format):
  1178. """Alternative implementation of tf.nn.bias_add which is compatiable with tensorRT."""
  1179. if data_format == 'NHWC':
  1180. return tf.add(x, b)
  1181. elif data_format == 'NCHW':
  1182. return tf.add(x, _to_channel_first_bias(b))
  1183. else:
  1184. raise ValueError('invalid data_format: %s' % data_format)
  1185. def batch_normalization(x, mean, variance, offset, scale, variance_epsilon, data_format, name=None):
  1186. """Data Format aware version of tf.nn.batch_normalization."""
  1187. if data_format == 'channels_last':
  1188. mean = tf.reshape(mean, [1] * (len(x.shape) - 1) + [-1])
  1189. variance = tf.reshape(variance, [1] * (len(x.shape) - 1) + [-1])
  1190. offset = tf.reshape(offset, [1] * (len(x.shape) - 1) + [-1])
  1191. scale = tf.reshape(scale, [1] * (len(x.shape) - 1) + [-1])
  1192. elif data_format == 'channels_first':
  1193. mean = tf.reshape(mean, [1] + [-1] + [1] * (len(x.shape) - 2))
  1194. variance = tf.reshape(variance, [1] + [-1] + [1] * (len(x.shape) - 2))
  1195. offset = tf.reshape(offset, [1] + [-1] + [1] * (len(x.shape) - 2))
  1196. scale = tf.reshape(scale, [1] + [-1] + [1] * (len(x.shape) - 2))
  1197. else:
  1198. raise ValueError('invalid data_format: %s' % data_format)
  1199. with ops.name_scope(name, 'batchnorm', [x, mean, variance, scale, offset]):
  1200. inv = math_ops.rsqrt(variance + variance_epsilon)
  1201. if scale is not None:
  1202. inv *= scale
  1203. a = math_ops.cast(inv, x.dtype)
  1204. b = math_ops.cast(offset - mean * inv if offset is not None else -mean * inv, x.dtype)
  1205. # Return a * x + b with customized data_format.
  1206. # Currently TF doesn't have bias_scale, and tensorRT has bug in converting tf.nn.bias_add
  1207. # So we reimplemted them to allow make the model work with tensorRT.
  1208. # See https://github.com/tensorlayer/openpose-plus/issues/75 for more details.
  1209. # df = {'channels_first': 'NCHW', 'channels_last': 'NHWC'}
  1210. # return _bias_add(_bias_scale(x, a, df[data_format]), b, df[data_format])
  1211. return a * x + b
  1212. class BatchNorm(object):
  1213. """
  1214. The :class:`BatchNorm` is a batch normalization layer for both fully-connected and convolution outputs.
  1215. See ``tf.nn.batch_normalization`` and ``tf.nn.moments``.
  1216. Parameters
  1217. ----------
  1218. decay : float
  1219. A decay factor for `ExponentialMovingAverage`.
  1220. Suggest to use a large value for large dataset.
  1221. epsilon : float
  1222. Eplison.
  1223. act : activation function
  1224. The activation function of this layer.
  1225. is_train : boolean
  1226. Is being used for training or inference.
  1227. beta_init : initializer or None
  1228. The initializer for initializing beta, if None, skip beta.
  1229. Usually you should not skip beta unless you know what happened.
  1230. gamma_init : initializer or None
  1231. The initializer for initializing gamma, if None, skip gamma.
  1232. When the batch normalization layer is use instead of 'biases', or the next layer is linear, this can be
  1233. disabled since the scaling can be done by the next layer. see `Inception-ResNet-v2 <https://github.com/tensorflow/models/blob/master/research/slim/nets/inception_resnet_v2.py>`__
  1234. moving_mean_init : initializer or None
  1235. The initializer for initializing moving mean, if None, skip moving mean.
  1236. moving_var_init : initializer or None
  1237. The initializer for initializing moving var, if None, skip moving var.
  1238. num_features: int
  1239. Number of features for input tensor. Useful to build layer if using BatchNorm1d, BatchNorm2d or BatchNorm3d,
  1240. but should be left as None if using BatchNorm. Default None.
  1241. data_format : str
  1242. channels_last 'channel_last' (default) or channels_first.
  1243. name : None or str
  1244. A unique layer name.
  1245. Examples
  1246. ---------
  1247. With TensorLayer
  1248. >>> net = tl.layers.Input([None, 50, 50, 32], name='input')
  1249. >>> net = tl.layers.BatchNorm()(net)
  1250. Notes
  1251. -----
  1252. The :class:`BatchNorm` is universally suitable for 3D/4D/5D input in static model, but should not be used
  1253. in dynamic model where layer is built upon class initialization. So the argument 'num_features' should only be used
  1254. for subclasses :class:`BatchNorm1d`, :class:`BatchNorm2d` and :class:`BatchNorm3d`. All the three subclasses are
  1255. suitable under all kinds of conditions.
  1256. References
  1257. ----------
  1258. - `Source <https://github.com/ry/tensorflow-resnet/blob/master/resnet.py>`__
  1259. - `stackoverflow <http://stackoverflow.com/questions/38312668/how-does-one-do-inference-with-batch-normalization-with-tensor-flow>`__
  1260. """
  1261. def __init__(
  1262. self, decay=0.9, epsilon=0.00001, beta=None, gamma=None, moving_mean=None, moving_var=None, num_features=None,
  1263. data_format='channels_last', is_train=False
  1264. ):
  1265. self.decay = decay
  1266. self.epsilon = epsilon
  1267. self.data_format = data_format
  1268. self.beta = beta
  1269. self.gamma = gamma
  1270. self.moving_mean = moving_mean
  1271. self.moving_var = moving_var
  1272. self.num_features = num_features
  1273. self.is_train = is_train
  1274. self.axes = None
  1275. if self.decay < 0.0 or 1.0 < self.decay:
  1276. raise ValueError("decay should be between 0 to 1")
  1277. def _get_param_shape(self, inputs_shape):
  1278. if self.data_format == 'channels_last':
  1279. axis = -1
  1280. elif self.data_format == 'channels_first':
  1281. axis = 1
  1282. else:
  1283. raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first'))
  1284. channels = inputs_shape[axis]
  1285. params_shape = [channels]
  1286. return params_shape
  1287. def _check_input_shape(self, inputs):
  1288. if inputs.ndim <= 1:
  1289. raise ValueError('expected input at least 2D, but got {}D input'.format(inputs.ndim))
  1290. def __call__(self, inputs):
  1291. self._check_input_shape(inputs)
  1292. self.channel_axis = len(inputs.shape) - 1 if self.data_format == 'channels_last' else 1
  1293. if self.axes is None:
  1294. self.axes = [i for i in range(len(inputs.shape)) if i != self.channel_axis]
  1295. mean, var = tf.nn.moments(inputs, self.axes, keepdims=False)
  1296. if self.is_train:
  1297. # update moving_mean and moving_var
  1298. self.moving_mean = moving_averages.assign_moving_average(
  1299. self.moving_mean, mean, self.decay, zero_debias=False
  1300. )
  1301. self.moving_var = moving_averages.assign_moving_average(self.moving_var, var, self.decay, zero_debias=False)
  1302. outputs = batch_normalization(inputs, mean, var, self.beta, self.gamma, self.epsilon, self.data_format)
  1303. else:
  1304. outputs = batch_normalization(
  1305. inputs, self.moving_mean, self.moving_var, self.beta, self.gamma, self.epsilon, self.data_format
  1306. )
  1307. return outputs
  1308. class GroupConv2D(object):
  1309. def __init__(self, strides, padding, data_format, dilations, out_channel, k_size, groups):
  1310. self.data_format, self.padding = preprocess_2d_format(data_format, padding)
  1311. self.strides = strides
  1312. self.dilations = dilations
  1313. self.groups = groups
  1314. if self.data_format == 'NHWC':
  1315. self.channels_axis = 3
  1316. else:
  1317. self.channels_axis = 1
  1318. def __call__(self, input, filters):
  1319. if self.groups == 1:
  1320. outputs = tf.nn.conv2d(
  1321. input=input,
  1322. filters=filters,
  1323. strides=self.strides,
  1324. padding=self.padding,
  1325. data_format=self.data_format,
  1326. dilations=self.dilations,
  1327. )
  1328. else:
  1329. inputgroups = tf.split(input, num_or_size_splits=self.groups, axis=self.channels_axis)
  1330. weightsgroups = tf.split(filters, num_or_size_splits=self.groups, axis=self.channels_axis)
  1331. convgroups = []
  1332. for i, k in zip(inputgroups, weightsgroups):
  1333. convgroups.append(
  1334. tf.nn.conv2d(
  1335. input=i,
  1336. filters=k,
  1337. strides=self.strides,
  1338. padding=self.padding,
  1339. data_format=self.data_format,
  1340. dilations=self.dilations,
  1341. )
  1342. )
  1343. outputs = tf.concat(axis=self.channels_axis, values=convgroups)
  1344. return outputs
  1345. class SeparableConv1D(object):
  1346. def __init__(self, stride, padding, data_format, dilations, out_channel, k_size, in_channel, depth_multiplier):
  1347. self.data_format, self.padding = preprocess_1d_format(data_format, padding)
  1348. if self.data_format == 'NWC':
  1349. self.spatial_start_dim = 1
  1350. self.strides = (1, stride, stride, 1)
  1351. self.data_format = 'NHWC'
  1352. else:
  1353. self.spatial_start_dim = 2
  1354. self.strides = (1, 1, stride, stride)
  1355. self.data_format = 'NCHW'
  1356. self.dilation_rate = (1, dilations)
  1357. def __call__(self, inputs, depthwise_filters, pointwise_filters):
  1358. inputs = tf.expand_dims(inputs, axis=self.spatial_start_dim)
  1359. depthwise_filters = tf.expand_dims(depthwise_filters, 0)
  1360. pointwise_filters = tf.expand_dims(pointwise_filters, 0)
  1361. outputs = tf.nn.separable_conv2d(
  1362. inputs, depthwise_filters, pointwise_filters, strides=self.strides, padding=self.padding,
  1363. dilations=self.dilation_rate, data_format=self.data_format
  1364. )
  1365. outputs = tf.squeeze(outputs, axis=self.spatial_start_dim)
  1366. return outputs
  1367. class SeparableConv2D(object):
  1368. def __init__(self, strides, padding, data_format, dilations, out_channel, k_size, in_channel, depth_multiplier):
  1369. self.data_format, self.padding = preprocess_2d_format(data_format, padding)
  1370. self.strides = strides
  1371. self.dilations = (dilations[2], dilations[2])
  1372. def __call__(self, inputs, depthwise_filters, pointwise_filters):
  1373. outputs = tf.nn.separable_conv2d(
  1374. inputs, depthwise_filters, pointwise_filters, strides=self.strides, padding=self.padding,
  1375. dilations=self.dilations, data_format=self.data_format
  1376. )
  1377. return outputs
  1378. class AdaptiveMeanPool1D(object):
  1379. def __init__(self, output_size, data_format):
  1380. self.data_format, _ = preprocess_1d_format(data_format, None)
  1381. self.output_size = output_size
  1382. def __call__(self, input):
  1383. if self.data_format == 'NWC':
  1384. n, w, c = input.shape
  1385. else:
  1386. n, c, w = input.shape
  1387. stride = floor(w / self.output_size)
  1388. kernel = w - (self.output_size - 1) * stride
  1389. output = tf.nn.avg_pool1d(input, ksize=kernel, strides=stride, data_format=self.data_format, padding='VALID')
  1390. return output
  1391. class AdaptiveMeanPool2D(object):
  1392. def __init__(self, output_size, data_format):
  1393. self.data_format, _ = preprocess_2d_format(data_format, None)
  1394. self.output_size = output_size
  1395. def __call__(self, inputs):
  1396. if self.data_format == 'NHWC':
  1397. n, h, w, c = inputs.shape
  1398. else:
  1399. n, c, h, w = inputs.shape
  1400. out_h, out_w = self.output_size
  1401. stride_h = floor(h / out_h)
  1402. kernel_h = h - (out_h - 1) * stride_h
  1403. stride_w = floor(w / out_w)
  1404. kernel_w = w - (out_w - 1) * stride_w
  1405. outputs = tf.nn.avg_pool2d(
  1406. inputs, ksize=(kernel_h, kernel_w), strides=(stride_h, stride_w), data_format=self.data_format,
  1407. padding='VALID'
  1408. )
  1409. return outputs
  1410. class AdaptiveMeanPool3D(object):
  1411. def __init__(self, output_size, data_format):
  1412. self.data_format, _ = preprocess_3d_format(data_format, None)
  1413. self.output_size = output_size
  1414. def __call__(self, inputs):
  1415. if self.data_format == 'NDHWC':
  1416. n, d, h, w, c = inputs.shape
  1417. else:
  1418. n, c, d, h, w = inputs.shape
  1419. out_d, out_h, out_w = self.output_size
  1420. stride_d = floor(d / out_d)
  1421. kernel_d = d - (out_d - 1) * stride_d
  1422. stride_h = floor(h / out_h)
  1423. kernel_h = h - (out_h - 1) * stride_h
  1424. stride_w = floor(w / out_w)
  1425. kernel_w = w - (out_w - 1) * stride_w
  1426. outputs = tf.nn.avg_pool3d(
  1427. inputs, ksize=(kernel_d, kernel_h, kernel_w), strides=(stride_d, stride_h, stride_w),
  1428. data_format=self.data_format, padding='VALID'
  1429. )
  1430. return outputs
  1431. class AdaptiveMaxPool1D(object):
  1432. def __init__(self, output_size, data_format):
  1433. self.data_format, _ = preprocess_1d_format(data_format, None)
  1434. self.output_size = output_size
  1435. def __call__(self, input):
  1436. if self.data_format == 'NWC':
  1437. n, w, c = input.shape
  1438. else:
  1439. n, c, w = input.shape
  1440. stride = floor(w / self.output_size)
  1441. kernel = w - (self.output_size - 1) * stride
  1442. output = tf.nn.max_pool1d(input, ksize=kernel, strides=stride, data_format=self.data_format, padding='VALID')
  1443. return output
  1444. class AdaptiveMaxPool2D(object):
  1445. def __init__(self, output_size, data_format):
  1446. self.data_format, _ = preprocess_2d_format(data_format, None)
  1447. self.output_size = output_size
  1448. def __call__(self, inputs):
  1449. if self.data_format == 'NHWC':
  1450. n, h, w, c = inputs.shape
  1451. else:
  1452. n, c, h, w = inputs.shape
  1453. out_h, out_w = self.output_size
  1454. stride_h = floor(h / out_h)
  1455. kernel_h = h - (out_h - 1) * stride_h
  1456. stride_w = floor(w / out_w)
  1457. kernel_w = w - (out_w - 1) * stride_w
  1458. outputs = tf.nn.max_pool2d(
  1459. inputs, ksize=(kernel_h, kernel_w), strides=(stride_h, stride_w), data_format=self.data_format,
  1460. padding='VALID'
  1461. )
  1462. return outputs
  1463. class AdaptiveMaxPool3D(object):
  1464. def __init__(self, output_size, data_format):
  1465. self.data_format, _ = preprocess_3d_format(data_format, None)
  1466. self.output_size = output_size
  1467. def __call__(self, inputs):
  1468. if self.data_format == 'NDHWC':
  1469. n, d, h, w, c = inputs.shape
  1470. else:
  1471. n, c, d, h, w = inputs.shape
  1472. out_d, out_h, out_w = self.output_size
  1473. stride_d = floor(d / out_d)
  1474. kernel_d = d - (out_d - 1) * stride_d
  1475. stride_h = floor(h / out_h)
  1476. kernel_h = h - (out_h - 1) * stride_h
  1477. stride_w = floor(w / out_w)
  1478. kernel_w = w - (out_w - 1) * stride_w
  1479. outputs = tf.nn.max_pool3d(
  1480. inputs, ksize=(kernel_d, kernel_h, kernel_w), strides=(stride_d, stride_h, stride_w),
  1481. data_format=self.data_format, padding='VALID'
  1482. )
  1483. return outputs
  1484. class BinaryConv2D(object):
  1485. def __init__(self, strides, padding, data_format, dilations, out_channel, k_size, in_channel):
  1486. self.data_format, self.padding = preprocess_2d_format(data_format, padding)
  1487. self.strides = strides
  1488. self.dilations = dilations
  1489. # @tf.RegisterGradient("TL_Sign_QuantizeGrad")
  1490. # def _quantize_grad(op, grad):
  1491. # """Clip and binarize tensor using the straight through estimator (STE) for the gradient."""
  1492. # return tf.clip_by_value(grad, -1, 1)
  1493. def quantize(self, x):
  1494. # ref: https://github.com/AngusG/tensorflow-xnor-bnn/blob/master/models/binary_net.py#L70
  1495. # https://github.com/itayhubara/BinaryNet.tf/blob/master/nnUtils.py
  1496. with tf.compat.v1.get_default_graph().gradient_override_map({"Sign": "TL_Sign_QuantizeGrad"}):
  1497. return tf.sign(x)
  1498. def __call__(self, inputs, filters):
  1499. filters = self.quantize(filters)
  1500. outputs = tf.nn.conv2d(
  1501. input=inputs, filters=filters, strides=self.strides, padding=self.padding, data_format=self.data_format,
  1502. dilations=self.dilations
  1503. )
  1504. return outputs
  1505. class DorefaConv2D(object):
  1506. def __init__(self, bitW, bitA, strides, padding, data_format, dilations, out_channel, k_size, in_channel):
  1507. self.data_format, self.padding = preprocess_2d_format(data_format, padding)
  1508. self.strides = strides
  1509. self.dilations = dilations
  1510. self.bitW = bitW
  1511. self.bitA = bitA
  1512. def _quantize_dorefa(self, x, k):
  1513. G = tf.compat.v1.get_default_graph()
  1514. n = float(2**k - 1)
  1515. with G.gradient_override_map({"Round": "Identity"}):
  1516. return tf.round(x * n) / n
  1517. def cabs(self, x):
  1518. return tf.minimum(1.0, tf.abs(x), name='cabs')
  1519. def quantize_active(self, x, bitA):
  1520. if bitA == 32:
  1521. return x
  1522. return self._quantize_dorefa(x, bitA)
  1523. def quantize_weight(self, x, bitW, force_quantization=False):
  1524. G = tf.compat.v1.get_default_graph()
  1525. if bitW == 32 and not force_quantization:
  1526. return x
  1527. if bitW == 1: # BWN
  1528. with G.gradient_override_map({"Sign": "Identity"}):
  1529. E = tf.stop_gradient(tf.reduce_mean(input_tensor=tf.abs(x)))
  1530. return tf.sign(x / E) * E
  1531. x = tf.clip_by_value(
  1532. x * 0.5 + 0.5, 0.0, 1.0
  1533. ) # it seems as though most weights are within -1 to 1 region anyways
  1534. return 2 * self._quantize_dorefa(x, bitW) - 1
  1535. def __call__(self, inputs, filters):
  1536. inputs = self.quantize_active(self.cabs(inputs), self.bitA)
  1537. filters = self.quantize_weight(filters, self.bitW)
  1538. outputs = tf.nn.conv2d(
  1539. input=inputs,
  1540. filters=filters,
  1541. strides=self.strides,
  1542. padding=self.padding,
  1543. data_format=self.data_format,
  1544. dilations=self.dilations,
  1545. )
  1546. return outputs
  1547. class rnncell(object):
  1548. def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, act):
  1549. self.weight_ih = weight_ih
  1550. self.weight_hh = weight_hh
  1551. self.bias_ih = bias_ih
  1552. self.bias_hh = bias_hh
  1553. self.act_fn = tf.nn.relu if act == 'relu' else tf.nn.tanh
  1554. def __call__(self, input, h, c=None):
  1555. i2h = tf.matmul(input, self.weight_ih, transpose_b=True)
  1556. if self.bias_ih is not None:
  1557. i2h += self.bias_ih
  1558. h2h = tf.matmul(h, self.weight_hh, transpose_b=True)
  1559. if self.bias_hh is not None:
  1560. h2h += self.bias_hh
  1561. h = self.act_fn(i2h + h2h)
  1562. return h, h
  1563. class lstmcell(object):
  1564. def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, act=None):
  1565. self.weight_ih = weight_ih
  1566. self.weight_hh = weight_hh
  1567. self.bias_ih = bias_ih
  1568. self.bias_hh = bias_hh
  1569. self.gate_act_fn = tf.sigmoid
  1570. self.act_fn = tf.tanh
  1571. def __call__(self, input, h, c):
  1572. gates = tf.matmul(input, self.weight_ih, transpose_b=True)
  1573. if self.bias_ih is not None:
  1574. gates = gates + self.bias_ih
  1575. gates += tf.matmul(h, self.weight_hh, transpose_b=True)
  1576. if self.bias_hh is not None:
  1577. gates += self.bias_hh
  1578. gate_slices = tf.split(gates, num_or_size_splits=4, axis=-1)
  1579. i = self.gate_act_fn(gate_slices[0])
  1580. f = self.gate_act_fn(gate_slices[1])
  1581. o = self.gate_act_fn(gate_slices[3])
  1582. c = f * c + i * self.act_fn(gate_slices[2])
  1583. h = o * self.act_fn(c)
  1584. return h, h, c
  1585. class grucell(object):
  1586. def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, act=None):
  1587. self.weight_ih = weight_ih
  1588. self.weight_hh = weight_hh
  1589. self.bias_ih = bias_ih
  1590. self.bias_hh = bias_hh
  1591. self.gate_act_fn = tf.sigmoid
  1592. self.act_fn = tf.tanh
  1593. def __call__(self, input, h, c=None):
  1594. x_gates = tf.matmul(input, self.weight_ih, transpose_b=True)
  1595. if self.bias_ih is not None:
  1596. x_gates = x_gates + self.bias_ih
  1597. h_gates = tf.matmul(h, self.weight_hh, transpose_b=True)
  1598. if self.bias_hh is not None:
  1599. h_gates = h_gates + self.bias_hh
  1600. x_r, x_z, x_c = tf.split(x_gates, num_or_size_splits=3, axis=-1)
  1601. h_r, h_z, h_c = tf.split(h_gates, num_or_size_splits=3, axis=-1)
  1602. r = self.gate_act_fn(x_r + h_r)
  1603. z = self.gate_act_fn(x_r + h_z)
  1604. c = self.act_fn(x_c + r * h_c)
  1605. h = (h - c) * z + c
  1606. return h, h
  1607. class rnnbase(object):
  1608. def __init__(
  1609. self,
  1610. mode,
  1611. input_size,
  1612. hidden_size,
  1613. num_layers,
  1614. bias,
  1615. batch_first,
  1616. dropout,
  1617. bidirectional,
  1618. is_train,
  1619. weights_fw,
  1620. weights_bw,
  1621. bias_fw,
  1622. bias_bw,
  1623. ):
  1624. self.mode = mode
  1625. self.input_size = input_size
  1626. self.hidden_size = hidden_size
  1627. self.num_layers = num_layers
  1628. self.bias = bias
  1629. self.batch_first = batch_first
  1630. self.dropout = float(dropout)
  1631. self.train = is_train
  1632. if not 0 <= dropout < 1:
  1633. raise ValueError("dropout should be a number in range [0, 1).")
  1634. if dropout > 0 and num_layers == 1:
  1635. raise ValueError(
  1636. "dropout option adds dropout after all but last "
  1637. "recurrent layer, so non-zero dropout expects "
  1638. "num_layers greater than 1, but got dropout={} and "
  1639. "num_layers={}".format(dropout, num_layers)
  1640. )
  1641. self.bidirect = 2 if bidirectional else 1
  1642. self.weights_fw = weights_fw
  1643. self.bias_fw = bias_fw
  1644. self.weights_bw = weights_bw
  1645. self.bias_bw = bias_bw
  1646. # stdv = 1.0 / np.sqrt(self.hidden_size)
  1647. # _init = tf.random_uniform_initializer(minval=-stdv, maxval=stdv)
  1648. self.act_fn = None
  1649. if mode == 'LSTM':
  1650. # gate_size = 4 * hidden_size
  1651. self.rnn_cell = lstmcell
  1652. elif mode == 'GRU':
  1653. # gate_size = 3 * hidden_size
  1654. self.rnn_cell = grucell
  1655. elif mode == 'RNN_TANH':
  1656. # gate_size = hidden_size
  1657. self.rnn_cell = rnncell
  1658. self.act_fn = 'tanh'
  1659. elif mode == 'RNN_RELU':
  1660. # gate_size = hidden_size
  1661. self.rnn_cell = rnncell
  1662. self.act_fn = 'relu'
  1663. # for layer in range(num_layers):
  1664. # for direction in range(self.bidirect):
  1665. # layer_input_size = input_size if layer==0 else hidden_size*self.bidirect
  1666. # if direction == 0:
  1667. # self.w_ih = tf.Variable(initial_value= _init(shape=(gate_size, layer_input_size)),name = 'weight_ih_l'+str(layer), trainable=True)
  1668. # self.w_hh = tf.Variable(initial_value=_init(shape=(gate_size, hidden_size)),
  1669. # name='weight_hh_l'+str(layer), trainable=True)
  1670. # # self.w_ih = self.weights_init('weight_ih_l'+str(layer), shape = (gate_size, layer_input_size), init = _init)
  1671. # # self.w_hh = self.weights_init('weight_ih_l' + str(layer), shape=(gate_size, hidden_size),
  1672. # # init=_init)
  1673. # self.weights_fw.append(self.w_ih)
  1674. # self.weights_fw.append(self.w_hh)
  1675. # if bias:
  1676. # self.b_ih = tf.Variable(initial_value=_init(shape=(gate_size,)),
  1677. # name='bias_ih_l'+str(layer), trainable=True)
  1678. # self.b_hh = tf.Variable(initial_value=_init(shape=(gate_size,)),
  1679. # name='bias_hh_l'+str(layer), trainable=True)
  1680. # # self.b_ih = self.weights_init('bias_ih_l'+str(layer), shape=(gate_size,), init=_init)
  1681. # # self.b_hh = self.weights_init('bias_hh_l'+str(layer), shape=(gate_size,), init=_init)
  1682. # self.bias_fw.append(self.b_ih)
  1683. # self.bias_fw.append(self.b_hh)
  1684. # else:
  1685. # self.w_ih = tf.Variable(initial_value= _init(shape=(gate_size, layer_input_size)),name = 'weight_ih_l'+str(layer)+'_reverse', trainable=True)
  1686. # self.w_hh = tf.Variable(initial_value=_init(shape=(gate_size, hidden_size)),
  1687. # name='weight_hh_l'+str(layer)+'_reverse', trainable=True)
  1688. # # self.w_ih = self.weights_init('weight_ih_l'+str(layer)+'_reverse', shape = (gate_size, layer_input_size), init = _init)
  1689. # # self.w_hh = self.weights_init('weight_hh_l'+str(layer)+'_reverse', shape=(gate_size, hidden_size),
  1690. # # init=_init)
  1691. # self.weights_bw.append(self.w_ih)
  1692. # self.weights_bw.append(self.w_hh)
  1693. # if bias:
  1694. # self.b_ih = tf.Variable(initial_value=_init(shape=(gate_size,)),
  1695. # name='bias_ih_l'+str(layer)+'_reverse', trainable=True)
  1696. # self.b_hh = tf.Variable(initial_value=_init(shape=(gate_size,)),
  1697. # name='bias_hh_l'+str(layer)+'_reverse', trainable=True)
  1698. # # self.b_ih = self.weights_init('bias_ih_l'+str(layer)+'_reverse', shape=(gate_size,), init=_init)
  1699. # # self.b_hh = self.weights_init('bias_hh_l'+str(layer)+'_reverse', shape=(gate_size,), init=_init)
  1700. # self.bias_bw.append(self.b_ih)
  1701. # self.bias_bw.append(self.b_hh)
  1702. def _bi_rnn_forward(self, x, h, c=None):
  1703. time_step, batch_size, input_size = x.shape
  1704. h_out = []
  1705. c_out = []
  1706. y = []
  1707. pre_layer = x
  1708. for i in range(self.num_layers):
  1709. weight_ih_fw = self.weights_fw[2 * i]
  1710. weight_hh_fw = self.weights_fw[2 * i + 1]
  1711. weight_ih_bw = self.weights_bw[2 * i]
  1712. weight_hh_bw = self.weights_bw[2 * i + 1]
  1713. if self.bias:
  1714. bias_ih_fw = self.bias_fw[2 * i]
  1715. bias_hh_fw = self.bias_fw[2 * i + 1]
  1716. bias_ih_bw = self.bias_bw[2 * i]
  1717. bias_hh_bw = self.bias_bw[2 * i + 1]
  1718. else:
  1719. bias_ih_fw = None
  1720. bias_hh_fw = None
  1721. bias_ih_bw = None
  1722. bias_hh_bw = None
  1723. h_i_fw = h[i, :, :]
  1724. h_i_bw = h[i + 1, :, :]
  1725. if i != 0 and self.train:
  1726. pre_layer = tf.nn.dropout(pre_layer, rate=self.dropout)
  1727. if c is not None:
  1728. c_i_fw = c[i, :, :]
  1729. c_i_bw = c[i + 1, :, :]
  1730. for j in range(time_step):
  1731. input = pre_layer[j, :, :]
  1732. cell_fw = self.rnn_cell(weight_ih_fw, weight_hh_fw, bias_ih_fw, bias_hh_fw, self.act_fn)
  1733. cell_bw = self.rnn_cell(weight_ih_bw, weight_hh_bw, bias_ih_bw, bias_hh_bw, self.act_fn)
  1734. bw_input = tf.reverse(input, axis=[0])
  1735. step_out_fw, h_i_fw, c_i_fw = cell_fw(input, h_i_fw, c_i_fw)
  1736. step_out_bw, h_i_bw, c_i_bw = cell_bw(bw_input, h_i_bw, c_i_bw)
  1737. step_out_bw = tf.reverse(step_out_bw, axis=[0])
  1738. step_out = tf.concat([step_out_fw, step_out_bw], axis=-1)
  1739. y.append(step_out)
  1740. h_out.append(h_i_fw)
  1741. h_out.append(h_i_bw)
  1742. c_out.append(c_i_fw)
  1743. c_out.append(c_i_bw)
  1744. pre_layer = tf.stack(y)
  1745. y = []
  1746. else:
  1747. for j in range(time_step):
  1748. input = pre_layer[j, :, :]
  1749. cell_fw = self.rnn_cell(weight_ih_fw, weight_hh_fw, bias_ih_fw, bias_hh_fw, self.act_fn)
  1750. cell_bw = self.rnn_cell(weight_ih_bw, weight_hh_bw, bias_ih_bw, bias_hh_bw, self.act_fn)
  1751. bw_input = tf.reverse(input, axis=[0])
  1752. step_out_fw, h_i_fw = cell_fw(input, h_i_fw)
  1753. step_out_bw, h_i_bw = cell_bw(bw_input, h_i_bw)
  1754. step_out_bw = tf.reverse(step_out_bw, axis=[0])
  1755. step_out = tf.concat([step_out_fw, step_out_bw], axis=-1)
  1756. y.append(step_out)
  1757. h_out.append(h_i_fw)
  1758. h_out.append(h_i_bw)
  1759. pre_layer = tf.stack(y)
  1760. y = []
  1761. h_out = tf.stack(h_out)
  1762. c_out = tf.stack(c_out) if c is not None else None
  1763. return pre_layer, h_out, c_out
  1764. def _rnn_forward(self, x, h, c=None):
  1765. pre_layer = x
  1766. h_out = []
  1767. c_out = []
  1768. y = []
  1769. time_step, batch_size, input_size = x.shape
  1770. for i in range(self.num_layers):
  1771. weight_ih = self.weights_fw[2 * i]
  1772. weight_hh = self.weights_fw[2 * i + 1]
  1773. if self.bias:
  1774. bias_ih = self.bias_fw[2 * i]
  1775. bias_hh = self.bias_fw[2 * i + 1]
  1776. else:
  1777. bias_ih = None
  1778. bias_hh = None
  1779. h_i = h[i, :, :]
  1780. if i != 0 and self.train:
  1781. pre_layer = tf.nn.dropout(pre_layer, rate=self.dropout)
  1782. if c is not None:
  1783. c_i = c[i, :, :]
  1784. for j in range(time_step):
  1785. input = pre_layer[j, :, :]
  1786. cell = self.rnn_cell(weight_ih, weight_hh, bias_ih, bias_hh, self.act_fn)
  1787. step_out, h_i, c_i = cell(input, h_i, c_i)
  1788. y.append(step_out)
  1789. h_out.append(h_i)
  1790. c_out.append(c_i)
  1791. pre_layer = tf.stack(y)
  1792. y = []
  1793. else:
  1794. for j in range(time_step):
  1795. input = pre_layer[j, :, :]
  1796. cell = self.rnn_cell(weight_hh, weight_ih, bias_ih, bias_hh, self.act_fn)
  1797. step_out, h_i = cell(input, h_i)
  1798. y.append(step_out)
  1799. h_out.append(h_i)
  1800. pre_layer = tf.stack(y)
  1801. y = []
  1802. h_out = tf.stack(h_out)
  1803. c_out = tf.stack(c_out) if c is not None else None
  1804. return pre_layer, h_out, c_out
  1805. def check_input(self, input_shape):
  1806. if len(input_shape) != 3:
  1807. raise ValueError("input must have 3 dimensions. But got {}.".format(len(input_shape)))
  1808. if self.input_size != input_shape[-1]:
  1809. raise ValueError(
  1810. "The last dimension of input should be equal to input_size {}.But got {}".format(
  1811. self.input_size, input_shape[-1]
  1812. )
  1813. )
  1814. def check_hidden(self, h, batch_size):
  1815. expected_hidden_size = (self.num_layers * self.bidirect, batch_size, self.hidden_size)
  1816. if h.shape != expected_hidden_size:
  1817. raise ValueError('Expected hidden size {}, got {}.'.format(expected_hidden_size, h.shape))
  1818. def __call__(self, input, states):
  1819. if self.batch_first:
  1820. input = tf.transpose(input, perm=(1, 0, 2))
  1821. input_dtype = input.dtype
  1822. input_shape = input.shape
  1823. time_step, batch_size, input_size = input_shape
  1824. self.check_input(input_shape)
  1825. if self.mode == "LSTM":
  1826. if states is not None:
  1827. h, c = states
  1828. self.check_hidden(h, batch_size)
  1829. self.check_hidden(c, batch_size)
  1830. else:
  1831. h = tf.zeros(shape=(self.num_layers * self.bidirect, batch_size, self.hidden_size), dtype=input_dtype)
  1832. c = tf.zeros(shape=(self.num_layers * self.bidirect, batch_size, self.hidden_size), dtype=input_dtype)
  1833. if self.bidirect == 1:
  1834. y, new_h, new_c = self._rnn_forward(input, h, c)
  1835. else:
  1836. y, new_h, new_c = self._bi_rnn_forward(input, h, c)
  1837. new_states = (new_h, new_c)
  1838. else:
  1839. if states is not None:
  1840. h = states
  1841. self.check_hidden(h, batch_size)
  1842. else:
  1843. h = tf.zeros(shape=(self.num_layers * self.bidirect, batch_size, self.hidden_size), dtype=input_dtype)
  1844. if self.bidirect == 1:
  1845. y, new_h, _ = self._rnn_forward(input, h)
  1846. else:
  1847. y, new_h, _ = self._bi_rnn_forward(input, h)
  1848. new_states = new_h
  1849. if self.batch_first:
  1850. y = tf.transpose(y, perm=(1, 0, 2))
  1851. return y, new_states

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.