using System; using System.Collections.Generic; using System.Text; namespace Tensorflow { public static partial class tf { public static IInitializer zeros_initializer => new Zeros(); public static IInitializer glorot_uniform_initializer => new GlorotUniform(); public static variable_scope variable_scope(string name, string default_name = null, object values = null, bool auxiliary_name_scope = true) => new variable_scope(name, default_name, values, auxiliary_name_scope); public static variable_scope variable_scope(VariableScope scope, string default_name = null, object values = null, bool auxiliary_name_scope = true) => new variable_scope(scope, default_name, values, auxiliary_name_scope); public class Zeros : IInitializer { private TF_DataType dtype; public Zeros(TF_DataType dtype = TF_DataType.TF_FLOAT) { this.dtype = dtype; } public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) { if (dtype == TF_DataType.DtInvalid) dtype = this.dtype; return array_ops.zeros(shape, dtype); } public object get_config() { return new { dtype = dtype.name() }; } } /// /// Initializer capable of adapting its scale to the shape of weights tensors. /// public class VarianceScaling : IInitializer { protected float _scale; protected string _mode; protected string _distribution; protected int? _seed; protected TF_DataType _dtype; public VarianceScaling(float scale = 1.0f, string mode = "fan_in", string distribution= "truncated_normal", int? seed = null, TF_DataType dtype = TF_DataType.TF_FLOAT) { if (scale < 0) throw new ValueError("`scale` must be positive float."); _scale = scale; _mode = mode; _distribution = distribution; _seed = seed; _dtype = dtype; } public Tensor call(TensorShape shape, TF_DataType dtype) { var (fan_in, fan_out) = _compute_fans(shape); if (_mode == "fan_in") _scale /= Math.Max(1, fan_in); else if (_mode == "fan_out") _scale /= Math.Max(1, fan_out); else _scale /= Math.Max(1, (fan_in + fan_out) / 2); if (_distribution == "normal" || _distribution == "truncated_normal") { throw new NotImplementedException("truncated_normal"); } else if(_distribution == "untruncated_normal") { throw new NotImplementedException("truncated_normal"); } else { var limit = Math.Sqrt(3.0f * _scale); return random_ops.random_uniform(shape, (float)-limit, (float)limit, dtype, seed: _seed); } } private (int, int) _compute_fans(int[] shape) { if (shape.Length < 1) return (1, 1); if (shape.Length == 1) return (shape[0], shape[0]); if (shape.Length == 2) return (shape[0], shape[1]); else throw new NotImplementedException("VarianceScaling._compute_fans"); } public virtual object get_config() { return new { scale = _scale, mode = _mode, distribution = _distribution, seed = _seed, dtype = _dtype }; } } public class GlorotUniform : VarianceScaling { public GlorotUniform(float scale = 1.0f, string mode = "fan_avg", string distribution = "uniform", int? seed = null, TF_DataType dtype = TF_DataType.TF_FLOAT) : base(scale, mode, distribution, seed, dtype) { } public object get_config() { return new { scale = _scale, mode = _mode, distribution = _distribution, seed = _seed, dtype = _dtype }; } } } }