You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensorflow_backend.py 25 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. from __future__ import absolute_import, division, print_function
  4. from .tensorflow_nn import nchw_to_nhwc, nhwc_to_nchw
  5. import tensorflow as tf
  6. _dtypeDict = {
  7. 'DType': tf.DType,
  8. 'float16': tf.float16,
  9. 'float32': tf.float32,
  10. 'float64': tf.float64,
  11. 'int8': tf.int8,
  12. 'int16': tf.int16,
  13. 'int32': tf.int32,
  14. 'int64': tf.int64,
  15. 'uint8': tf.uint8,
  16. 'uint16': tf.uint16,
  17. 'uint32': tf.uint32,
  18. 'uint64': tf.uint64
  19. }
  20. DType = tf.DType
  21. float16 = tf.float16
  22. float32 = tf.float32
  23. float64 = tf.float64
  24. int8 = tf.int8
  25. int16 = tf.int16
  26. int32 = tf.int32
  27. int64 = tf.int64
  28. uint8 = tf.uint8
  29. uint16 = tf.uint16
  30. uint32 = tf.uint32
  31. uint64 = tf.uint64
  32. # isinstance input output
  33. # TensorLike = tf_ops._TensorLike
  34. def set_context(**kwargs):
  35. raise Exception("Using TenosrFlow backend,You don't need to set context")
  36. def get_tensor_shape(x):
  37. return x.get_shape().as_list()
  38. # initializers
  39. def zeros(shape, dtype=tf.float32):
  40. """
  41. Creates a tensor with all elements set to zero.
  42. Parameters
  43. ----------
  44. shape : A list of integers
  45. a tuple of integers, or a 1-D Tensor of type int32.
  46. dtype : tensor
  47. The DType of an element in the resulting Tensor
  48. Returns
  49. -------
  50. A Tensor with all elements set to zero.
  51. """
  52. return tf.zeros(shape=shape, dtype=dtype)
  53. def ones(shape, dtype=tf.float32):
  54. """
  55. Creates a tensor with all elements set to ones.
  56. Parameters
  57. ----------
  58. shape : A list of integers
  59. a tuple of integers, or a 1-D Tensor of type int32.
  60. dtype : tensor
  61. The DType of an element in the resulting Tensor
  62. Returns
  63. -------
  64. A Tensor with all elements set to zero.
  65. """
  66. return tf.ones(shape=shape, dtype=dtype)
  67. def constant(value, dtype=tf.float32, shape=None):
  68. """
  69. Creates a constant tensor from a tensor-like object.
  70. Parameters
  71. ----------
  72. value : list
  73. A constant value (or list) of output type dtype.
  74. dtype : tensor
  75. The type of the elements of the resulting tensor.
  76. shape : tuple
  77. Optional dimensions of resulting tensor.
  78. Returns
  79. -------
  80. A Constant Tensor.
  81. """
  82. return tf.constant(value=value, dtype=dtype, shape=shape)
  83. def random_uniform(shape, minval=0, maxval=None, dtype=tf.float32, seed=None):
  84. """
  85. Outputs random values from a uniform distribution.
  86. Parameters
  87. ----------
  88. shape : tuple
  89. A 1-D integer Tensor or Python array. The shape of the output tensor.
  90. minval : int
  91. The lower bound on the range of random values to generate (inclusive). Defaults to 0.
  92. maxval : int
  93. The upper bound on the range of random values to generate (exclusive). Defaults to 1 if dtype is floating point.
  94. dtype : tensor
  95. The type of the output: float16, float32, float64, int32, or int64.
  96. seed : int
  97. Used in combination with tf.random.set_seed to create a reproducible sequence of tensors across multiple calls.
  98. Returns
  99. -------
  100. A tensor of the specified shape filled with random uniform values.
  101. """
  102. outputs = tf.random.uniform(shape=shape, minval=minval, maxval=maxval, dtype=dtype, seed=seed)
  103. return outputs
  104. def random_normal(shape, mean=0.0, stddev=1.0, dtype=tf.dtypes.float32, seed=None):
  105. """
  106. Outputs random values from a normal distribution.
  107. Parameters
  108. ----------
  109. shape : tuple
  110. A 1-D integer Tensor or Python array. The shape of the output tensor.
  111. mean : float
  112. The mean of the normal distribution
  113. stddev : float
  114. The standard deviation of the normal distribution.
  115. dtype : tensor
  116. The type of the output.
  117. seed : A Python integer
  118. Used to create a random seed for the distribution
  119. Returns
  120. -------
  121. A tensor of the specified shape filled with random normal values.
  122. """
  123. outputs = tf.random.normal(shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed)
  124. return outputs
  125. def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=tf.float32, seed=None):
  126. """
  127. Outputs random values from a truncated normal distribution.
  128. Parameters
  129. ----------
  130. shape : tuple
  131. A 1-D integer Tensor or Python array. The shape of the output tensor.
  132. mean : float
  133. The mean of the normal distribution
  134. stddev : float
  135. The standard deviation of the normal distribution.
  136. dtype : tensor
  137. The type of the output.
  138. seed : A Python integer
  139. Used to create a random seed for the distribution
  140. Returns
  141. -------
  142. A tensor of the specified shape filled with random truncated normal values.
  143. """
  144. outputs = tf.random.truncated_normal(shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed)
  145. return outputs
  146. def he_normal(shape, dtype, seed=None):
  147. """
  148. He normal initializer.
  149. Parameters
  150. ----------
  151. seed : A Python integer.
  152. Used to seed the random generator.
  153. shape : tuple
  154. A 1-D integer Tensor or Python array. The shape of the output tensor.
  155. dtype : tensor
  156. The type of the output.
  157. Returns
  158. -------
  159. A tensor of the specified shape filled with he normal values.
  160. """
  161. return tf.initializers.he_normal(seed)(shape=shape, dtype=dtype)
  162. def Variable(initial_value, name, trainable=True):
  163. """
  164. Creates a new variable with value initial_value.
  165. Parameters
  166. ----------
  167. initial_value : tensor
  168. A Tensor, or Python object convertible to a Tensor
  169. name : str
  170. Optional name for the variable. Defaults to 'Variable' and gets uniquified automatically.
  171. Returns
  172. -------
  173. Variable
  174. """
  175. var = tf.Variable(initial_value=initial_value, name=name, trainable=trainable)
  176. return var
  177. class MatMul(object):
  178. def __init__(self):
  179. pass
  180. def __call__(self, a, b):
  181. return tf.matmul(a, b)
  182. def matmul(a, b):
  183. """
  184. Multiplies matrix a by matrix b, producing a * b.
  185. Parameters
  186. ----------
  187. a : tensor
  188. type float16, float32, float64, int32, complex64, complex128 and rank > 1.
  189. b : tensor
  190. with same type and rank as a.
  191. Returns
  192. -------
  193. A Tensor of the same type as a and b
  194. """
  195. outputs = tf.matmul(a, b)
  196. return outputs
  197. def add(value, bias):
  198. """
  199. Returns x + y element-wise.
  200. Parameters
  201. ----------
  202. value : tensor.
  203. Must be one of the following types: bfloat16, half, float32, float64,
  204. uint8, int8, int16, int32, int64, complex64, complex128, string.
  205. bias : tensor
  206. Must have the same type as a
  207. Returns
  208. -------
  209. A Tensor. Has the same type as a.
  210. """
  211. outputs = tf.add(value, bias)
  212. return outputs
  213. def dtypes(dt):
  214. """
  215. Data dtypes.
  216. Parameters
  217. ----------
  218. dt : string
  219. It could be 'uint8', 'uint16', 'uint32', 'uint64', 'int8', 'int16',
  220. 'int32', 'int64', 'float16', 'float32', 'float64', 'DType'.
  221. Returns
  222. -------
  223. Data dtypes
  224. """
  225. if dt not in _dtypeDict.keys():
  226. raise Exception("Unsupported dtype: {}".format(dt))
  227. return _dtypeDict[dt]
  228. class Maximum(object):
  229. def __init__(self):
  230. pass
  231. def __call__(self, x, y):
  232. return tf.maximum(x=x, y=y)
  233. class Minimum(object):
  234. def __init__(self):
  235. pass
  236. def __call__(self, x, y):
  237. return tf.minimum(x=x, y=y)
  238. def minimum(x, y):
  239. """
  240. Returns the min of x and y (i.e. x < y ? x : y) element-wise.
  241. Parameters
  242. ----------
  243. x : tensor.
  244. Must be one of the following types: bfloat16, half, float32, float64, int32, int64.
  245. y : A Tensor.
  246. Must have the same type as x.
  247. Returns
  248. -------
  249. A Tensor. Has the same type as x
  250. """
  251. outputs = tf.minimum(x=x, y=y)
  252. return outputs
  253. class FlattenReshape(object):
  254. def __init__(self):
  255. pass
  256. def __call__(self, inputs):
  257. dim = 1
  258. for d in get_tensor_shape(inputs)[1:]:
  259. dim *= d
  260. return tf.reshape(inputs, [-1, dim])
  261. class Reshape(object):
  262. def __init__(self, shape):
  263. self.shape = shape
  264. def __call__(self, tensor):
  265. return tf.reshape(tensor, self.shape)
  266. def reshape(tensor, shape):
  267. """
  268. Reshapes a tensor.
  269. Parameters
  270. ----------
  271. tensor : tensor
  272. A Tensor.
  273. shape : tensor
  274. Defines the shape of the output tensor.
  275. Returns
  276. -------
  277. A Tensor. Has the same type as tensor
  278. """
  279. return tf.reshape(tensor, shape)
  280. class Concat(object):
  281. def __init__(self, axis):
  282. super(Concat, self).__init__()
  283. self.axis = axis
  284. def __call__(self, values):
  285. return tf.concat(values=values, axis=self.axis)
  286. def concat(values, axis):
  287. """
  288. Concatenates tensors along one dimension.
  289. Parameters
  290. ----------
  291. values : list
  292. A list of Tensor objects or a single Tensor
  293. axis : int
  294. 0-D int32 Tensor. Dimension along which to concatenate
  295. Returns
  296. -------
  297. A Tensor resulting from concatenation of the input tensors.
  298. """
  299. return tf.concat(values, axis)
  300. def convert_to_tensor(value, dtype=None):
  301. """
  302. Converts the given value to a Tensor.
  303. Parameters
  304. ----------
  305. value : object
  306. An object whose type has a registered Tensor conversion function.
  307. dtype : optional
  308. Optional element type for the returned tensor. If missing, the type is inferred from the type of value.
  309. Returns
  310. -------
  311. A Tensor based on value.
  312. """
  313. return tf.convert_to_tensor(value, dtype)
  314. def sqrt(x):
  315. """
  316. Computes square root of x element-wise.
  317. Parameters
  318. ----------
  319. x : tensor
  320. Must be one of the following types: bfloat16, half, float32, float64, complex64, complex128.
  321. Returns
  322. -------
  323. A Tensor. Has the same type as x.
  324. """
  325. return tf.sqrt(x)
  326. class ReduceSum(object):
  327. def __init__(self, axis=None):
  328. self.axis = axis
  329. def __call__(self, input):
  330. return tf.reduce_sum(input, axis=self.axis)
  331. class ReduceMean(object):
  332. def __init__(self, axis):
  333. self.axis = axis
  334. def __call__(self, inputs):
  335. output = tf.reduce_mean(inputs, self.axis)
  336. return output
  337. def reduce_mean(input_tensor, axis=None):
  338. """
  339. Computes the mean of elements across dimensions of a tensor.
  340. Parameters
  341. ----------
  342. input_tensor : tensor
  343. The tensor to reduce. Should have numeric type.
  344. axis : list
  345. The dimensions to reduce. If None (the default), reduces all dimensions.
  346. Must be in the range [-rank(input_tensor), rank(input_tensor)).
  347. name : str
  348. A name for the operation (optional).
  349. Returns
  350. -------
  351. The reduced tensor.
  352. """
  353. return tf.reduce_mean(input_tensor, axis=axis)
  354. class ReduceMax(object):
  355. def __init__(self, axis):
  356. self.axis = axis
  357. def __call__(self, inputs):
  358. output = tf.reduce_max(inputs, self.axis)
  359. return output
  360. def reduce_max(input_tensor, axis=None):
  361. """
  362. Computes the maximum of elements across dimensions of a tensor.
  363. Parameters
  364. ----------
  365. input_tensor : tensor
  366. The tensor to reduce. Should have real numeric type.
  367. axis : int
  368. The dimensions to reduce. If None (the default), reduces all dimensions.
  369. Must be in the range [-rank(input_tensor), rank(input_tensor)).
  370. name : str
  371. A name for the operation (optional).
  372. Returns
  373. -------
  374. The reduced tensor.
  375. """
  376. return tf.reduce_max(input_tensor, axis=axis)
  377. def reduce_min(input_tensor, axis=None):
  378. """
  379. Computes the minimum of elements across dimensions of a tensor.
  380. Parameters
  381. ----------
  382. input_tensor : tensor
  383. The tensor to reduce. Should have real numeric type.
  384. axis : int
  385. The dimensions to reduce. If None (the default), reduces all dimensions.
  386. Must be in the range [-rank(input_tensor), rank(input_tensor)).
  387. name : str
  388. A name for the operation (optional).
  389. Returns
  390. -------
  391. The reduced tensor.
  392. """
  393. return tf.reduce_min(input_tensor, axis=axis)
  394. class Pad(object):
  395. def __init__(self, paddings, mode="REFLECT"):
  396. if mode not in ['CONSTANT', 'REFLECT', 'SYMMETRIC']:
  397. raise Exception("Unsupported mode: {}".format(mode))
  398. self.paddings = paddings
  399. self.mode = mode
  400. def __call__(self, x):
  401. outputs = tf.pad(x, self.paddings, mode=self.mode, constant_values=0)
  402. return outputs
  403. def pad(tensor, paddings, mode='CONSTANT', constant_values=0):
  404. """
  405. Pads a tensor.
  406. Parameters
  407. ----------
  408. tensor : tensor
  409. A Tensor.
  410. paddings : tensor
  411. A Tensor of type int32.
  412. mode : str
  413. One of "CONSTANT", "REFLECT", or "SYMMETRIC" (case-insensitive)
  414. constant_values : int
  415. In "CONSTANT" mode, the scalar pad value to use. Must be same type as tensor.
  416. Returns
  417. -------
  418. A Tensor. Has the same type as tensor.
  419. """
  420. if mode not in ['CONSTANT', 'REFLECT', 'SYMMETRIC']:
  421. raise Exception("Unsupported mode: {}".format(mode))
  422. outputs = tf.pad(tensor, paddings, mode=mode, constant_values=constant_values)
  423. return outputs
  424. class Unstack(object):
  425. def __init__(self, axis, num=None):
  426. self.axis = axis
  427. self.num = num
  428. def __call__(self, values):
  429. return tf.unstack(values, num=self.num, axis=self.axis)
  430. class Stack(object):
  431. def __init__(self, axis=0):
  432. self.axis = axis
  433. def __call__(self, values):
  434. return tf.stack(values, axis=self.axis)
  435. def stack(values, axis=0):
  436. """
  437. Stacks a list of rank-R tensors into one rank-(R+1) tensor.
  438. Parameters
  439. ----------
  440. values : list
  441. A list of Tensor objects with the same shape and type.
  442. axis : int
  443. An int. The axis to stack along. Defaults to the first dimension.
  444. Negative values wrap around, so the valid range is [-(R+1), R+1).
  445. Returns
  446. -------
  447. A stacked Tensor with the same type as values.
  448. """
  449. return tf.stack(values, axis=axis)
  450. class Meshgrid(object):
  451. def __init__(self, indexing='xy'):
  452. super(Meshgrid, self).__init__()
  453. self.index = indexing
  454. def __call__(self, inputs):
  455. return tf.meshgrid(inputs)
  456. def meshgrid(*args, **kwargs):
  457. """
  458. Broadcasts parameters for evaluation on an N-D grid.
  459. Parameters
  460. ----------
  461. x : tensor
  462. Tensors with rank 1.
  463. y : tensor
  464. Tensors with rank 1.
  465. Returns
  466. -------
  467. A list of N Tensors with rank N.
  468. """
  469. return tf.meshgrid(*args, **kwargs)
  470. def range(start, limit=None, delta=1, dtype=None):
  471. """
  472. Creates a sequence of numbers.
  473. Parameters
  474. ----------
  475. start : tensor
  476. A 0-D Tensor (scalar). Acts as first entry in the range if limit is not None;
  477. otherwise, acts as range limit and first entry defaults to 0.
  478. limit : tensor
  479. A 0-D Tensor (scalar). Upper limit of sequence, exclusive. If None,
  480. defaults to the value of start while the first entry of the range defaults to 0.
  481. delta : tensor
  482. A 0-D Tensor (scalar). Number that increments start. Defaults to 1.
  483. dtype : type
  484. The type of the elements of the resulting tensor.
  485. Returns
  486. -------
  487. An 1-D Tensor of type dtype.
  488. """
  489. if limit is None:
  490. outputs = tf.range(start, delta=delta, dtype=dtype)
  491. else:
  492. outputs = tf.range(start, limit, delta=delta, dtype=dtype)
  493. return outputs
  494. class ExpandDims(object):
  495. def __init__(self, axis):
  496. self.axis = axis
  497. def __call__(self, input):
  498. return tf.expand_dims(input, axis=self.axis)
  499. def expand_dims(input, axis):
  500. """
  501. Inserts a dimension of 1 into a tensor's shape.
  502. Parameters
  503. ----------
  504. input : tensor
  505. A Tensor.
  506. axis : int
  507. 0-D (scalar). Specifies the dimension index at which to expand the shape of input.
  508. Must be in the range [-rank(input) - 1, rank(input)].
  509. Returns
  510. -------
  511. A Tensor with the same data as input, but its shape has an additional dimension of size 1 added.
  512. """
  513. return tf.expand_dims(input, axis)
  514. class Tile(object):
  515. def __init__(self):
  516. pass
  517. def __call__(self, input, multiples):
  518. return tf.tile(input, multiples)
  519. def tile(input, multiples):
  520. """
  521. Constructs a tensor by tiling a given tensor.
  522. Parameters
  523. ----------
  524. input : tensor
  525. A Tensor. 1-D or higher.
  526. multiples : tensor
  527. Must be one of the following types: int32, int64. 1-D.
  528. Length must be the same as the number of dimensions in input
  529. Returns
  530. -------
  531. A Tensor. Has the same type as input.
  532. """
  533. return tf.tile(input, multiples)
  534. class Cast(object):
  535. def __init__(self, dtype):
  536. self.dtype = dtype
  537. def __call__(self, x):
  538. return tf.cast(x, dtype=self.dtype)
  539. def cast(x, dtype):
  540. """
  541. Casts a tensor to a new type.
  542. Parameters
  543. ----------
  544. x : tensor
  545. A Tensor or SparseTensor or IndexedSlices of numeric type.
  546. It could be uint8, uint16, uint32, uint64, int8, int16, int32, int64, float16, float32, float64.
  547. dtype : dtpye
  548. The destination type. The list of supported dtypes is the same as x
  549. Returns
  550. -------
  551. A Tensor or SparseTensor or IndexedSlices with same shape as x and same type as dtype.
  552. """
  553. return tf.cast(x, dtype=dtype)
  554. class Transpose(object):
  555. def __init__(self, perm, conjugate=False):
  556. self.perm = perm
  557. self.conjugate = conjugate
  558. def __call__(self, a):
  559. return tf.transpose(a, self.perm, self.conjugate)
  560. def transpose(a, perm=None, conjugate=False):
  561. """
  562. Transposes a.
  563. Parameters
  564. ----------
  565. a : tensor
  566. A Tensor.
  567. perm : list / int
  568. A permutation of the dimensions of a.
  569. conjugate : bool
  570. Setting it to True is mathematically equivalent to tf.math.conj(tf.transpose(input)).
  571. Returns
  572. -------
  573. A transposed Tensor.
  574. """
  575. return tf.transpose(a, perm, conjugate)
  576. def gather_nd(params, indices, batch_dims=0):
  577. """
  578. Gather slices from params into a Tensor with shape specified by indices.
  579. Parameters
  580. ----------
  581. params : tensor
  582. The tensor from which to gather values.
  583. indices : tensor
  584. Must be one of the following types: int32, int64. Index tensor.
  585. batch_dims : int
  586. An integer or a scalar 'Tensor'. The number of batch dimensions.
  587. Returns
  588. -------
  589. A Tensor. Has the same type as params.
  590. """
  591. return tf.gather_nd(params, indices, batch_dims)
  592. def clip_by_value(t, clip_value_min, clip_value_max):
  593. """
  594. Clips tensor values to a specified min and max.
  595. Parameters
  596. ----------
  597. t : tensor
  598. A Tensor or IndexedSlices
  599. clip_value_min : tensor
  600. A 0-D (scalar) Tensor, or a Tensor with the same shape as t. The minimum value to clip by
  601. clip_value_max : tensor
  602. A 0-D (scalar) Tensor, or a Tensor with the same shape as t. The minimum value to clip by
  603. Returns
  604. -------
  605. A clipped Tensor or IndexedSlices.
  606. """
  607. return tf.clip_by_value(t, clip_value_min, clip_value_max)
  608. def split(value, num_or_size_splits, axis=0, num=None):
  609. """
  610. Splits a tensor into sub tensors.
  611. Parameters
  612. ----------
  613. value : tensor
  614. The Tensor to split.
  615. num_or_size_splits : list
  616. Either an integer indicating the number of splits along split_dim or a 1-D integer Tensor or
  617. Python list containing the sizes of each output tensor along split_dim.
  618. axis : int
  619. The dimension along which to split. Must be in the range [-rank(value), rank(value)). Defaults to 0.
  620. num : int
  621. used to specify the number of outputs when it cannot be inferred from the shape of size_splits.
  622. Returns
  623. -------
  624. Tensor objects resulting from splitting value.
  625. """
  626. return tf.split(value=value, num_or_size_splits=num_or_size_splits, axis=axis, num=num)
  627. def floor(x):
  628. return tf.floor(x)
  629. def gather(params, indices):
  630. return tf.gather(params, indices)
  631. def linspace(start, stop, num):
  632. return tf.linspace(start, stop, num)
  633. def slice(inputs, starts, sizes):
  634. return tf.slice(inputs, starts, sizes)
  635. def add_n(inputs):
  636. return tf.add_n(inputs)
  637. class OneHot(object):
  638. def __init__(self, depth, on_value, off_value, axis, dtype):
  639. self.depth = depth
  640. self.on_value = on_value
  641. self.off_value = off_value
  642. self.axis = axis
  643. self.dtype = dtype
  644. def __call__(self, inputs, *args, **kwargs):
  645. outputs = tf.one_hot(
  646. inputs, self.depth, on_value=self.on_value, off_value=self.off_value, axis=self.axis, dtype=self.dtype
  647. )
  648. return outputs
  649. class L2Normalize(object):
  650. def __init__(self, axis=None, epsilon=1e-12):
  651. self.axis = axis
  652. self.epsilon = epsilon
  653. def __call__(self, input, *args, **kwargs):
  654. outputs = tf.math.l2_normalize(input, axis=self.axis, epsilon=self.epsilon)
  655. return outputs
  656. class EmbeddingLookup(object):
  657. def __init__(self, max_norm=None):
  658. self.max_norm = max_norm
  659. def __call__(self, params, ids, *args, **kwargs):
  660. outputs = tf.nn.embedding_lookup(params=params, ids=ids, max_norm=self.max_norm)
  661. return outputs
  662. class NCELoss(object):
  663. def __init__(self, num_true=1, sampled_values=None, remove_accidental_hits=False):
  664. self.num_true = num_true
  665. self.sampled_values = sampled_values
  666. self.remove_accidental_hits = remove_accidental_hits
  667. def __call__(self, weights, biases, labels, inputs, num_sampled, num_classes):
  668. outputs = tf.nn.nce_loss(
  669. weights=weights, biases=biases, inputs=inputs, labels=labels, num_sampled=num_sampled,
  670. num_classes=num_classes
  671. )
  672. return outputs
  673. class Not_equal(object):
  674. def __init__(self):
  675. pass
  676. def __call__(self, x, y):
  677. return tf.not_equal(x, y)
  678. class Count_nonzero(object):
  679. def __init__(self, keepdims=None, dtype=int64):
  680. self.keepdims = keepdims
  681. self.dtype = dtype
  682. def __call__(self, input, axis=None):
  683. return tf.math.count_nonzero(input, axis=axis, keepdims=self.keepdims, dtype=self.dtype)
  684. class Resize:
  685. def __init__(self, scale, method, antialias=False, data_format='channels_last', ksize=None):
  686. self.method = method
  687. self.antialias = antialias
  688. self.scale = scale
  689. self.data_format = data_format
  690. def __call__(self, inputs):
  691. if self.data_format == 'channels_first':
  692. inputs = nchw_to_nhwc(inputs)
  693. if len(get_tensor_shape(inputs)) == 4:
  694. output_size = [int(inputs.shape[1] * self.scale[0]), int(inputs.shape[2] * self.scale[1])]
  695. else:
  696. raise ("The inputs shape must be 4-D Tensor.")
  697. outputs = tf.image.resize(inputs, size=output_size, method=self.method, antialias=self.antialias)
  698. if self.data_format == 'channels_first':
  699. outputs = nhwc_to_nchw(outputs)
  700. return outputs
  701. def resize(inputs, output_size, method, antialias):
  702. return tf.image.resize(inputs, size=output_size, method=method, antialias=antialias)
  703. class ZeroPadding1D(object):
  704. def __init__(self, padding):
  705. self.zeropad = tf.keras.layers.ZeroPadding1D(padding=padding)
  706. def __call__(self, inputs):
  707. return self.zeropad(inputs)
  708. class ZeroPadding2D(object):
  709. def __init__(self, padding):
  710. self.zeropad = tf.keras.layers.ZeroPadding2D(padding=padding)
  711. def __call__(self, inputs):
  712. return self.zeropad(inputs)
  713. class ZeroPadding3D(object):
  714. def __init__(self, padding):
  715. self.zeropad = tf.keras.layers.ZeroPadding3D(padding=padding)
  716. def __call__(self, inputs):
  717. return self.zeropad(inputs)
  718. class Sign(object):
  719. def __init__(self):
  720. pass
  721. def __call__(self, x):
  722. return tf.sign(x)
  723. def ceil(x):
  724. return tf.math.ceil(x)
  725. def multiply(x, y):
  726. return tf.multiply(x, y)
  727. def divide(x, y):
  728. return tf.divide(x, y)
  729. def identity(x):
  730. return tf.identity(x)
  731. class BatchToSpace(object):
  732. def __init__(self, block_size, crops):
  733. self.bolock_size = block_size
  734. self.crops = crops
  735. def __call__(self, input_x):
  736. return tf.batch_to_space(input=input_x, block_shape=self.bolock_size, crops=self.crops)
  737. class DepthToSpace(object):
  738. def __init__(self, block_size, data_format='NHWC'):
  739. self.block_size = block_size
  740. self.data_format = data_format
  741. def __call__(self, input):
  742. return tf.nn.depth_to_space(input, block_size=self.block_size, data_format=self.data_format)

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.