You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GradientTest.cs 32 kB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow.NumPy;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using Tensorflow;
  7. using static Tensorflow.Binding;
  8. using Tensorflow.Framework;
  9. namespace TensorFlowNET.UnitTest.Gradient
  10. {
  11. [TestClass]
  12. public class GradientTest : GraphModeTestBase
  13. {
  14. [TestMethod]
  15. public void BroadcastToGrad()
  16. {
  17. var x = tf.constant(2, dtype: dtypes.float32);
  18. var y = tf.broadcast_to(x, (2, 4, 3));
  19. var grad = tf.gradients(y, x);
  20. var sess = tf.Session(graph);
  21. float result = sess.run(grad[0]);
  22. Assert.AreEqual(result, 24.0f);
  23. }
  24. [TestMethod]
  25. public void CumsumGrad()
  26. {
  27. var x = tf.constant(2, dtype: dtypes.float32);
  28. var y = tf.broadcast_to(x, (2, 4, 3));
  29. var z = tf.cumsum(y, axis: 1);
  30. var grad = tf.gradients(z, x);
  31. var sess = tf.Session(graph);
  32. float result = sess.run(grad[0]);
  33. Assert.AreEqual(result, 60.0f);
  34. }
  35. [TestMethod, Ignore]
  36. public void testGradients()
  37. {
  38. var inp = tf.constant(1.0, shape: new[] { 32, 100 }, name: "in");
  39. var w = tf.constant(1.0, shape: new[] { 100, 10 }, name: "w");
  40. var b = tf.Variable(1.0, shape: new[] { 10 }, name: "b");
  41. var xw = math_ops.matmul(inp, w, name: "xw");
  42. var h = nn_ops.bias_add(xw, b, name: "h");
  43. var w_grad = gradients_impl.gradients(new[] { h }, new[] { w })[0];
  44. self.assertEquals("MatMul", w_grad.op.type);
  45. // TODO: Operation._original_op
  46. //self.assertEquals(w_grad.op._original_op, xw.op);
  47. self.assertTrue((bool)w_grad.op.get_attr("transpose_a"));
  48. self.assertFalse((bool)w_grad.op.get_attr("transpose_b"));
  49. }
  50. [TestMethod]
  51. public void testBatchMatMulGradient()
  52. {
  53. var a = tf.constant(np.array(Enumerable.Range(1, 18).Select(elem => (float)elem).ToArray()), shape: new[] { 2, 3, 3 });
  54. var b = tf.divide(a, tf.constant(2.0f));
  55. var c = tf.batch_matmul(a, b);
  56. var g = tf.gradients(c, new[] { a, b }, stop_gradients: new[] { a, b });
  57. var checkG = new[]
  58. {
  59. 3.0f, 7.5f, 12.0f,
  60. 3.0f, 7.5f, 12.0f,
  61. 3.0f, 7.5f, 12.0f,
  62. 16.5f, 21.0f, 25.5f,
  63. 16.5f, 21.0f, 25.5f,
  64. 16.5f, 21.0f, 25.5f,
  65. 12.0f, 12.0f, 12.0f,
  66. 15.0f, 15.0f, 15.0f,
  67. 18.0f, 18.0f, 18.0f,
  68. 39.0f, 39.0f, 39.0f,
  69. 42.0f, 42.0f, 42.0f,
  70. 45.0f, 45.0f, 45.0f
  71. };
  72. var sess = tf.Session();
  73. var result = sess.run(g);
  74. var resultList = result[0].ToArray<float>().ToList();
  75. resultList.AddRange(result[1].ToArray<float>());
  76. Console.WriteLine(result.ToString());
  77. CollectionAssert.AreEqual(resultList.ToArray(), checkG);
  78. }
  79. [TestMethod]
  80. public void testSimpleGradients()
  81. {
  82. (T, T) evaluateDerivatives<T>(Func<Tensor, Tensor> f, T xval) where T : unmanaged
  83. {
  84. var x = tf.constant(xval);
  85. var y = f(x);
  86. var g = tf.gradients(y, x);
  87. var session = tf.Session();
  88. var result = session.run(new[] { y, g[0] });
  89. return (result[0].ToArray<T>()[0], result[1].ToArray<T>()[0]);
  90. }
  91. void test(string name, Func<Tensor, Tensor> tfF, Func<double, (double, double)> targetF, double[] values)
  92. {
  93. foreach (var x in values)
  94. {
  95. var (expectedY, expectedDY) = targetF(x);
  96. {
  97. var (actualY, actualDY) = evaluateDerivatives(tfF, x);
  98. self.assertFloat64Equal(expectedY, actualY, $"value {name}/float64 at {x}");
  99. self.assertFloat64Equal(expectedDY, actualDY, $"derivative {name}/float64 at {x}");
  100. }
  101. {
  102. var (actualY, actualDY) = evaluateDerivatives(tfF, (float)x);
  103. self.assertFloat32Equal((float)expectedY, actualY, $"value {name}/float32 at {x}");
  104. self.assertFloat32Equal((float)expectedDY, actualDY, $"derivative {name}/float32 at {x}");
  105. }
  106. }
  107. }
  108. test("tf.exp",
  109. x => tf.exp(5 * x),
  110. x => (Math.Exp(5.0 * x), 5.0 * Math.Exp(5.0 * x)),
  111. new[] { -1.0, 0.0, 1.0, 1.5 });
  112. test("tf.log",
  113. x => tf.log(x),
  114. x => (Math.Log(x), 1.0 / x),
  115. new[] { 0.5, 1.0, 1.5, 2.0 });
  116. test("tf.sqrt",
  117. x => tf.sqrt(x),
  118. x => (Math.Sqrt(x), 0.5 / Math.Sqrt(x)),
  119. new[] { 0.5, 1.0, 1.1, 1.5, 2.0 });
  120. test("tf.sin",
  121. x => tf.sin(x),
  122. x => (Math.Sin(x), Math.Cos(x)),
  123. new[] { -1.0, 0.0, 1.0, 1.5, 2.0 });
  124. test("tf.sinh",
  125. x => tf.sinh(x),
  126. x => (Math.Sinh(x), Math.Cosh(x)),
  127. new[] { -1.0, 0.0, 1.0, 1.5, 2.0 });
  128. test("tf.cos",
  129. x => tf.cos(x),
  130. x => (Math.Cos(x), -Math.Sin(x)),
  131. new[] { -1.0, 0.0, 1.0, 1.5, 2.0 });
  132. test("tf.cosh",
  133. x => tf.cosh(x),
  134. x => (Math.Cosh(x), Math.Sinh(x)),
  135. new[] { -1.0, 0.0, 1.0, 1.5, 2.0 });
  136. test("tf.tanh",
  137. x => tf.tanh(x),
  138. x => (Math.Tanh(x), 1.0 - Math.Pow(Math.Tanh(x), 2.0)),
  139. new[] { -1.0, 0.0, 1.0, 1.5, 2.0 });
  140. test("tf.maximum",
  141. x => tf.maximum(x, tf.constant(0.0, dtype: x.dtype)),
  142. x => (Math.Max(x, 0.0), (x > 0.0) ? 1.0 : 0.0),
  143. new[] { -1.0, 1.0 });
  144. test("tf.minimum",
  145. x => tf.minimum(x, tf.constant(0.0, dtype: x.dtype)),
  146. x => (Math.Min(x, 0.0), (x < 0.0) ? 1.0 : 0.0),
  147. new[] { -1.0, 1.0 });
  148. }
  149. [TestMethod]
  150. public void testReduceSumGradients()
  151. {
  152. /* python code
  153. import tensorflow.compat.v1 as tf
  154. tf.disable_v2_behavior()
  155. x = tf.placeholder(tf.float64, shape = (1, 1))
  156. m = tf.broadcast_to(x, (2, 3))
  157. g0 = tf.gradients(tf.reduce_sum(m), x)[0]
  158. g1 = tf.gradients(tf.reduce_sum(m, axis = 0)[0], x)[0]
  159. g2 = tf.gradients(tf.reduce_sum(m, axis = 1)[0], x)[0]
  160. with tf.compat.v1.Session() as sess:
  161. (r0, r1, r2) = sess.run((g0, g1, g2), {x: [[1.0]]})
  162. */
  163. var x = tf.placeholder(tf.float64, shape: new Shape(1, 1));
  164. var m = tf.broadcast_to(x, new Shape(2, 3));
  165. var g0 = tf.gradients(tf.reduce_sum(m), x)[0];
  166. var g1 = tf.gradients(tf.reduce_sum(m, axis: 0)[0], x)[0];
  167. var g2 = tf.gradients(tf.reduce_sum(m, axis: 1)[0], x)[0];
  168. var session = tf.Session();
  169. var (r0, r1, r2) = session.run((g0, g1, g2), new FeedItem(x, new[,] { { 1.0 } }));
  170. self.assertFloat64Equal(6.0, r0[0], $"tf.reduce_sum(...)");
  171. self.assertFloat64Equal(2.0, r1[0], $"tf.reduce_sum(..., axis = 0)");
  172. self.assertFloat64Equal(3.0, r2[0], $"tf.reduce_sum(..., axis = 1)");
  173. }
  174. [TestMethod]
  175. public void testTanhGradient()
  176. {
  177. var a = tf.constant(1f);
  178. var b = tf.tanh(a);
  179. var g = tf.gradients(b, a);
  180. var sess = tf.Session();
  181. var result = sess.run(g);
  182. var actual = result[0];
  183. Assert.AreEqual(actual, 0.41997434127f);
  184. }
  185. [TestMethod]
  186. public void testLgammaGrad()
  187. {
  188. var a = tf.constant(5f);
  189. var b = tf.lgamma(a);
  190. var g = tf.gradients(b, a);
  191. var sess = tf.Session();
  192. var result = sess.run(new object[] { g, b });
  193. var actualDeriv = result[0];
  194. var actual = result[1];
  195. Assert.AreEqual(actualDeriv, 1.5061177f);
  196. Assert.AreEqual(actual, 3.17805386f);
  197. }
  198. [TestMethod]
  199. public void testSliceGrad()
  200. {
  201. var a = tf.tanh(tf.constant(new[] { 2f, 3f }, shape: new[] { 2, 1 }));
  202. var b = tf.strided_slice(a,
  203. tf.constant(new[] { 0 }, tf.int32, new[] { 1 }),
  204. tf.constant(new[] { 1 }, tf.int32, new[] { 1 }),
  205. tf.constant(new[] { 1 }, tf.int32, new[] { 1 })
  206. );
  207. var g = tf.gradients(b, a);
  208. var sess = tf.Session();
  209. var result = sess.run(new object[] { g, b });
  210. var actualDeriv = np.squeeze(result[0]);
  211. var actual = np.squeeze(result[1]);
  212. Assert.AreEqual(actualDeriv, new float[] { 1, 0 });
  213. Assert.AreEqual(actual, 0.9640276f);
  214. }
  215. [TestMethod]
  216. public void testConcatGrad()
  217. {
  218. var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 });
  219. var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 });
  220. var a = tf.concat(new List<Tensor>(new[] { a1, a2 }), 0);
  221. var g = tf.gradients(a, a1);
  222. var sess = tf.Session();
  223. var result = sess.run(new object[] { g, a });
  224. var actualDeriv = result[0][0];
  225. var actual = result[1][0];
  226. Assert.AreEqual(actualDeriv, 1f);
  227. Assert.AreEqual(actual, 2f);
  228. }
  229. [TestMethod]
  230. public void testStopGradientFunction()
  231. {
  232. var ap = tf.constant(1f);
  233. var b = tf.tanh(ap) + array_ops.stop_gradient(ap);
  234. var g = tf.gradients(b, ap);
  235. var sess = tf.Session();
  236. var result = sess.run(g);
  237. var actual = result[0];
  238. Assert.AreEqual(actual, 0.41997434127f);
  239. }
  240. [Ignore("TODO")]
  241. [TestMethod]
  242. public void testUnusedOutput()
  243. {
  244. //def testUnusedOutput(self):
  245. // with ops.Graph().as_default():
  246. // w = constant(1.0, shape=[2, 2])
  247. // x = constant(1.0, shape=[2, 2])
  248. // wx = math_ops.matmul(w, x)
  249. // split_wx = array_ops.split(value=wx, num_or_size_splits=2, axis=0)
  250. // c = math_ops.reduce_sum(split_wx[1])
  251. // gw = gradients.gradients(c, [w])[0]
  252. // self.assertEquals("MatMul", gw.op.type)
  253. }
  254. [Ignore("TODO")]
  255. [TestMethod]
  256. public void testColocateGradients()
  257. {
  258. //def testColocateGradients(self):
  259. // with ops.Graph().as_default() as g:
  260. // w = constant(1.0, shape=[1, 1])
  261. // x = constant(1.0, shape=[1, 2])
  262. // with g.device("/device:GPU:0"):
  263. // wx = math_ops.matmul(w, x)
  264. // gw = gradients.gradients(wx, [w], colocate_gradients_with_ops=True)[0]
  265. // self.assertEqual(gw.op.colocation_groups(), wx.op.colocation_groups())
  266. }
  267. [Ignore("TODO")]
  268. [TestMethod]
  269. public void testColocateGradientsWithAggregation()
  270. {
  271. //def testColocateGradientsWithAggregation(self):
  272. // with ops.Graph().as_default() as g:
  273. // with g.device("/device:GPU:1"):
  274. // w = constant(1.0, shape=[1, 1])
  275. // x = constant(1.0, shape=[1, 2])
  276. // y = constant(1.0, shape=[1, 2])
  277. // wx = math_ops.matmul(w, x)
  278. // wy = math_ops.matmul(w, y)
  279. // with g.device("/device:GPU:0"):
  280. // z = wx + wy
  281. // gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0]
  282. // self.assertEqual(gw1.op.colocation_groups(), wx.op.colocation_groups())
  283. // gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0]
  284. // self.assertTrue(wx.op.colocation_groups() != gw2.op.colocation_groups())
  285. }
  286. [Ignore("TODO")]
  287. [TestMethod]
  288. public void testColocateGradientsWithAggregationInMultipleDevices()
  289. {
  290. //def testColocateGradientsWithAggregationInMultipleDevices(self):
  291. // with ops.Graph().as_default() as g:
  292. // with g.device("/device:GPU:1"):
  293. // w = constant(1.0, shape=[1, 1])
  294. // x = constant(1.0, shape=[1, 2])
  295. // y = constant(1.0, shape=[1, 2])
  296. // with g.device("/task:1"):
  297. // wx = math_ops.matmul(w, x)
  298. // with g.device("/task:2"):
  299. // wy = math_ops.matmul(w, y)
  300. // with g.device("/device:GPU:0"):
  301. // z = wx + wy
  302. // gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0]
  303. // self.assertEqual(gw1.op.colocation_groups(), w.op.colocation_groups())
  304. // gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0]
  305. // self.assertTrue(w.op.colocation_groups() != gw2.op.colocation_groups())
  306. }
  307. [Ignore("TODO")]
  308. [TestMethod]
  309. public void testColocateGradientsWithGateGradients()
  310. {
  311. //def testColocateGradientsWithGateGradients(self):
  312. // if not test_util.is_gpu_available():
  313. // self.skipTest("No GPU available")
  314. // with ops.Graph().as_default() as g:
  315. // with g.device("/device:CPU:0"):
  316. // x = constant(1.0, shape=[1, 1])
  317. // y = constant(1.0, shape=[1, 1])
  318. // s = x + y
  319. // with g.device("/device:GPU:0"):
  320. // z = math_ops.reduce_sum(s)
  321. // gz_x = gradients.gradients(z, [x], colocate_gradients_with_ops=True,
  322. // gate_gradients=True)[0]
  323. // with session.Session():
  324. // # Make sure the placer doesn't complain.
  325. // self.evaluate(gz_x)
  326. }
  327. [Ignore("TODO")]
  328. [TestMethod]
  329. public void testBoundaryStop()
  330. {
  331. //def testBoundaryStop(self):
  332. // # Test that we don't differentiate 'x'. The gradient function for 'x' is
  333. // # set explicitly to None so we will get an exception if the gradient code
  334. // # tries to differentiate 'x'.
  335. // with ops.Graph().as_default():
  336. // c = constant(1.0)
  337. // x = array_ops.identity(c)
  338. // y = x + 1.0
  339. // z = y + 1
  340. // grads = gradients.gradients(z, [x])
  341. // self.assertTrue(all(x is not None for x in grads))
  342. }
  343. [TestMethod]
  344. public void testBoundaryContinue()
  345. {
  346. // Test that we differentiate both 'x' and 'y' correctly when x is a
  347. // predecessor of y.
  348. //TODO: @test_util.run_v1_only("b/120545219")
  349. using (self.cached_session())
  350. {
  351. var x = tf.constant(1.0);
  352. var y = x * 2.0;
  353. var z = y * 3.0;
  354. var grads = tf.gradients(z, new[] { x, y });
  355. self.assertTrue(all(grads.Select(x => x != null)));
  356. self.assertEqual(6.0, grads[0].eval());
  357. }
  358. }
  359. [TestMethod]
  360. public void testAggregationMethodAccumulateN()
  361. {
  362. //TODO: @test_util.run_v1_only("b/120545219")
  363. using (self.cached_session())
  364. {
  365. var x = tf.constant(1.0);
  366. var y = x * 2.0;
  367. var z = y + y + y + y + y + y + y + y + y + y;
  368. var grads = tf.gradients(z, new[] { x, y },
  369. aggregation_method: AggregationMethod.EXPERIMENTAL_ACCUMULATE_N);
  370. self.assertTrue(all(grads.Select(x => x != null)));
  371. self.assertEqual(20.0, grads[0].eval());
  372. self.assertEqual(10.0, grads[1].eval());
  373. }
  374. }
  375. [TestMethod]
  376. public void testAggregationMethodAddN()
  377. {
  378. //TODO: @test_util.run_v1_only("b/120545219")
  379. using (self.cached_session())
  380. {
  381. var x = tf.constant(1.0);
  382. var y = x * 2.0;
  383. var z = y + y + y + y + y + y + y + y + y + y;
  384. var grads = tf.gradients(z, new[] { x, y },
  385. aggregation_method: AggregationMethod.ADD_N);
  386. self.assertTrue(grads.All(x => x != null));
  387. self.assertEqual(20.0, grads[0].eval());
  388. self.assertEqual(10.0, grads[1].eval());
  389. }
  390. }
  391. [TestMethod]
  392. public void testAggregationMethodTree()
  393. {
  394. //TODO: @test_util.run_v1_only("b/120545219")
  395. using (self.cached_session())
  396. {
  397. var x = tf.constant(1.0);
  398. var y = x * 2.0;
  399. var z = y + y + y + y + y + y + y + y + y + y;
  400. var grads = tf.gradients(z, new[] { x, y },
  401. aggregation_method: AggregationMethod.EXPERIMENTAL_TREE);
  402. self.assertTrue(grads.All(x => x != null));
  403. self.assertEqual(20.0, grads[0].eval());
  404. self.assertEqual(10.0, grads[1].eval());
  405. }
  406. }
  407. [Ignore("TODO")]
  408. [TestMethod]
  409. public void testNoGradientForStringOutputs()
  410. {
  411. //def testNoGradientForStringOutputs(self):
  412. // with ops.Graph().as_default():
  413. // def _TestOpGrad(_, float_grad, string_grad):
  414. // """Gradient function for TestStringOutput."""
  415. // self.assertEquals(float_grad.dtype, dtypes.float32)
  416. // self.assertFalse(string_grad)
  417. // return float_grad
  418. // ops.RegisterGradient("TestStringOutput")(_TestOpGrad)
  419. // c = constant(1.0)
  420. // x, _ = test_ops.test_string_output(c)
  421. // z = x * 2.0
  422. // w = z * 3.0
  423. // grads = gradients.gradients(z, [c])
  424. // self.assertTrue(isinstance(grads[0], ops.Tensor))
  425. // grads = gradients.gradients(w, [c])
  426. // self.assertTrue(isinstance(grads[0], ops.Tensor))
  427. }
  428. [TestMethod]
  429. public void testSingletonIndexedSlices()
  430. {
  431. tf.Graph().as_default();
  432. var x = tf.placeholder(TF_DataType.TF_FLOAT);
  433. var y = tf.identity(x);
  434. var dy_indices = tf.placeholder(TF_DataType.TF_INT32);
  435. var dy_values = tf.placeholder(TF_DataType.TF_FLOAT);
  436. Tensor dy = new IndexedSlices(dy_values, dy_indices);
  437. var dx = tf.gradients(new[] { y }, new[] { x }, grad_ys: new[] { dy })[0];
  438. // The IndexedSlices gradient of tf.identity is the identity map.
  439. using (var sess = self.cached_session())
  440. {
  441. var feed_dict = new FeedItem[]
  442. {
  443. ( x, new Tensor(new float[] { 1.0f }) ),
  444. (dy_indices, new Tensor(new int[] { 0 })),
  445. (dy_values, new Tensor(new float[] { 2.0f }))
  446. };
  447. var result = sess.run(new[] { dx, dy }, feed_dict);
  448. var vdx = result[0];
  449. var vdy = result[1];
  450. self.assertEqual(vdx, vdy);
  451. }
  452. }
  453. [Ignore("TODO")]
  454. [TestMethod]
  455. public void testNonDifferentiableSwitchInWhileLoop()
  456. {
  457. //@test_util.run_v1_only("b/120545219")
  458. //def testNonDifferentiableSwitchInWhileLoop(self):
  459. // with ops.Graph().as_default():
  460. // v = array_ops.placeholder(dtypes.float32, [])
  461. // def _Step(i, a, ta):
  462. // a += math_ops.cast(v, dtypes.int32)
  463. // return (i + 1, a, ta.write(i, a))
  464. // n = 4
  465. // i, _, ta = control_flow_ops.while_loop(
  466. // lambda i, *_: i < n,
  467. // _Step, [0, 0, tensor_array_ops.TensorArray(
  468. // dtypes.int32, size=n)])
  469. // target = ta.read(i - 1)
  470. // grad, = gradients.gradients(target, v)
  471. // self.assertIsNone(grad)
  472. }
  473. [Ignore("TODO")]
  474. [TestMethod]
  475. public void testVariableReadValueGradient()
  476. {
  477. //def testVariableReadValueGradient(self):
  478. // with ops.Graph().as_default():
  479. // init = constant_op.constant(100.0)
  480. // var = variables.Variable(init)
  481. // gradient = gradients.gradients(var.read_value(), var)
  482. // self.assertIsNotNone(gradient)
  483. }
  484. [Ignore("TODO")]
  485. [TestMethod]
  486. public void testVariableAsGraphElementGradient()
  487. {
  488. //def testVariableAsGraphElementGradient(self):
  489. // with ops.Graph().as_default() as graph:
  490. // init = constant_op.constant(100.0)
  491. // var = variables.Variable(init)
  492. // gradient = gradients.gradients(graph.as_graph_element(var), var)
  493. // self.assertIsNotNone(gradient)
  494. }
  495. [Ignore("TODO")]
  496. [TestMethod]
  497. public void testVariableRefGradient()
  498. {
  499. //@test_util.run_v1_only("b/120545219")
  500. //def testVariableRefGradient(self):
  501. // with ops.Graph().as_default():
  502. // init = constant_op.constant(100.0)
  503. // var = variables.VariableV1(init)
  504. // gradient = gradients.gradients(var._ref(), var)
  505. // self.assertIsNotNone(gradient)
  506. }
  507. [TestMethod]
  508. public void testDependentYs()
  509. {
  510. //TODO: @test_util.run_v1_only("b/120545219")
  511. using (self.cached_session())
  512. {
  513. var x = constant_op.constant(3.0);
  514. var y = math_ops.square(x);
  515. var y1 = math_ops.square(y);
  516. var y2 = math_ops.square(y1);
  517. var g = tf.gradients(new[] { y, y2 }, new[] { x });
  518. self.assertAllClose(17502.0, g[0].eval());
  519. g = tf.gradients(y + y2, x);
  520. self.assertAllClose(17502.0, g[0].eval());
  521. var z = array_ops.identity(y);
  522. var z2 = array_ops.identity(y2);
  523. g = tf.gradients(new[] { z, z2 }, new[] { x });
  524. self.assertAllClose(17502.0, g[0].eval());
  525. }
  526. }
  527. [Ignore("TODO")]
  528. [TestMethod]
  529. public void testPartialDerivatives()
  530. {
  531. //TODO: @test_util.run_v1_only("b/120545219")
  532. using (self.cached_session())
  533. {
  534. var x = tf.constant(1.0);
  535. var y = 2 * x;
  536. var z = x + y;
  537. var totalg = tf.gradients(z, new[] { x, y });
  538. self.assertEqual(new[] { 3.0, 1.0 }, totalg.Select(g => g.eval()));
  539. var partialg = tf.gradients(z, new[] { x, y }, stop_gradients: new[] { x, y });
  540. self.assertEqual(new[] { 1.0, 1.0 }, partialg.Select(g => g.eval()));
  541. }
  542. }
  543. // TODO: remove when np.testing.assert_allclose(a, b) is implemented
  544. private class CollectionComparer : System.Collections.IComparer
  545. {
  546. private readonly double _epsilon = 1e-07;
  547. public int Compare(object x, object y)
  548. {
  549. var a = (double)x;
  550. var b = (double)y;
  551. double delta = Math.Abs(a - b);
  552. if (delta < _epsilon)
  553. {
  554. return 0;
  555. }
  556. return a.CompareTo(b);
  557. }
  558. }
  559. private struct Case
  560. {
  561. public Tensor[] grad1;
  562. public Tensor[] grad2;
  563. public string constants;
  564. public string variables;
  565. }
  566. [Ignore("FIXME")]
  567. [TestMethod]
  568. public void testStopGradients()
  569. {
  570. //TODO: @test_util.run_v1_only("b/120545219")
  571. Dictionary<char, Tensor> makeGraph(RandomizedImpl rng, string stop_gradients)
  572. {
  573. Tensor functionOf(Tensor[] xs, int k)
  574. {
  575. var shape = new Shape(k, k);
  576. // TODO: replace by DefaultIfEmpty() before Aggregate().
  577. if (!xs.Any())
  578. {
  579. return rng.random(shape).astype(np.float32);
  580. }
  581. return xs.Select(x => gen_math_ops.mat_mul(rng.random(shape).astype(np.float32), x))
  582. .Aggregate((t1, t2) => t1 + t2)
  583. + rng.random(shape).astype(np.float32);
  584. }
  585. var a = functionOf(Array.Empty<Tensor>(), 3);
  586. if (stop_gradients.Contains('a')) a = array_ops.stop_gradient(a);
  587. var b = functionOf(new Tensor[] { a }, 3);
  588. if (stop_gradients.Contains('b')) b = array_ops.stop_gradient(b);
  589. var c = functionOf(new Tensor[] { a, b }, 3);
  590. if (stop_gradients.Contains('c')) c = array_ops.stop_gradient(c);
  591. var d = functionOf(new Tensor[] { b, c }, 3);
  592. if (stop_gradients.Contains('d')) d = array_ops.stop_gradient(d);
  593. return new Dictionary<char, Tensor>
  594. {
  595. { 'a', a },
  596. { 'b', b },
  597. { 'c', c },
  598. { 'd', d }
  599. };
  600. }
  601. Tensor[] gradients(Tensor[] ys, Tensor[] xs, Tensor[] stop_gradients = null)
  602. {
  603. var dydxs = tf.gradients(ys, xs, stop_gradients);
  604. dydxs = dydxs.Select((dydx, i) => dydx == null ? xs[i] * 0 : dydx).ToArray();
  605. return dydxs;
  606. }
  607. var seed = np.random.randint(1000);
  608. // TODO: remove next line when np.random.RandomState implemented.
  609. tf.set_random_seed(seed);
  610. var cases = new List<Case>();
  611. // TODO: add "" case.
  612. var subsets = new List<string> { "" }.Concat("a b c d ab ac ad bc bd cd abc abd acd bcd abcd".Split());
  613. // TODO: pass np.random.RandomState(seed) instead of np.random
  614. var graph = makeGraph(np.random, string.Empty);
  615. foreach (var constants in subsets)
  616. {
  617. var graphWithStops = makeGraph(np.random, constants);
  618. foreach (var variables_ in subsets)
  619. {
  620. // compute the gradient when stopped using tf.stop_gradients
  621. var grad1 = gradients(
  622. new[] { graphWithStops['d'] },
  623. variables_.ToCharArray().Select(v => graphWithStops[v]).ToArray()
  624. );
  625. // compute the gradient when stopped using the stop_gradients from args
  626. var grad2 = gradients(
  627. new[] { graph['d'] },
  628. variables_.ToCharArray().Select(v => graph[v]).ToArray(),
  629. constants.ToCharArray().Select(c => graph[c]).DefaultIfEmpty(null)?.ToArray()
  630. );
  631. cases.Add(new Case
  632. {
  633. grad1 = grad1,
  634. grad2 = grad2,
  635. variables = variables_,
  636. constants = constants,
  637. }) ;
  638. }
  639. }
  640. // evaluate all tensors in one call to session.run for speed
  641. using (var sess = self.cached_session())
  642. {
  643. var results = sess.run(
  644. cases.Select(case_ => (
  645. case_.grad1,
  646. case_.grad2
  647. )).ToArray()
  648. );
  649. foreach (var (result, case_) in results.Zip(cases))
  650. {
  651. var npgrad1 = result[0];
  652. var npgrad2 = result[1];
  653. foreach (var (a, b) in npgrad1.Zip(npgrad2))
  654. {
  655. // TODO: np.testing.assert_allclose(a, b);
  656. CollectionAssert.AreEqual(a.ToArray(), b.ToArray(), new CollectionComparer());
  657. }
  658. }
  659. }
  660. }
  661. [Ignore("TODO: Unconnected gradients are not implemented")]
  662. [TestMethod]
  663. public void testUnconnectedGradientsNoneUnconnectedGradients()
  664. {
  665. //def testUnconnectedGradientsNoneUnconnectedGradients(self):
  666. // with ops.Graph().as_default():
  667. // x = constant(1.0, shape=[2, 2])
  668. // y = constant(3.0, shape=[3, 1])
  669. // grad = gradients.gradients(
  670. // [y], [x], unconnected_gradients="none")
  671. // self.assertIsNone(grad[0])
  672. }
  673. [Ignore("TODO: Unconnected gradients are not implemented")]
  674. [TestMethod]
  675. public void testUnconnectedGradientsZerosUnconnectedGradients()
  676. {
  677. //def testUnconnectedGradientsZerosUnconnectedGradients(self):
  678. // with ops.Graph().as_default():
  679. // x = constant(1.0, shape=[2, 2])
  680. // y = constant(3.0, shape=[3, 1])
  681. // grads = gradients.gradients(
  682. // [y], [x], unconnected_gradients="zero")
  683. // with self.cached_session() as sess:
  684. // self.assertAllEqual([[0.0, 0.0], [0.0, 0.0]], self.evaluate(grads)[0])
  685. // tf.Graph().as_default();
  686. // var x = tf.constant(1.0, shape: new long[] { 2, 2 });
  687. // var y = tf.constant(3.0, shape: new long[] { 3, 1 });
  688. // var grads = tf.gradients(new[] { y }, new[] { x }, unconnected_gradients: "zero");
  689. // using (self.cached_session())
  690. // {
  691. // self.assertAllEqual(new[,] { { 0.0, 0.0 }, { 0.0, 0.0 } }, self.evaluate(grads)[0]);
  692. // }
  693. }
  694. [Ignore("TODO: Unconnected gradients are not implemented")]
  695. [TestMethod]
  696. public void testUnconnectedGradientsZeroConnectedGradients()
  697. {
  698. //def testUnconnectedGradientsZeroConnectedGradients(self):
  699. // with ops.Graph().as_default():
  700. // x = constant(1.0)
  701. // y = x * 3.0
  702. // grad = gradients.gradients(
  703. // [y], [x], unconnected_gradients="zero")
  704. // with self.cached_session() as sess:
  705. // self.assertEquals(3.0, self.evaluate(grad)[0])
  706. // tf.Graph().as_default();
  707. // var x = tf.constant(1.0f);
  708. // var y = x * 3.0f;
  709. // var grad = tf.gradients(new [] { y }, new [] { x }, unconnected_gradients: "zero");
  710. // using (var sess = tf.Session())
  711. // {
  712. // self.assertEquals(3.0, self.evaluate(grad)[0]);
  713. // }
  714. }
  715. [Ignore("TODO: Unconnected gradients are not implemented")]
  716. [TestMethod]
  717. public void testUnknownUnconnectedGradientsValueGiven()
  718. {
  719. //def testUnknownUnconnectedGradientsValueGiven(self):
  720. // with ops.Graph().as_default():
  721. // x = constant(1.0)
  722. // y = constant(1.0)
  723. // with self.assertRaisesRegexp(
  724. // ValueError, "Unknown value for unconnected_gradients: 'nonsense'"):
  725. // gradients.gradients([y], [x], unconnected_gradients="nonsense")
  726. }
  727. }
  728. }