You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

AssignTests.cs 1.7 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow;
  3. using static Tensorflow.Binding;
  4. namespace TensorFlowNET.UnitTest.Basics
  5. {
  6. [TestClass]
  7. public sealed class AssignTests
  8. {
  9. [Ignore("Not implemented")]
  10. [TestMethod]
  11. public void ShouldAssignVariable()
  12. {
  13. var raw_data = new[] { 1.0, 2.0, 8.0, -1.0, 0.0, 5.5, 6.0, 16.0 };
  14. var expected = new[] { false, true, false, false, true, false, true };
  15. var spike = tf.Variable(false);
  16. using (var sess = new Session())
  17. {
  18. spike.initializer.run(session: sess);
  19. foreach (var i in range(1, 2))
  20. {
  21. if (raw_data[i] - raw_data[i - 1] > 5d)
  22. {
  23. var updater = tf.assign(spike, tf.constant(true));
  24. updater.eval(sess);
  25. } else
  26. {
  27. tf.assign(spike, tf.constant(true)).eval(sess);
  28. }
  29. Assert.AreEqual((bool) spike.eval(), expected[i - 1]);
  30. }
  31. }
  32. }
  33. [TestMethod]
  34. public void Bug397()
  35. {
  36. // fix bug https://github.com/SciSharp/TensorFlow.NET/issues/397
  37. var W = tf.Variable(-1, name: "weight_" + 1, dtype: tf.float32);
  38. var init = tf.global_variables_initializer();
  39. var reluEval = tf.nn.relu(W);
  40. var nonZero = tf.assign(W, reluEval);
  41. using (var sess = tf.Session())
  42. {
  43. sess.run(init);
  44. float result = nonZero.eval();
  45. Assert.IsTrue(result == 0f);
  46. }
  47. }
  48. }
  49. }