You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gradients_test.py 40 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104
  1. # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for tensorflow.ops.gradients."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import sys
  20. import warnings
  21. import numpy as np
  22. from tensorflow.python.client import session
  23. from tensorflow.python.eager import backprop
  24. from tensorflow.python.eager import context
  25. from tensorflow.python.eager import function
  26. from tensorflow.python.framework import constant_op
  27. from tensorflow.python.framework import dtypes
  28. from tensorflow.python.framework import function as framework_function
  29. from tensorflow.python.framework import ops
  30. from tensorflow.python.framework import test_ops
  31. from tensorflow.python.framework import test_util
  32. from tensorflow.python.framework.constant_op import constant
  33. from tensorflow.python.layers import core as core_layers
  34. from tensorflow.python.ops import array_grad # pylint: disable=unused-import
  35. from tensorflow.python.ops import array_ops
  36. from tensorflow.python.ops import control_flow_grad # pylint: disable=unused-import
  37. from tensorflow.python.ops import control_flow_ops
  38. from tensorflow.python.ops import custom_gradient
  39. from tensorflow.python.ops import data_flow_grad # pylint: disable=unused-import
  40. from tensorflow.python.ops import data_flow_ops # pylint: disable=unused-import
  41. from tensorflow.python.ops import functional_ops # pylint: disable=unused-import
  42. from tensorflow.python.ops import gradients
  43. from tensorflow.python.ops import gradients_impl
  44. from tensorflow.python.ops import list_ops
  45. from tensorflow.python.ops import math_grad # pylint: disable=unused-import
  46. from tensorflow.python.ops import math_ops
  47. from tensorflow.python.ops import nn_grad # pylint: disable=unused-import
  48. from tensorflow.python.ops import resource_variable_ops
  49. from tensorflow.python.ops import state_grad # pylint: disable=unused-import
  50. from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
  51. from tensorflow.python.ops import tensor_array_ops
  52. from tensorflow.python.ops import variable_scope
  53. from tensorflow.python.ops import variables
  54. from tensorflow.python.ops.nn_ops import bias_add
  55. from tensorflow.python.platform import googletest
  56. class GradientsTest(test_util.TensorFlowTestCase):
  57. def testGradients(self):
  58. with ops.Graph().as_default():
  59. inp = constant(1.0, shape=[32, 100], name="in")
  60. w = constant(1.0, shape=[100, 10], name="w")
  61. b = constant(1.0, shape=[10], name="b")
  62. xw = math_ops.matmul(inp, w, name="xw")
  63. h = bias_add(xw, b, name="h")
  64. w_grad = gradients.gradients(h, w)[0]
  65. self.assertEquals("MatMul", w_grad.op.type)
  66. self.assertEquals(w_grad.op._original_op, xw.op)
  67. self.assertTrue(w_grad.op.get_attr("transpose_a"))
  68. self.assertFalse(w_grad.op.get_attr("transpose_b"))
  69. def testUnusedOutput(self):
  70. with ops.Graph().as_default():
  71. w = constant(1.0, shape=[2, 2])
  72. x = constant(1.0, shape=[2, 2])
  73. wx = math_ops.matmul(w, x)
  74. split_wx = array_ops.split(value=wx, num_or_size_splits=2, axis=0)
  75. c = math_ops.reduce_sum(split_wx[1])
  76. gw = gradients.gradients(c, [w])[0]
  77. self.assertEquals("MatMul", gw.op.type)
  78. def testColocateGradients(self):
  79. with ops.Graph().as_default() as g:
  80. w = constant(1.0, shape=[1, 1])
  81. x = constant(1.0, shape=[1, 2])
  82. with g.device("/device:GPU:0"):
  83. wx = math_ops.matmul(w, x)
  84. gw = gradients.gradients(wx, [w], colocate_gradients_with_ops=True)[0]
  85. self.assertEqual(gw.op.colocation_groups(), wx.op.colocation_groups())
  86. def testColocateGradientsWithAggregation(self):
  87. with ops.Graph().as_default() as g:
  88. with g.device("/device:GPU:1"):
  89. w = constant(1.0, shape=[1, 1])
  90. x = constant(1.0, shape=[1, 2])
  91. y = constant(1.0, shape=[1, 2])
  92. wx = math_ops.matmul(w, x)
  93. wy = math_ops.matmul(w, y)
  94. with g.device("/device:GPU:0"):
  95. z = wx + wy
  96. gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0]
  97. self.assertEqual(gw1.op.colocation_groups(), wx.op.colocation_groups())
  98. gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0]
  99. self.assertTrue(wx.op.colocation_groups() != gw2.op.colocation_groups())
  100. def testColocateGradientsWithAggregationInMultipleDevices(self):
  101. with ops.Graph().as_default() as g:
  102. with g.device("/device:GPU:1"):
  103. w = constant(1.0, shape=[1, 1])
  104. x = constant(1.0, shape=[1, 2])
  105. y = constant(1.0, shape=[1, 2])
  106. with g.device("/task:1"):
  107. wx = math_ops.matmul(w, x)
  108. with g.device("/task:2"):
  109. wy = math_ops.matmul(w, y)
  110. with g.device("/device:GPU:0"):
  111. z = wx + wy
  112. gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0]
  113. self.assertEqual(gw1.op.colocation_groups(), w.op.colocation_groups())
  114. gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0]
  115. self.assertTrue(w.op.colocation_groups() != gw2.op.colocation_groups())
  116. def testColocateGradientsWithGateGradients(self):
  117. if not test_util.is_gpu_available():
  118. self.skipTest("No GPU available")
  119. with ops.Graph().as_default() as g:
  120. with g.device("/device:CPU:0"):
  121. x = constant(1.0, shape=[1, 1])
  122. y = constant(1.0, shape=[1, 1])
  123. s = x + y
  124. with g.device("/device:GPU:0"):
  125. z = math_ops.reduce_sum(s)
  126. gz_x = gradients.gradients(z, [x], colocate_gradients_with_ops=True,
  127. gate_gradients=True)[0]
  128. with session.Session():
  129. # Make sure the placer doesn't complain.
  130. self.evaluate(gz_x)
  131. def testBoundaryStop(self):
  132. # Test that we don't differentiate 'x'. The gradient function for 'x' is
  133. # set explicitly to None so we will get an exception if the gradient code
  134. # tries to differentiate 'x'.
  135. with ops.Graph().as_default():
  136. c = constant(1.0)
  137. x = array_ops.identity(c)
  138. y = x + 1.0
  139. z = y + 1
  140. grads = gradients.gradients(z, [x])
  141. self.assertTrue(all(x is not None for x in grads))
  142. @test_util.run_v1_only("b/120545219")
  143. def testBoundaryContinue(self):
  144. # Test that we differentiate both 'x' and 'y' correctly when x is a
  145. # predecessor of y.
  146. with self.cached_session():
  147. x = constant(1.0)
  148. y = x * 2.0
  149. z = y * 3.0
  150. grads = gradients.gradients(z, [x, y])
  151. self.assertTrue(all(x is not None for x in grads))
  152. self.assertEqual(6.0, grads[0].eval())
  153. @test_util.run_v1_only("b/120545219")
  154. def testAggregationMethodAccumulateN(self):
  155. with self.cached_session():
  156. x = constant(1.0)
  157. y = x * 2.0
  158. z = y + y + y + y + y + y + y + y + y + y
  159. grads = gradients.gradients(
  160. z, [x, y],
  161. aggregation_method=gradients.AggregationMethod.
  162. EXPERIMENTAL_ACCUMULATE_N)
  163. self.assertTrue(all(x is not None for x in grads))
  164. self.assertEqual(20.0, grads[0].eval())
  165. self.assertEqual(10.0, grads[1].eval())
  166. @test_util.run_v1_only("b/120545219")
  167. def testAggregationMethodAddN(self):
  168. with self.cached_session():
  169. x = constant(1.0)
  170. y = x * 2.0
  171. z = y + y + y + y + y + y + y + y + y + y
  172. grads = gradients.gradients(
  173. z, [x, y], aggregation_method=gradients.AggregationMethod.ADD_N)
  174. self.assertTrue(all(x is not None for x in grads))
  175. self.assertEqual(20.0, grads[0].eval())
  176. self.assertEqual(10.0, grads[1].eval())
  177. @test_util.run_v1_only("b/120545219")
  178. def testAggregationMethodTree(self):
  179. with self.cached_session():
  180. x = constant(1.0)
  181. y = x * 2.0
  182. z = y + y + y + y + y + y + y + y + y + y
  183. grads = gradients.gradients(
  184. z, [x, y],
  185. aggregation_method=gradients.AggregationMethod.EXPERIMENTAL_TREE)
  186. self.assertTrue(all(x is not None for x in grads))
  187. self.assertEqual(20.0, grads[0].eval())
  188. self.assertEqual(10.0, grads[1].eval())
  189. def testNoGradientForStringOutputs(self):
  190. with ops.Graph().as_default():
  191. def _TestOpGrad(_, float_grad, string_grad):
  192. """Gradient function for TestStringOutput."""
  193. self.assertEquals(float_grad.dtype, dtypes.float32)
  194. self.assertFalse(string_grad)
  195. return float_grad
  196. ops.RegisterGradient("TestStringOutput")(_TestOpGrad)
  197. c = constant(1.0)
  198. x, _ = test_ops.test_string_output(c)
  199. z = x * 2.0
  200. w = z * 3.0
  201. grads = gradients.gradients(z, [c])
  202. self.assertTrue(isinstance(grads[0], ops.Tensor))
  203. grads = gradients.gradients(w, [c])
  204. self.assertTrue(isinstance(grads[0], ops.Tensor))
  205. def testSingletonIndexedSlices(self):
  206. with ops.Graph().as_default():
  207. x = array_ops.placeholder(dtypes.float32)
  208. y = array_ops.identity(x)
  209. dy = ops.IndexedSlices(
  210. array_ops.placeholder(dtypes.float32),
  211. array_ops.placeholder(dtypes.int32))
  212. dx, = gradients.gradients(y, x, grad_ys=dy)
  213. # The IndexedSlices gradient of tf.identity is the identity map.
  214. with self.cached_session() as sess:
  215. vdx, vdy = sess.run(
  216. [dx, dy], feed_dict={x: [1.0], dy.indices: [0], dy.values: [2.0]})
  217. self.assertEqual(vdx, vdy)
  218. @test_util.run_v1_only("b/120545219")
  219. def testNonDifferentiableSwitchInWhileLoop(self):
  220. with ops.Graph().as_default():
  221. v = array_ops.placeholder(dtypes.float32, [])
  222. def _Step(i, a, ta):
  223. a += math_ops.cast(v, dtypes.int32)
  224. return (i + 1, a, ta.write(i, a))
  225. n = 4
  226. i, _, ta = control_flow_ops.while_loop(
  227. lambda i, *_: i < n,
  228. _Step, [0, 0, tensor_array_ops.TensorArray(
  229. dtypes.int32, size=n)])
  230. target = ta.read(i - 1)
  231. grad, = gradients.gradients(target, v)
  232. self.assertIsNone(grad)
  233. def testVariableReadValueGradient(self):
  234. with ops.Graph().as_default():
  235. init = constant_op.constant(100.0)
  236. var = variables.Variable(init)
  237. gradient = gradients.gradients(var.read_value(), var)
  238. self.assertIsNotNone(gradient)
  239. def testVariableAsGraphElementGradient(self):
  240. with ops.Graph().as_default() as graph:
  241. init = constant_op.constant(100.0)
  242. var = variables.Variable(init)
  243. gradient = gradients.gradients(graph.as_graph_element(var), var)
  244. self.assertIsNotNone(gradient)
  245. @test_util.run_v1_only("b/120545219")
  246. def testVariableRefGradient(self):
  247. with ops.Graph().as_default():
  248. init = constant_op.constant(100.0)
  249. var = variables.VariableV1(init)
  250. gradient = gradients.gradients(var._ref(), var)
  251. self.assertIsNotNone(gradient)
  252. @test_util.run_v1_only("b/120545219")
  253. def testDependentYs(self):
  254. with self.cached_session():
  255. x = constant_op.constant(3.0)
  256. y = math_ops.square(x)
  257. y1 = math_ops.square(y)
  258. y2 = math_ops.square(y1)
  259. g = gradients.gradients([y, y2], x)
  260. self.assertAllClose(17502.0, g[0].eval())
  261. g = gradients.gradients(y + y2, x)
  262. self.assertAllClose(17502.0, g[0].eval())
  263. z = array_ops.identity(y)
  264. z2 = array_ops.identity(y2)
  265. g = gradients.gradients([z, z2], x)
  266. self.assertAllClose(17502.0, g[0].eval())
  267. @test_util.run_v1_only("b/120545219")
  268. def testPartialDerivatives(self):
  269. with self.cached_session():
  270. x = constant_op.constant(1.)
  271. y = 2 * x
  272. z = x + y
  273. totalg = gradients.gradients(z, [x, y])
  274. self.assertEqual([3.0, 1.0], [g.eval() for g in totalg])
  275. partialg = gradients.gradients(z, [x, y], stop_gradients=[x, y])
  276. self.assertEqual([1.0, 1.0], [g.eval() for g in partialg])
  277. @test_util.run_v1_only("b/120545219")
  278. def testStopGradients(self):
  279. def _MakeGraph(rng, stop_gradients=()):
  280. def _FunctionOf(xs, k=3):
  281. return ops.convert_to_tensor(
  282. sum(math_ops.matmul(rng.rand(k, k), x) for x in xs)
  283. + rng.rand(k, k))
  284. a = _FunctionOf([])
  285. if "a" in stop_gradients: a = array_ops.stop_gradient(a)
  286. b = _FunctionOf([a])
  287. if "b" in stop_gradients: b = array_ops.stop_gradient(b)
  288. c = _FunctionOf([a, b])
  289. if "c" in stop_gradients: c = array_ops.stop_gradient(c)
  290. d = _FunctionOf([b, c])
  291. if "d" in stop_gradients: d = array_ops.stop_gradient(d)
  292. return dict(a=a, b=b, c=c, d=d)
  293. def _Gradients(ys, xs, **kwargs):
  294. dydxs = gradients.gradients(ys, xs, **kwargs)
  295. dydxs = [0. * x if dydx is None else dydx
  296. for x, dydx in zip(xs, dydxs)]
  297. return dydxs
  298. seed = np.random.randint(1000)
  299. cases = []
  300. subsets = [""] + "a b c d ab ac ad bc bd cd abc abd acd bcd abcd".split()
  301. graph = _MakeGraph(np.random.RandomState(seed))
  302. for constants in subsets:
  303. graph_with_stops = _MakeGraph(np.random.RandomState(seed), constants)
  304. for variables_ in subsets:
  305. # compute the gradient when stopped using tf.stop_gradients
  306. grad1 = _Gradients([graph_with_stops["d"]],
  307. [graph_with_stops[v] for v in variables_])
  308. # compute the gradient when stopped using the stop_gradients kwarg
  309. grad2 = _Gradients([graph["d"]],
  310. [graph[v] for v in variables_],
  311. stop_gradients=[graph[v] for v in constants])
  312. cases.append(dict(grad1=grad1, grad2=grad2,
  313. constants=constants, variables=variables_))
  314. # evaluate all tensors in one call to session.run for speed
  315. with self.cached_session() as sess:
  316. results = sess.run([(case["grad1"], case["grad2"]) for case in cases])
  317. for (npgrad1, npgrad2), case in zip(results, cases):
  318. for a, b in zip(npgrad1, npgrad2):
  319. np.testing.assert_allclose(a, b)
  320. def testUnconnectedGradientsNoneUnconnectedGradients(self):
  321. with ops.Graph().as_default():
  322. x = constant(1.0, shape=[2, 2])
  323. y = constant(3.0, shape=[3, 1])
  324. grad = gradients.gradients(
  325. [y], [x], unconnected_gradients="none")
  326. self.assertIsNone(grad[0])
  327. def testUnconnectedGradientsZerosUnconnectedGradients(self):
  328. with ops.Graph().as_default():
  329. x = constant(1.0, shape=[2, 2])
  330. y = constant(3.0, shape=[3, 1])
  331. grads = gradients.gradients(
  332. [y], [x], unconnected_gradients="zero")
  333. with self.cached_session() as sess:
  334. self.assertAllEqual([[0.0, 0.0], [0.0, 0.0]], self.evaluate(grads)[0])
  335. def testUnconnectedGradientsZeroConnectedGradients(self):
  336. with ops.Graph().as_default():
  337. x = constant(1.0)
  338. y = x * 3.0
  339. grad = gradients.gradients(
  340. [y], [x], unconnected_gradients="zero")
  341. with self.cached_session() as sess:
  342. self.assertEquals(3.0, self.evaluate(grad)[0])
  343. def testUnknownUnconnectedGradientsValueGiven(self):
  344. with ops.Graph().as_default():
  345. x = constant(1.0)
  346. y = constant(1.0)
  347. with self.assertRaisesRegexp(
  348. ValueError, "Unknown value for unconnected_gradients: 'nonsense'"):
  349. gradients.gradients([y], [x], unconnected_gradients="nonsense")
  350. class FunctionGradientsTest(test_util.TensorFlowTestCase):
  351. @classmethod
  352. def XSquarePlusB(cls, x, b):
  353. return x * x + b
  354. @classmethod
  355. def XSquarePlusBGradient(cls, x, b, g):
  356. # Perturb gradients (multiply by 2), so we can test that this was called.
  357. g *= 2.0
  358. return g * 2.0 * x, g
  359. @classmethod
  360. def _PythonGradient(cls, op, grad):
  361. # Perturb gradients (multiply by 3), so we can test that this was called.
  362. grad *= 3.0
  363. return grad * op.inputs[0] * 2.0, grad
  364. @classmethod
  365. def _GetFunc(cls, **kwargs):
  366. return framework_function.Defun(dtypes.float32, dtypes.float32, **
  367. kwargs)(cls.XSquarePlusB)
  368. def _GetFuncGradients(self, f, x_value, b_value):
  369. x = constant_op.constant(x_value, name="x")
  370. b = constant_op.constant(b_value, name="b")
  371. y = f(x, b)
  372. grads = gradients.gradients(y, [x, b])
  373. with self.cached_session() as sess:
  374. return sess.run(grads)
  375. def testFunctionGradientsBasic(self):
  376. g = ops.Graph()
  377. with g.as_default():
  378. f = self._GetFunc()
  379. # Get gradients (should add SymbolicGradient node for function).
  380. grads = self._GetFuncGradients(f, [2.0], [1.0])
  381. self.assertAllEqual([4.0], grads[0])
  382. self.assertAllEqual([1.0], grads[1])
  383. def testFunctionGradientsComposition(self):
  384. with ops.Graph().as_default():
  385. f = self._GetFunc()
  386. x = constant_op.constant([2.0], name="x")
  387. b1 = constant_op.constant([1.0], name="b1")
  388. b2 = constant_op.constant([1.0], name="b2")
  389. y = f(f(x, b1), b2)
  390. # Build gradient graph (should add SymbolicGradient node for function).
  391. grads = gradients.gradients(y, [x, b1])
  392. with self.cached_session() as sess:
  393. self.assertAllEqual([40.0], self.evaluate(grads)[0])
  394. self.assertAllEqual([10.0], self.evaluate(grads)[1])
  395. def testFunctionGradientsWithGradFunc(self):
  396. g = ops.Graph()
  397. with g.as_default():
  398. grad_func = framework_function.Defun(dtypes.float32, dtypes.float32,
  399. dtypes.float32)(
  400. self.XSquarePlusBGradient)
  401. f = self._GetFunc(grad_func=grad_func)
  402. # Get gradients (should add SymbolicGradient node for function, which
  403. # uses the grad_func above, which multiplies all gradients by 2).
  404. grads = self._GetFuncGradients(f, [2.0], [1.0])
  405. self.assertAllEqual([4.0 * 2], grads[0])
  406. self.assertAllEqual([1.0 * 2], grads[1])
  407. def testFunctionGradientWithRegistration(self):
  408. g = ops.Graph()
  409. with g.as_default():
  410. f = self._GetFunc(python_grad_func=self._PythonGradient)
  411. # Get gradients, using the python gradient function. It multiplies the
  412. # gradients by 3.
  413. grads = self._GetFuncGradients(f, [2.0], [1.0])
  414. self.assertAllEqual([4.0 * 3], grads[0])
  415. self.assertAllEqual([1.0 * 3], grads[1])
  416. def testFunctionGradientWithGradFuncAndRegistration(self):
  417. g = ops.Graph()
  418. with g.as_default():
  419. grad_func = framework_function.Defun(dtypes.float32, dtypes.float32,
  420. dtypes.float32)(
  421. self.XSquarePlusBGradient)
  422. with self.assertRaisesRegexp(ValueError, "Gradient defined twice"):
  423. f = self._GetFunc(
  424. grad_func=grad_func, python_grad_func=self._PythonGradient)
  425. f.add_to_graph(ops.Graph())
  426. def testGradientWrtCaptured(self):
  427. with ops.Graph().as_default():
  428. x = constant_op.constant(1.0, name="x")
  429. @function.defun()
  430. def Foo():
  431. y = math_ops.multiply(x, 2.0, name="y")
  432. g = gradients_impl.gradients(y, x)
  433. return g[0]
  434. f = Foo()
  435. with self.cached_session() as sess:
  436. self.assertEqual(self.evaluate(f), 2.0)
  437. def testGradientOfCaptured(self):
  438. with ops.Graph().as_default():
  439. x = constant_op.constant(1.0, name="x")
  440. y = math_ops.multiply(x, 2.0, name="y")
  441. @framework_function.Defun()
  442. def Foo():
  443. g = gradients_impl.gradients(y, x)
  444. return g[0]
  445. f = Foo()
  446. with self.cached_session() as sess:
  447. self.assertEqual(self.evaluate(f), 2.0)
  448. def testCapturedResourceVariable(self):
  449. with ops.Graph().as_default():
  450. var = resource_variable_ops.ResourceVariable(1.0, name="var")
  451. @function.defun()
  452. def Foo():
  453. y = math_ops.multiply(var, 2.0, name="y")
  454. g = gradients_impl.gradients(y, var)
  455. return g[0]
  456. f = Foo()
  457. with self.cached_session() as sess:
  458. self.evaluate(variables.global_variables_initializer())
  459. self.assertEqual(self.evaluate(f), 2.0)
  460. def testCapturedNested(self):
  461. with ops.Graph().as_default():
  462. x1 = constant_op.constant(1.0, name="x1")
  463. x2 = constant_op.constant(2.0, name="x2")
  464. x3 = math_ops.multiply(x1, x2, name="x3")
  465. @function.defun()
  466. def Outer():
  467. outer1 = array_ops.identity(x1, name="outer1")
  468. @function.defun()
  469. def Inner():
  470. inner1 = array_ops.identity(outer1, name="inner1")
  471. inner2 = array_ops.identity(x2, name="inner2")
  472. inner3 = array_ops.identity(x3, name="inner3")
  473. return gradients_impl.gradients([inner1, inner2, inner3, x1],
  474. [x1, x2])
  475. return Inner()
  476. x1_grad, x2_grad = Outer()
  477. with self.cached_session() as sess:
  478. # 1.0 + None + 2.0 + 1.0 = 4.0
  479. self.assertEqual(self.evaluate(x1_grad), 4.0)
  480. # None + 1.0 + 1.0 + None = 2.0
  481. self.assertEqual(self.evaluate(x2_grad), 2.0)
  482. def testCapturedFromFunction(self):
  483. with ops.Graph().as_default():
  484. x = constant_op.constant(1.0, name="x")
  485. @function.defun()
  486. def Outer():
  487. y = math_ops.multiply(x, 2.0, name="y")
  488. @function.defun()
  489. def Inner():
  490. z = math_ops.multiply(y, 3.0, name="z")
  491. g = gradients_impl.gradients(z, y)
  492. return g[0]
  493. return Inner()
  494. z_grad = Outer()
  495. with self.cached_session() as sess:
  496. self.assertEqual(self.evaluate(z_grad), 3.0)
  497. def testCapturedEagerTensors(self):
  498. # Test that we can handle captured eager tensors unrelated to the gradient
  499. # computation (i.e. we need to ignore them).
  500. # TODO(skyewm): make it an error if you try to take the gradient wrt a
  501. # captured EagerTensor
  502. with context.eager_mode():
  503. c = constant_op.constant(2.0, name="c")
  504. @function.defun
  505. def Foo():
  506. x = constant_op.constant(10.0, name="x")
  507. y = math_ops.multiply(x, c, name="y")
  508. z = math_ops.multiply(y, 3.0, name="z")
  509. g = gradients_impl.gradients(z, x)
  510. return g[0]
  511. self.assertEqual(Foo().numpy(), 6.0)
  512. class StopGradientTest(test_util.TensorFlowTestCase):
  513. def testStopGradient(self):
  514. with ops.Graph().as_default():
  515. inp = constant(1.0, shape=[100, 32], name="in")
  516. out = array_ops.stop_gradient(inp)
  517. igrad = gradients.gradients(out, inp)[0]
  518. assert igrad is None
  519. class PreventGradientTest(test_util.TensorFlowTestCase):
  520. def testPreventGradient(self):
  521. with ops.Graph().as_default():
  522. inp = constant(1.0, shape=[100, 32], name="in")
  523. out = array_ops.prevent_gradient(inp)
  524. with self.assertRaisesRegexp(LookupError, "explicitly disabled"):
  525. _ = gradients.gradients(out, inp)
  526. class HessianVectorProductTest(test_util.TensorFlowTestCase):
  527. @test_util.run_v1_only("b/120545219")
  528. def testHessianVectorProduct(self):
  529. # Manually compute the Hessian explicitly for a low-dimensional problem
  530. # and check that HessianVectorProduct matches multiplication by the
  531. # explicit Hessian.
  532. # Specifically, the Hessian of f(x) = x^T A x is
  533. # H = A + A^T.
  534. # We expect HessianVectorProduct(f(x), x, v) to be H v.
  535. m = 4
  536. rng = np.random.RandomState([1, 2, 3])
  537. mat_value = rng.randn(m, m).astype("float32")
  538. v_value = rng.randn(m, 1).astype("float32")
  539. x_value = rng.randn(m, 1).astype("float32")
  540. hess_value = mat_value + mat_value.T
  541. hess_v_value = np.dot(hess_value, v_value)
  542. for use_gpu in [False, True]:
  543. with self.cached_session(use_gpu=use_gpu):
  544. mat = constant_op.constant(mat_value)
  545. v = constant_op.constant(v_value)
  546. x = constant_op.constant(x_value)
  547. mat_x = math_ops.matmul(mat, x, name="Ax")
  548. x_mat_x = math_ops.matmul(array_ops.transpose(x), mat_x, name="xAx")
  549. hess_v = gradients_impl._hessian_vector_product(x_mat_x, [x], [v])[0]
  550. hess_v_actual = self.evaluate(hess_v)
  551. self.assertAllClose(hess_v_value, hess_v_actual)
  552. class HessianTest(test_util.TensorFlowTestCase):
  553. @test_util.run_v1_only("b/120545219")
  554. def testHessian1D(self):
  555. # Manually compute the Hessian explicitly for a low-dimensional problem
  556. # and check that `hessian` matches. Specifically, the Hessian of
  557. # f(x) = x^T A x is H = A + A^T.
  558. m = 4
  559. rng = np.random.RandomState([1, 2, 3])
  560. mat_value = rng.randn(m, m).astype("float32")
  561. x_value = rng.randn(m).astype("float32")
  562. hess_value = mat_value + mat_value.T
  563. with self.session(use_gpu=True):
  564. mat = constant_op.constant(mat_value)
  565. x = constant_op.constant(x_value)
  566. x_mat_x = math_ops.reduce_sum(x[:, None] * mat * x[None, :])
  567. hess = gradients.hessians(x_mat_x, x)[0]
  568. hess_actual = self.evaluate(hess)
  569. self.assertAllClose(hess_value, hess_actual)
  570. @test_util.run_v1_only("b/120545219")
  571. def testHessian1D_multi(self):
  572. # Test the computation of the hessian with respect to multiple tensors
  573. m = 4
  574. n = 3
  575. rng = np.random.RandomState([1, 2, 3])
  576. mat_values = [rng.randn(m, m).astype("float32") for _ in range(n)]
  577. x_values = [rng.randn(m).astype("float32") for _ in range(n)]
  578. hess_values = [mat_value + mat_value.T for mat_value in mat_values]
  579. with self.session(use_gpu=True):
  580. mats = [constant_op.constant(mat_value) for mat_value in mat_values]
  581. xs = [constant_op.constant(x_value) for x_value in x_values]
  582. xs_mats_xs = [
  583. math_ops.reduce_sum(x[:, None] * mat * x[None, :])
  584. for x, mat in zip(xs, mats)
  585. ]
  586. hessians = gradients.hessians(xs_mats_xs, xs)
  587. hessians_actual = [hess.eval() for hess in hessians]
  588. for hess_value, hess_actual in zip(hess_values, hessians_actual):
  589. self.assertAllClose(hess_value, hess_actual)
  590. @test_util.run_v1_only("b/120545219")
  591. def testHessianInvalidDimension(self):
  592. for shape in [(10, 10), None]:
  593. with self.cached_session(use_gpu=True):
  594. x = array_ops.placeholder(dtypes.float32, shape)
  595. # Expect a ValueError because the dimensions are wrong
  596. with self.assertRaises(ValueError):
  597. gradients.hessians(x, x)
  598. @test_util.run_v1_only("b/120545219")
  599. def testHessian2D_square_matrix(self):
  600. # Manually compute the Hessian explicitly for a low-dimensional problem
  601. # and check that `hessian` matches. Specifically, the Hessian of
  602. # f(x) = 1/2 * x^T * x is H = constant (block identity matrix)
  603. m = 3
  604. rng = np.random.RandomState([1, 2, 3])
  605. x_value = rng.randn(m, m).astype("float32")
  606. with self.session(use_gpu=True):
  607. x = constant_op.constant(x_value)
  608. x_square = math_ops.reduce_sum(
  609. math_ops.matmul(array_ops.transpose(x), x) * 0.5
  610. )
  611. hess = gradients.hessians(x_square, x)[0]
  612. hess_actual = self.evaluate(hess)
  613. hess_value = np.bmat([
  614. [elem*np.ones((m, m)) for elem in vec]
  615. for vec in np.eye(m)
  616. ]).astype("float32")
  617. self.assertAllEqual((m, m, m, m), hess_actual.shape)
  618. self.assertAllClose(hess_value, hess_actual.reshape((m * m, m * m)))
  619. @test_util.run_v1_only("b/120545219")
  620. def testHessian2D_non_square_matrix(self):
  621. m = 3
  622. n = 4
  623. rng = np.random.RandomState([1, 2, 3])
  624. x_value = rng.randn(m, n).astype("float32")
  625. with self.session(use_gpu=True):
  626. x = constant_op.constant(x_value)
  627. x_square = math_ops.reduce_sum(
  628. math_ops.matmul(array_ops.transpose(x), x) * 0.5
  629. )
  630. hess = gradients.hessians(x_square, x)[0]
  631. hess_actual = self.evaluate(hess)
  632. hess_value = np.bmat([
  633. [elem*np.ones((n, n)) for elem in vec]
  634. for vec in np.eye(m)
  635. ]).astype("float32")
  636. self.assertAllEqual((m, n, m, n), hess_actual.shape)
  637. self.assertAllClose(hess_value, hess_actual.reshape((m * n, m * n)))
  638. class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase):
  639. @test_util.run_v1_only("b/120545219")
  640. def testIndexedSlicesToTensor(self):
  641. with self.cached_session():
  642. np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
  643. c = constant_op.constant(np_val)
  644. c_sparse = math_ops._as_indexed_slices(c)
  645. self.assertAllEqual(np_val.shape, c_sparse.dense_shape.eval())
  646. c_dense = math_ops.multiply(c_sparse, 1.0)
  647. self.assertAllClose(np_val, self.evaluate(c_dense))
  648. @test_util.run_v1_only("b/120545219")
  649. def testIndexedSlicesToTensorList(self):
  650. with self.cached_session():
  651. numpy_list = []
  652. dense_list = []
  653. sparse_list = []
  654. for _ in range(3):
  655. np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
  656. c = constant_op.constant(np_val)
  657. c_sparse = math_ops._as_indexed_slices(c)
  658. numpy_list.append(np_val)
  659. dense_list.append(c)
  660. sparse_list.append(c_sparse)
  661. packed_dense = array_ops.stack(dense_list)
  662. packed_sparse = array_ops.stack(sparse_list)
  663. self.assertAllClose(packed_dense.eval(), self.evaluate(packed_sparse))
  664. @test_util.run_v1_only("b/120545219")
  665. def testInt64Indices(self):
  666. with self.cached_session():
  667. np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
  668. c = constant_op.constant(np_val)
  669. c_sparse = math_ops._as_indexed_slices(c)
  670. c_sparse = ops.IndexedSlices(
  671. c_sparse.values,
  672. math_ops.cast(c_sparse.indices, dtypes.int64), c_sparse.dense_shape)
  673. self.assertAllEqual(np_val.shape, c_sparse.dense_shape.eval())
  674. c_dense = math_ops.multiply(c_sparse, 1.0)
  675. self.assertAllClose(np_val, self.evaluate(c_dense))
  676. @test_util.run_v1_only("b/120545219")
  677. def testWarnings(self):
  678. # TODO(gunan) Reenable after this issue is fixed:
  679. # https://github.com/google/protobuf/issues/2812
  680. if sys.version_info >= (3, 5):
  681. self.skipTest("Skipped test for Python 3.5+")
  682. # Smaller than the threshold: no warning.
  683. c_sparse = ops.IndexedSlices(
  684. array_ops.placeholder(dtypes.float32),
  685. array_ops.placeholder(dtypes.int32), constant([4, 4, 4, 4]))
  686. with warnings.catch_warnings(record=True) as w:
  687. math_ops.multiply(c_sparse, 1.0)
  688. self.assertEqual(0, len(w))
  689. # Greater than or equal to the threshold: warning.
  690. c_sparse = ops.IndexedSlices(
  691. array_ops.placeholder(dtypes.float32),
  692. array_ops.placeholder(dtypes.int32), constant([100, 100, 100, 100]))
  693. # "always" filter prevents the warning from being suppressed if it was
  694. # already triggered in a different test.
  695. warnings.simplefilter("always")
  696. with warnings.catch_warnings(record=True) as w:
  697. math_ops.multiply(c_sparse, 1.0)
  698. self.assertEqual(1, len(w))
  699. self.assertTrue(
  700. "with 100000000 elements. This may consume a large amount of memory." in
  701. str(w[0].message))
  702. # Unknown dense shape: warning.
  703. c_sparse = ops.IndexedSlices(
  704. array_ops.placeholder(dtypes.float32),
  705. array_ops.placeholder(dtypes.int32),
  706. array_ops.placeholder(dtypes.int32))
  707. with warnings.catch_warnings(record=True) as w:
  708. math_ops.multiply(c_sparse, 1.0)
  709. self.assertEqual(1, len(w))
  710. self.assertTrue(
  711. "of unknown shape. This may consume a large amount of memory." in
  712. str(w[0].message))
  713. class OnlyRealGradientsTest(test_util.TensorFlowTestCase):
  714. @test_util.run_v1_only("b/120545219")
  715. def testRealOnly(self):
  716. x = constant_op.constant(7+3j, dtype=dtypes.complex64)
  717. y = math_ops.square(x)
  718. with self.assertRaisesRegexp(
  719. TypeError,
  720. r"Gradients of complex tensors must set grad_ys "
  721. r"\(y\.dtype = tf\.complex64\)"):
  722. gradients.gradients(y, x)
  723. class ResourceCondTest(test_util.TensorFlowTestCase):
  724. @test_util.run_v1_only("b/120545219")
  725. def testBasic(self):
  726. gamma = resource_variable_ops.ResourceVariable(
  727. np.random.random((3,)),
  728. dtype="float32", name="gamma")
  729. inputs = array_ops.ones(shape=(3,), dtype="float32")
  730. def TestFn():
  731. output = inputs + gamma
  732. return output
  733. training = array_ops.placeholder_with_default(True, shape=())
  734. output = control_flow_ops.cond(
  735. training, TestFn, lambda: inputs)
  736. loss = output
  737. grads = gradients.gradients(
  738. loss, [gamma])
  739. self.assertTrue(None not in grads)
  740. class CustomGradientTest(test_util.TensorFlowTestCase):
  741. def testCustomGradientTrivial(self):
  742. @custom_gradient.custom_gradient
  743. def MyIdentity(x):
  744. def Grad(dy):
  745. return [3 * dy]
  746. return x, Grad
  747. with ops.Graph().as_default():
  748. x = constant(3.)
  749. y = MyIdentity(MyIdentity(x))
  750. dy = gradients.gradients(y, x)[0]
  751. with session.Session():
  752. self.assertEqual(9., self.evaluate(dy))
  753. def testCustomGradient(self):
  754. @custom_gradient.custom_gradient
  755. def MyMultiply(x1, x2):
  756. result = x1 * x2
  757. def Grad(dy):
  758. # Switched the ordering here.
  759. return [dy * x1, dy * x2]
  760. return result, Grad
  761. with ops.Graph().as_default():
  762. x1 = constant(3.)
  763. x2 = constant(5.)
  764. y = MyMultiply(x1, x2)
  765. dy = gradients.gradients(y, [x1, x2])
  766. with session.Session() as sess:
  767. self.assertAllEqual([3., 5.], self.evaluate(dy))
  768. def testCustomGradientErrors(self):
  769. @custom_gradient.custom_gradient
  770. def F(x):
  771. def Grad(_):
  772. raise RuntimeError("x")
  773. return x, Grad
  774. with ops.Graph().as_default():
  775. x = constant(1.0)
  776. y = F(x)
  777. with self.assertRaises(RuntimeError):
  778. gradients.gradients(y, x)
  779. def testCustomGradientWithVariables(self):
  780. @custom_gradient.custom_gradient
  781. def F(x):
  782. out = core_layers.dense(x, 3, use_bias=False)
  783. def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
  784. self.assertEqual(1, len(variables))
  785. grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
  786. return grads[0], [array_ops.ones((4, 3))]
  787. return out, Grad
  788. with ops.Graph().as_default():
  789. x = array_ops.ones((2, 4))
  790. with variable_scope.variable_scope("f", use_resource=True) as vs:
  791. y = F(x)
  792. all_vars = vs.global_variables()
  793. assert len(all_vars) == 1
  794. grads = gradients.gradients(y, [x, all_vars[0]])
  795. for g in grads:
  796. self.assertTrue(g is not None)
  797. with session.Session() as sess:
  798. self.evaluate(variables.global_variables_initializer())
  799. dw = sess.run(math_ops.reduce_sum(grads[1]))
  800. self.assertEqual(12., dw)
  801. def testCustomGradientWithVariablesEager(self):
  802. with context.eager_mode():
  803. layer = core_layers.Dense(4, use_bias=False)
  804. @custom_gradient.custom_gradient
  805. def F(x):
  806. out = layer(x)
  807. def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
  808. del out_grad
  809. self.assertEqual(1, len(variables))
  810. return (array_ops.ones((3, 2)),
  811. [array_ops.ones((2, 4))])
  812. return out, Grad
  813. x = array_ops.ones((3, 2)) + 2.
  814. with backprop.GradientTape() as tape:
  815. tape.watch(x)
  816. y = F(x)
  817. w, = layer.variables
  818. dx, dw = tape.gradient(y, [x, w])
  819. self.assertEqual(6., math_ops.reduce_sum(dx).numpy())
  820. self.assertEqual(8., math_ops.reduce_sum(dw).numpy())
  821. @test_util.run_v1_only("b/120545219")
  822. def testCustomGradientErrorsWithNonResourceVariables(self):
  823. def F(x, use_resource=False):
  824. with variable_scope.variable_scope("f", use_resource=use_resource):
  825. out = core_layers.dense(x, 4, use_bias=False)
  826. def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name
  827. del out_grad
  828. self.assertEqual(1, len(variables))
  829. return (array_ops.ones((3, 2)), [array_ops.ones((2, 4))])
  830. return out, Grad
  831. @custom_gradient.custom_gradient
  832. def FResource(x):
  833. return F(x, use_resource=True)
  834. @custom_gradient.custom_gradient
  835. def FNonResource(x):
  836. return F(x, use_resource=False)
  837. x = array_ops.ones((3, 2)) + 2.
  838. # Wrapping scope has use_resource=True but inner scope sets to False. Fails.
  839. with variable_scope.variable_scope("vs1", use_resource=True):
  840. with self.assertRaisesWithPredicateMatch(TypeError,
  841. "must be `ResourceVariable`s"):
  842. FNonResource(x)
  843. # Wrapping scope has use_resource=False but inner scope sets to True.
  844. # Passes.
  845. with variable_scope.variable_scope("vs2", use_resource=False):
  846. FResource(x)
  847. def testWithNumpyInputs(self):
  848. with context.eager_mode():
  849. @custom_gradient.custom_gradient
  850. def F(x):
  851. out = x
  852. def Grad(_):
  853. return (None, None)
  854. return out, Grad
  855. x = np.ones((3, 2), dtype=np.float32)
  856. # Smoke test to ensure numpy inputs are accepted
  857. F(x)
  858. @test_util.run_v1_only("b/120545219")
  859. def testRVGradientsDynamicCond(self):
  860. with self.cached_session():
  861. alpha = resource_variable_ops.ResourceVariable(
  862. np.random.random((1,)),
  863. dtype="float32")
  864. conditional = array_ops.placeholder_with_default(True, shape=())
  865. output = control_flow_ops.cond(
  866. conditional, lambda: alpha * 2, lambda: alpha * 3)
  867. g, = gradients_impl.gradients(output, alpha)
  868. self.evaluate(variables.global_variables_initializer())
  869. self.assertAllEqual(g.eval(), [2.0])
  870. self.assertAllEqual(g.eval(feed_dict={conditional: False}), [3.0])
  871. class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase):
  872. def _assert_indexed_slices_equal(self, left, right):
  873. self.assertAllEqual(
  874. self.evaluate(ops.convert_to_tensor(left)),
  875. self.evaluate(ops.convert_to_tensor(right)))
  876. def testNoGradients(self):
  877. self.assertIsNone(gradients_impl._AggregateIndexedSlicesGradients([]))
  878. def testOneGradient(self):
  879. t = math_ops._as_indexed_slices(constant_op.constant(
  880. [[1., 2.], [0, 0], [3., 4.]]))
  881. result = gradients_impl._AggregateIndexedSlicesGradients([t])
  882. self._assert_indexed_slices_equal(t, result)
  883. def testMultipleGradients(self):
  884. t0 = math_ops._as_indexed_slices(constant_op.constant(
  885. [[1., 2.], [0, 0], [3., 4.]]))
  886. t1 = math_ops._as_indexed_slices(constant_op.constant(
  887. [[0., 0.], [5, 6], [7., 8.]]))
  888. total = constant_op.constant(
  889. [[1., 2.], [5, 6], [10., 12.]])
  890. result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1])
  891. self._assert_indexed_slices_equal(total, result)
  892. def testMultipleGradientsWithNones(self):
  893. t0 = math_ops._as_indexed_slices(constant_op.constant(
  894. [[1., 2.], [0, 0], [3., 4.]]))
  895. t1 = math_ops._as_indexed_slices(constant_op.constant(
  896. [[0., 0.], [5, 6], [7., 8.]]))
  897. t3 = None
  898. total = constant_op.constant(
  899. [[1., 2.], [5, 6], [10., 12.]])
  900. result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1, t3])
  901. self._assert_indexed_slices_equal(total, result)
  902. def testMixedTensorAndIndexedSlices(self):
  903. t0 = math_ops._as_indexed_slices(constant_op.constant(
  904. [[1., 2.], [0, 0], [3., 4.]]))
  905. t1 = constant_op.constant(
  906. [[0., 0.], [5, 6], [7., 8.]])
  907. total = constant_op.constant(
  908. [[1., 2.], [5, 6], [10., 12.]])
  909. result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1])
  910. self._assert_indexed_slices_equal(total, result)
  911. class TensorListGradientsTest(test_util.TensorFlowTestCase):
  912. def testDefaultGradYs(self):
  913. with ops.Graph().as_default():
  914. tl = list_ops.empty_tensor_list(
  915. element_dtype=dtypes.float32,
  916. element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
  917. a = constant(1.0)
  918. tl = list_ops.tensor_list_push_back(tl, a)
  919. grad_tl = list_ops.empty_tensor_list(
  920. element_dtype=dtypes.float32,
  921. element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
  922. grad_tl = list_ops.tensor_list_push_back(tl, constant(5.0))
  923. grad = gradients.gradients(tl, a, grad_ys=grad_tl)[0]
  924. with self.cached_session() as sess:
  925. self.assertEquals(self.evaluate(grad), 5.)
  926. if __name__ == "__main__":
  927. googletest.main()