You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GradientTest.cs 2.7 kB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System.Linq;
  3. using Tensorflow.Keras.Engine;
  4. using Tensorflow.NumPy;
  5. using static Tensorflow.Binding;
  6. using static Tensorflow.KerasApi;
  7. namespace Tensorflow.Keras.UnitTest;
  8. [TestClass]
  9. public class GradientTest : EagerModeTestBase
  10. {
  11. public IModel get_actor(int num_states)
  12. {
  13. var inputs = tf.keras.layers.Input(shape: num_states);
  14. var outputs = tf.keras.layers.Dense(1, activation: keras.activations.Tanh).Apply(inputs);
  15. var model = tf.keras.Model(inputs, outputs);
  16. return model;
  17. }
  18. public IModel get_critic(int num_states, int num_actions)
  19. {
  20. // State as input
  21. var state_input = keras.layers.Input(shape: num_states);
  22. // Action as input
  23. var action_input = keras.layers.Input(shape: num_actions);
  24. var concat = keras.layers.Concatenate(axis: 1).Apply(new Tensors(state_input, action_input));
  25. var outputs = keras.layers.Dense(1).Apply(concat);
  26. var model = tf.keras.Model(new Tensors(state_input, action_input), outputs);
  27. model.summary();
  28. return model;
  29. }
  30. [TestMethod]
  31. public void GetGradientTest()
  32. {
  33. var numStates = 3;
  34. var numActions = 1;
  35. var batchSize = 64;
  36. var gamma = 0.99f;
  37. var target_actor_model = get_actor(numStates);
  38. var target_critic_model = get_critic(numStates, numActions);
  39. var critic_model = get_critic(numStates, numActions);
  40. Tensor state_batch = tf.convert_to_tensor(np.zeros((batchSize, numStates)), TF_DataType.TF_FLOAT);
  41. Tensor action_batch = tf.convert_to_tensor(np.zeros((batchSize, numActions)), TF_DataType.TF_FLOAT);
  42. Tensor reward_batch = tf.convert_to_tensor(np.zeros((batchSize, 1)), TF_DataType.TF_FLOAT);
  43. Tensor next_state_batch = tf.convert_to_tensor(np.zeros((batchSize, numStates)), TF_DataType.TF_FLOAT);
  44. using (var tape = tf.GradientTape())
  45. {
  46. var target_actions = target_actor_model.Apply(next_state_batch, training: true);
  47. var target_critic_value = target_critic_model.Apply(new Tensors(new Tensor[] { next_state_batch, target_actions }), training: true);
  48. var y = reward_batch + tf.multiply(gamma, target_critic_value);
  49. var critic_value = critic_model.Apply(new Tensors(new Tensor[] { state_batch, action_batch }), training: true);
  50. var critic_loss = math_ops.reduce_mean(math_ops.square(y - critic_value));
  51. var critic_grad = tape.gradient(critic_loss, critic_model.TrainableVariables);
  52. Assert.IsNotNull(critic_grad);
  53. Assert.IsNotNull(critic_grad.First());
  54. }
  55. }
  56. }