You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tf.init.cs 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. namespace Tensorflow
  5. {
  6. public static partial class tf
  7. {
  8. public static IInitializer zeros_initializer => new Zeros();
  9. public static IInitializer glorot_uniform_initializer => new GlorotUniform();
  10. public static variable_scope variable_scope(string name,
  11. string default_name = null,
  12. object values = null,
  13. bool auxiliary_name_scope = true) => new variable_scope(name,
  14. default_name,
  15. values,
  16. auxiliary_name_scope);
  17. public static variable_scope variable_scope(VariableScope scope,
  18. string default_name = null,
  19. object values = null,
  20. bool auxiliary_name_scope = true) => new variable_scope(scope,
  21. default_name,
  22. values,
  23. auxiliary_name_scope);
  24. public class Zeros : IInitializer
  25. {
  26. private TF_DataType dtype;
  27. public Zeros(TF_DataType dtype = TF_DataType.TF_FLOAT)
  28. {
  29. this.dtype = dtype;
  30. }
  31. public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid)
  32. {
  33. if (dtype == TF_DataType.DtInvalid)
  34. dtype = this.dtype;
  35. return array_ops.zeros(shape, dtype);
  36. }
  37. public object get_config()
  38. {
  39. return new { dtype = dtype.name() };
  40. }
  41. }
  42. /// <summary>
  43. /// Initializer capable of adapting its scale to the shape of weights tensors.
  44. /// </summary>
  45. public class VarianceScaling : IInitializer
  46. {
  47. protected float _scale;
  48. protected string _mode;
  49. protected string _distribution;
  50. protected int? _seed;
  51. protected TF_DataType _dtype;
  52. public VarianceScaling(float scale = 1.0f,
  53. string mode = "fan_in",
  54. string distribution= "truncated_normal",
  55. int? seed = null,
  56. TF_DataType dtype = TF_DataType.TF_FLOAT)
  57. {
  58. if (scale < 0)
  59. throw new ValueError("`scale` must be positive float.");
  60. _scale = scale;
  61. _mode = mode;
  62. _distribution = distribution;
  63. _seed = seed;
  64. _dtype = dtype;
  65. }
  66. public Tensor call(TensorShape shape, TF_DataType dtype)
  67. {
  68. var (fan_in, fan_out) = _compute_fans(shape);
  69. if (_mode == "fan_in")
  70. _scale /= Math.Max(1, fan_in);
  71. else if (_mode == "fan_out")
  72. _scale /= Math.Max(1, fan_out);
  73. else
  74. _scale /= Math.Max(1, (fan_in + fan_out) / 2);
  75. if (_distribution == "normal" || _distribution == "truncated_normal")
  76. {
  77. throw new NotImplementedException("truncated_normal");
  78. }
  79. else if(_distribution == "untruncated_normal")
  80. {
  81. throw new NotImplementedException("truncated_normal");
  82. }
  83. else
  84. {
  85. var limit = Math.Sqrt(3.0f * _scale);
  86. return random_ops.random_uniform(shape, (float)-limit, (float)limit, dtype, seed: _seed);
  87. }
  88. }
  89. private (int, int) _compute_fans(int[] shape)
  90. {
  91. if (shape.Length < 1)
  92. return (1, 1);
  93. if (shape.Length == 1)
  94. return (shape[0], shape[0]);
  95. if (shape.Length == 2)
  96. return (shape[0], shape[1]);
  97. else
  98. throw new NotImplementedException("VarianceScaling._compute_fans");
  99. }
  100. public virtual object get_config()
  101. {
  102. return new
  103. {
  104. scale = _scale,
  105. mode = _mode,
  106. distribution = _distribution,
  107. seed = _seed,
  108. dtype = _dtype
  109. };
  110. }
  111. }
  112. public class GlorotUniform : VarianceScaling
  113. {
  114. public GlorotUniform(float scale = 1.0f,
  115. string mode = "fan_avg",
  116. string distribution = "uniform",
  117. int? seed = null,
  118. TF_DataType dtype = TF_DataType.TF_FLOAT) : base(scale, mode, distribution, seed, dtype)
  119. {
  120. }
  121. public object get_config()
  122. {
  123. return new
  124. {
  125. scale = _scale,
  126. mode = _mode,
  127. distribution = _distribution,
  128. seed = _seed,
  129. dtype = _dtype
  130. };
  131. }
  132. }
  133. }
  134. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。