You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Optimizer.cs 3.4 kB

6 years ago
6 years ago
6 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using distribute_lib = Tensorflow.Distribute;
  6. namespace Tensorflow
  7. {
  8. /// <summary>
  9. /// Base class for optimizers.
  10. /// This class defines the API to add Ops to train a model. You never use this
  11. /// class directly, but instead instantiate one of its subclasses such as
  12. /// `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`.
  13. /// </summary>
  14. public abstract class Optimizer
  15. {
  16. public string Name { get; set; }
  17. public double LearningRate { get; set; }
  18. public Tensor LearningRateTensor { get; set; }
  19. public bool _use_locking;
  20. public Dictionary<string, object> _slots;
  21. public Dictionary<string, object> _non_slot_dict;
  22. public Dictionary<string, object> _deferred_slot_restorations;
  23. public Optimizer(double learning_rate, bool use_locking, string name = "")
  24. {
  25. if (String.IsNullOrEmpty(name))
  26. throw new NotImplementedException("Must specify the optimizer name");
  27. Name = name;
  28. _use_locking = use_locking;
  29. // Dictionary of slots.
  30. _slots = new Dictionary<string, object>();
  31. _non_slot_dict = new Dictionary<string, object>();
  32. _deferred_slot_restorations = new Dictionary<string, object>();
  33. }
  34. /// <summary>
  35. /// Add operations to minimize `loss` by updating `var_list`
  36. /// </summary>
  37. /// <param name="loss"></param>
  38. /// <returns></returns>
  39. public Optimizer minimize(Tensor loss,
  40. GateGradientType gate_gradients = GateGradientType.GATE_OP,
  41. bool colocate_gradients_with_ops = false)
  42. {
  43. compute_gradients(loss,
  44. gate_gradients: gate_gradients,
  45. colocate_gradients_with_ops: colocate_gradients_with_ops);
  46. return this;
  47. }
  48. /// <summary>
  49. /// Compute gradients of `loss` for the variables in `var_list`.
  50. /// </summary>
  51. /// <param name="loss"></param>
  52. /// <param name="gate_gradients"></param>
  53. public List<KeyValuePair<object, object>> compute_gradients(Tensor loss,
  54. List<RefVariable> var_list = null,
  55. int? aggregation_method = null,
  56. GateGradientType gate_gradients = GateGradientType.GATE_OP,
  57. bool colocate_gradients_with_ops = false,
  58. List<Tensor> grad_loss = null)
  59. {
  60. int num_towers = 1;
  61. if(distribute_lib.get_loss_reduction() == VariableAggregationType.MEAN)
  62. {
  63. }
  64. var tmp = variables.trainable_variables();
  65. switch (tmp)
  66. {
  67. case List<RefVariable> values:
  68. var_list = values;
  69. break;
  70. }
  71. var processors = var_list.Select(v => optimizer._get_processor(v)).ToList();
  72. var var_refs = processors.Select(x => x.target()).ToList();
  73. gradients_impl.gradients(loss, var_refs, grad_ys: grad_loss,
  74. gate_gradients: (gate_gradients == GateGradientType.GATE_OP),
  75. aggregation_method: aggregation_method,
  76. colocate_gradients_with_ops: colocate_gradients_with_ops);
  77. return null;
  78. }
  79. }
  80. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。