You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GradientTest.cs 2.8 kB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System.Linq;
  3. using Tensorflow;
  4. using Tensorflow.Keras.Engine;
  5. using static Tensorflow.Binding;
  6. using static Tensorflow.KerasApi;
  7. using Tensorflow.NumPy;
  8. using System;
  9. using Tensorflow.Keras.Optimizers;
  10. namespace TensorFlowNET.Keras.UnitTest;
  11. [TestClass]
  12. public class GradientTest : EagerModeTestBase
  13. {
  14. public IModel get_actor(int num_states)
  15. {
  16. var inputs = tf.keras.layers.Input(shape: num_states);
  17. var outputs = tf.keras.layers.Dense(1, activation: keras.activations.Tanh).Apply(inputs);
  18. var model = tf.keras.Model(inputs, outputs);
  19. return model;
  20. }
  21. public IModel get_critic(int num_states, int num_actions)
  22. {
  23. // State as input
  24. var state_input = keras.layers.Input(shape: num_states);
  25. // Action as input
  26. var action_input = keras.layers.Input(shape: num_actions);
  27. var concat = keras.layers.Concatenate(axis: 1).Apply(new Tensors(state_input, action_input));
  28. var outputs = keras.layers.Dense(1).Apply(concat);
  29. var model = tf.keras.Model(new Tensors(state_input, action_input), outputs);
  30. model.summary();
  31. return model;
  32. }
  33. [TestMethod]
  34. public void GetGradientTest()
  35. {
  36. var numStates = 3;
  37. var numActions = 1;
  38. var batchSize = 64;
  39. var gamma = 0.99f;
  40. var target_actor_model = get_actor(numStates);
  41. var target_critic_model = get_critic(numStates, numActions);
  42. var critic_model = get_critic(numStates, numActions);
  43. Tensor state_batch = tf.convert_to_tensor(np.zeros((batchSize, numStates)), TF_DataType.TF_FLOAT);
  44. Tensor action_batch = tf.convert_to_tensor(np.zeros((batchSize, numActions)), TF_DataType.TF_FLOAT);
  45. Tensor reward_batch = tf.convert_to_tensor(np.zeros((batchSize, 1)), TF_DataType.TF_FLOAT);
  46. Tensor next_state_batch = tf.convert_to_tensor(np.zeros((batchSize, numStates)), TF_DataType.TF_FLOAT);
  47. using (var tape = tf.GradientTape())
  48. {
  49. var target_actions = target_actor_model.Apply(next_state_batch, training: true);
  50. var target_critic_value = target_critic_model.Apply(new Tensors(new Tensor[] { next_state_batch, target_actions }), training: true);
  51. var y = reward_batch + tf.multiply(gamma, target_critic_value);
  52. var critic_value = critic_model.Apply(new Tensors(new Tensor[] { state_batch, action_batch }), training: true);
  53. var critic_loss = math_ops.reduce_mean(math_ops.square(y - critic_value));
  54. var critic_grad = tape.gradient(critic_loss, critic_model.TrainableVariables);
  55. Assert.IsNotNull(critic_grad);
  56. Assert.IsNotNull(critic_grad.First());
  57. }
  58. }
  59. }