You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Gradient.cs 2.7 kB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System.Linq;
  3. using Tensorflow;
  4. using Tensorflow.Keras.Engine;
  5. using static Tensorflow.Binding;
  6. using static Tensorflow.KerasApi;
  7. using Tensorflow.NumPy;
  8. namespace TensorFlowNET.Keras.UnitTest;
  9. [TestClass]
  10. public class GradientTest
  11. {
  12. public IModel get_actor(int num_states)
  13. {
  14. var inputs = tf.keras.layers.Input(shape: num_states);
  15. var outputs = tf.keras.layers.Dense(1, activation: keras.activations.Tanh).Apply(inputs);
  16. var model = tf.keras.Model(inputs, outputs);
  17. return model;
  18. }
  19. public IModel get_critic(int num_states, int num_actions)
  20. {
  21. // State as input
  22. var state_input = keras.layers.Input(shape: num_states);
  23. // Action as input
  24. var action_input = keras.layers.Input(shape: num_actions);
  25. var concat = keras.layers.Concatenate(axis: 1).Apply(new Tensors(state_input, action_input));
  26. var outputs = keras.layers.Dense(1).Apply(concat);
  27. var model = tf.keras.Model(new Tensors(state_input, action_input), outputs);
  28. model.summary();
  29. return model;
  30. }
  31. [TestMethod]
  32. public void GetGradient_Test()
  33. {
  34. var numStates = 3;
  35. var numActions = 1;
  36. var batchSize = 64;
  37. var gamma = 0.99f;
  38. var target_actor_model = get_actor(numStates);
  39. var target_critic_model = get_critic(numStates, numActions);
  40. var critic_model = get_critic(numStates, numActions);
  41. Tensor state_batch = tf.convert_to_tensor(np.zeros((batchSize, numStates)), TF_DataType.TF_FLOAT);
  42. Tensor action_batch = tf.convert_to_tensor(np.zeros((batchSize, numActions)), TF_DataType.TF_FLOAT);
  43. Tensor reward_batch = tf.convert_to_tensor(np.zeros((batchSize, 1)), TF_DataType.TF_FLOAT);
  44. Tensor next_state_batch = tf.convert_to_tensor(np.zeros((batchSize, numStates)), TF_DataType.TF_FLOAT);
  45. using (var tape = tf.GradientTape())
  46. {
  47. var target_actions = target_actor_model.Apply(next_state_batch, training: true);
  48. var target_critic_value = target_critic_model.Apply(new Tensors(new Tensor[] { next_state_batch, target_actions }), training: true);
  49. var y = reward_batch + tf.multiply(gamma, target_critic_value);
  50. var critic_value = critic_model.Apply(new Tensors(new Tensor[] { state_batch, action_batch }), training: true);
  51. var critic_loss = math_ops.reduce_mean(math_ops.square(y - critic_value));
  52. var critic_grad = tape.gradient(critic_loss, critic_model.TrainableVariables);
  53. Assert.IsNotNull(critic_grad);
  54. Assert.IsNotNull(critic_grad.First());
  55. }
  56. }
  57. }