@@ -17,7 +17,9 @@ | |||||
using NumSharp; | using NumSharp; | ||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Diagnostics; | |||||
using System.Linq; | using System.Linq; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -76,7 +78,14 @@ namespace Tensorflow | |||||
public Tensor concat(IList<Tensor> values, int axis, string name = "concat") | public Tensor concat(IList<Tensor> values, int axis, string name = "concat") | ||||
{ | { | ||||
if (values.Count == 1) | if (values.Count == 1) | ||||
throw new NotImplementedException("tf.concat length is 1"); | |||||
{ | |||||
return tf_with(ops.name_scope(name), scope => | |||||
{ | |||||
var tensor = ops.convert_to_tensor(axis, name: "concat_dim", dtype: dtypes.int32); | |||||
Debug.Assert(tensor.TensorShape.ndim == 0); | |||||
return identity(values[0], name: scope); | |||||
}); | |||||
} | |||||
return gen_array_ops.concat_v2(values.ToArray(), axis, name: name); | return gen_array_ops.concat_v2(values.ToArray(), axis, name: name); | ||||
} | } | ||||
@@ -111,7 +120,7 @@ namespace Tensorflow | |||||
/// <param name="input"></param> | /// <param name="input"></param> | ||||
/// <param name="name"></param> | /// <param name="name"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public static Tensor identity(Tensor input, string name = null) | |||||
public Tensor identity(Tensor input, string name = null) | |||||
=> array_ops.identity(input, name: name); | => array_ops.identity(input, name: name); | ||||
/// <summary> | /// <summary> | ||||
@@ -150,10 +159,10 @@ namespace Tensorflow | |||||
/// <param name="axis"></param> | /// <param name="axis"></param> | ||||
/// <param name="name"></param> | /// <param name="name"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public static Tensor reverse(Tensor tensor, int[] axis, string name = null) | |||||
public Tensor reverse(Tensor tensor, int[] axis, string name = null) | |||||
=> gen_array_ops.reverse(tensor, axis, name: name); | => gen_array_ops.reverse(tensor, axis, name: name); | ||||
public static Tensor reverse(Tensor tensor, Tensor axis, string name = null) | |||||
public Tensor reverse(Tensor tensor, Tensor axis, string name = null) | |||||
=> gen_array_ops.reverse(tensor, axis, name: name); | => gen_array_ops.reverse(tensor, axis, name: name); | ||||
/// <summary> | /// <summary> | ||||
@@ -277,5 +286,14 @@ namespace Tensorflow | |||||
/// <returns>A `Tensor` with all elements set to zero.</returns> | /// <returns>A `Tensor` with all elements set to zero.</returns> | ||||
public Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) | public Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) | ||||
=> array_ops.zeros_like(tensor, dtype: dtype, name: name, optimize: optimize); | => array_ops.zeros_like(tensor, dtype: dtype, name: name, optimize: optimize); | ||||
/// <summary> | |||||
/// Stops gradient computation. | |||||
/// </summary> | |||||
/// <param name="x"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public Tensor stop_gradient(Tensor x, string name = null) | |||||
=> gen_array_ops.stop_gradient(x, name: name); | |||||
} | } | ||||
} | } |
@@ -434,11 +434,14 @@ namespace Tensorflow | |||||
public Tensor reduce_sum(Tensor input, int? axis = null, int? reduction_indices = null, | public Tensor reduce_sum(Tensor input, int? axis = null, int? reduction_indices = null, | ||||
bool keepdims = false, string name = null) | bool keepdims = false, string name = null) | ||||
{ | { | ||||
if(!axis.HasValue && reduction_indices.HasValue) | |||||
if (!axis.HasValue && reduction_indices.HasValue && !keepdims) | |||||
return math_ops.reduce_sum(input, reduction_indices.Value); | return math_ops.reduce_sum(input, reduction_indices.Value); | ||||
else if (axis.HasValue && !reduction_indices.HasValue) | |||||
else if (axis.HasValue && !reduction_indices.HasValue && !keepdims) | |||||
return math_ops.reduce_sum(input, axis.Value); | return math_ops.reduce_sum(input, axis.Value); | ||||
return math_ops.reduce_sum(input, keepdims: keepdims, name: name); | |||||
else if (axis.HasValue && !reduction_indices.HasValue && keepdims) | |||||
return math_ops.reduce_sum(input, keepdims: keepdims, axis: axis.Value, name: name); | |||||
else | |||||
return math_ops.reduce_sum(input, keepdims: keepdims, name: name); | |||||
} | } | ||||
public Tensor reduce_sum(Tensor input, TensorShape axis, int? reduction_indices = null, | public Tensor reduce_sum(Tensor input, TensorShape axis, int? reduction_indices = null, | ||||
@@ -471,6 +474,9 @@ namespace Tensorflow | |||||
public Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null) | public Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null) | ||||
=> math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name, reduction_indices: reduction_indices); | => math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name, reduction_indices: reduction_indices); | ||||
public Tensor reduce_mean(Tensor[] input_tensors, int axis, bool keepdims = false, string name = null) | |||||
=> math_ops.reduce_mean(input_tensors, axis: axis, keepdims: keepdims, name: name); | |||||
public Tensor round(Tensor x, string name = null) | public Tensor round(Tensor x, string name = null) | ||||
=> gen_math_ops.round(x, name: name); | => gen_math_ops.round(x, name: name); | ||||
@@ -65,5 +65,10 @@ namespace Tensorflow | |||||
public void set_random_seed(int seed) | public void set_random_seed(int seed) | ||||
=> ops.get_default_graph().seed = seed; | => ops.get_default_graph().seed = seed; | ||||
public Tensor multinomial(Tensor logits, int num_samples, int? seed = null, | |||||
string name = null, TF_DataType output_dtype = TF_DataType.DtInvalid) | |||||
=> random_ops.multinomial(logits, num_samples, seed: seed, | |||||
name: name, output_dtype: output_dtype); | |||||
} | } | ||||
} | } |
@@ -15,6 +15,7 @@ | |||||
******************************************************************************/ | ******************************************************************************/ | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using Tensorflow.Keras.Optimizers; | |||||
using Tensorflow.Train; | using Tensorflow.Train; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -73,6 +74,26 @@ namespace Tensorflow | |||||
public CheckpointState get_checkpoint_state(string checkpoint_dir, string latest_filename = null) | public CheckpointState get_checkpoint_state(string checkpoint_dir, string latest_filename = null) | ||||
=> checkpoint_management.get_checkpoint_state(checkpoint_dir, latest_filename: latest_filename); | => checkpoint_management.get_checkpoint_state(checkpoint_dir, latest_filename: latest_filename); | ||||
public Tensor polynomial_decay(float learning_rate, | |||||
RefVariable global_step, | |||||
float decay_steps, | |||||
float end_learning_rate = 0.0001f, | |||||
float power = 1.0f, | |||||
bool cycle = false, | |||||
string name = null) | |||||
{ | |||||
var decayed = new PolynomialDecay(learning_rate, | |||||
decay_steps, | |||||
end_learning_rate: end_learning_rate, | |||||
power: power, | |||||
cycle: cycle, | |||||
name: name); | |||||
var decayed_lr = decayed.__call__(global_step); | |||||
return decayed_lr; | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -27,6 +27,15 @@ namespace Tensorflow | |||||
.ToArray(); | .ToArray(); | ||||
} | } | ||||
/// <summary> | |||||
/// Returns an Op that initializes a list of variables. | |||||
/// </summary> | |||||
/// <param name="var_list">List of `Variable` objects to initialize.</param> | |||||
/// <param name="name">Optional name for the returned operation.</param> | |||||
/// <returns>An Op that run the initializers of all the specified variables.</returns> | |||||
public Operation variables_initializer(VariableV1[] var_list, string name = "init") | |||||
=> variables.variables_initializer(var_list, name: name); | |||||
public Operation global_variables_initializer() | public Operation global_variables_initializer() | ||||
{ | { | ||||
var g = variables.global_variables(); | var g = variables.global_variables(); | ||||
@@ -115,6 +115,7 @@ namespace Tensorflow | |||||
return instance; | return instance; | ||||
} | } | ||||
[DebuggerStepThrough] | |||||
[DebuggerNonUserCode()] // with "Just My Code" enabled this lets the debugger break at the origin of the exception | [DebuggerNonUserCode()] // with "Just My Code" enabled this lets the debugger break at the origin of the exception | ||||
public static void tf_with(IObjectLife py, Action<IObjectLife> action) | public static void tf_with(IObjectLife py, Action<IObjectLife> action) | ||||
{ | { | ||||
@@ -273,7 +274,10 @@ namespace Tensorflow | |||||
return sum; | return sum; | ||||
} | } | ||||
public static double sum(IEnumerable<int> enumerable) | |||||
public static float sum(IEnumerable<float> enumerable) | |||||
=> enumerable.Sum(); | |||||
public static int sum(IEnumerable<int> enumerable) | |||||
=> enumerable.Sum(); | => enumerable.Sum(); | ||||
public static double sum<TKey, TValue>(Dictionary<TKey, TValue> values) | public static double sum<TKey, TValue>(Dictionary<TKey, TValue> values) | ||||
@@ -0,0 +1,16 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using System.Threading.Tasks; | |||||
namespace Tensorflow.Keras.Optimizers | |||||
{ | |||||
public class LearningRateSchedule | |||||
{ | |||||
public LearningRateSchedule() | |||||
{ | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,62 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using System.Threading.Tasks; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Keras.Optimizers | |||||
{ | |||||
/// <summary> | |||||
/// A LearningRateSchedule that uses a polynomial decay schedule. | |||||
/// </summary> | |||||
public class PolynomialDecay : LearningRateSchedule | |||||
{ | |||||
float initial_learning_rate; | |||||
float decay_steps; | |||||
float end_learning_rate; | |||||
float power; | |||||
bool cycle; | |||||
string name; | |||||
public PolynomialDecay(float initial_learning_rate, | |||||
float decay_steps, | |||||
float end_learning_rate = 0.0001f, | |||||
float power = 1.0f, | |||||
bool cycle = false, | |||||
string name = null) : base() | |||||
{ | |||||
this.initial_learning_rate = initial_learning_rate; | |||||
this.decay_steps = decay_steps; | |||||
this.end_learning_rate = end_learning_rate; | |||||
this.power = power; | |||||
this.cycle = cycle; | |||||
this.name = name; | |||||
} | |||||
public Tensor __call__(RefVariable step) | |||||
{ | |||||
tf_with(ops.name_scope(name ?? "PolynomialDecay"), scope => | |||||
{ | |||||
name = scope; | |||||
var initial_learning_rate_tensor = ops.convert_to_tensor(initial_learning_rate, name: "initial_learning_rate"); | |||||
var dtype = initial_learning_rate_tensor.dtype; | |||||
var end_learning_rate_tensor = math_ops.cast(end_learning_rate, dtype); | |||||
var power_tensor = math_ops.cast(power, dtype); | |||||
var global_step_recomp = math_ops.cast(step, dtype); | |||||
var decay_steps_recomp = math_ops.cast(decay_steps, dtype); | |||||
if(cycle) | |||||
{ | |||||
throw new NotImplementedException("PolynomialDecay cycle"); | |||||
} | |||||
else | |||||
{ | |||||
} | |||||
}); | |||||
throw new NotImplementedException(""); | |||||
} | |||||
} | |||||
} |
@@ -19,8 +19,7 @@ namespace Tensorflow.Operations.Initializers | |||||
public class GlorotUniform : VarianceScaling | public class GlorotUniform : VarianceScaling | ||||
{ | { | ||||
public GlorotUniform(float scale = 1.0f, | public GlorotUniform(float scale = 1.0f, | ||||
string mode = "fan_avg", | |||||
string distribution = "uniform", | |||||
string mode = "FAN_AVG", | |||||
int? seed = null, | int? seed = null, | ||||
TF_DataType dtype = TF_DataType.TF_FLOAT) : base(factor: scale, | TF_DataType dtype = TF_DataType.TF_FLOAT) : base(factor: scale, | ||||
mode: mode, | mode: mode, | ||||
@@ -36,7 +35,6 @@ namespace Tensorflow.Operations.Initializers | |||||
{ | { | ||||
scale = _scale, | scale = _scale, | ||||
mode = _mode, | mode = _mode, | ||||
distribution = _distribution, | |||||
seed = _seed, | seed = _seed, | ||||
dtype = _dtype | dtype = _dtype | ||||
}; | }; | ||||
@@ -30,6 +30,7 @@ namespace Tensorflow.Operations.Initializers | |||||
protected string _distribution; | protected string _distribution; | ||||
protected int? _seed; | protected int? _seed; | ||||
protected TF_DataType _dtype; | protected TF_DataType _dtype; | ||||
protected bool _uniform; | |||||
public VarianceScaling(float factor = 2.0f, | public VarianceScaling(float factor = 2.0f, | ||||
string mode = "FAN_IN", | string mode = "FAN_IN", | ||||
@@ -49,31 +50,31 @@ namespace Tensorflow.Operations.Initializers | |||||
_mode = mode; | _mode = mode; | ||||
_seed = seed; | _seed = seed; | ||||
_dtype = dtype; | _dtype = dtype; | ||||
_uniform = uniform; | |||||
} | } | ||||
public Tensor call(TensorShape shape, TF_DataType dtype, bool? verify_shape = null) | public Tensor call(TensorShape shape, TF_DataType dtype, bool? verify_shape = null) | ||||
{ | { | ||||
float n = 0; | |||||
var (fan_in, fan_out) = _compute_fans(shape); | var (fan_in, fan_out) = _compute_fans(shape); | ||||
if (_mode == "fan_in") | |||||
_scale /= Math.Max(1, fan_in); | |||||
else if (_mode == "fan_out") | |||||
_scale /= Math.Max(1, fan_out); | |||||
else | |||||
_scale /= Math.Max(1, (fan_in + fan_out) / 2); | |||||
if (_mode == "FAN_IN") | |||||
n = fan_in; | |||||
else if (_mode == "FAN_OUT") | |||||
n = fan_out; | |||||
else if(_mode == "FAN_AVG") | |||||
n = (fan_in + fan_out) / 2.0f; | |||||
if (_distribution == "normal" || _distribution == "truncated_normal") | |||||
{ | |||||
float stddev = (float)Math.Sqrt(_scale) / .87962566103423978f; | |||||
return random_ops.truncated_normal(shape, mean: 0.0f, stddev: stddev, dtype: dtype, seed: _seed); | |||||
} | |||||
else if (_distribution == "untruncated_normal") | |||||
if(_uniform) | |||||
{ | { | ||||
throw new NotImplementedException("truncated_normal"); | |||||
var limit = Convert.ToSingle(Math.Sqrt(3.0f * _scale / n)); | |||||
return random_ops.random_uniform(shape, -limit, limit, | |||||
dtype, seed: _seed); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
var limit = Math.Sqrt(3.0f * _scale); | |||||
return random_ops.random_uniform(shape, (float)-limit, (float)limit, dtype, seed: _seed); | |||||
var trunc_stddev = Convert.ToSingle(Math.Sqrt(1.3f * _scale / n)); | |||||
return random_ops.truncated_normal(shape, 0.0f, trunc_stddev, dtype, | |||||
seed: _seed); | |||||
} | } | ||||
} | } | ||||
@@ -106,6 +107,7 @@ namespace Tensorflow.Operations.Initializers | |||||
mode = _mode, | mode = _mode, | ||||
distribution = _distribution, | distribution = _distribution, | ||||
seed = _seed, | seed = _seed, | ||||
uniform = _uniform, | |||||
dtype = _dtype | dtype = _dtype | ||||
}; | }; | ||||
} | } | ||||
@@ -383,7 +383,7 @@ namespace Tensorflow | |||||
{ | { | ||||
var _op = _op_def_lib._apply_op_helper("StopGradient", name, args: new { input = x, name }); | var _op = _op_def_lib._apply_op_helper("StopGradient", name, args: new { input = x, name }); | ||||
return _op.outputs[0]; | |||||
return _op.output; | |||||
} | } | ||||
public static Tensor strided_slice(Tensor input, Tensor begin, Tensor end, Tensor strides, | public static Tensor strided_slice(Tensor input, Tensor begin, Tensor end, Tensor strides, | ||||
@@ -115,7 +115,7 @@ namespace Tensorflow | |||||
{ | { | ||||
var _op = _op_def_lib._apply_op_helper("Mean", name, args: new { input, reduction_indices = axis, keep_dims = keep_dims }); | var _op = _op_def_lib._apply_op_helper("Mean", name, args: new { input, reduction_indices = axis, keep_dims = keep_dims }); | ||||
return _op.outputs[0]; | |||||
return _op.output; | |||||
} | } | ||||
public static Tensor prod<T1, T2>(T1 input, T2 axis, bool keep_dims = false, string name = null) | public static Tensor prod<T1, T2>(T1 input, T2 axis, bool keep_dims = false, string name = null) | ||||
@@ -98,7 +98,8 @@ namespace Tensorflow | |||||
/// <param name="seed2"></param> | /// <param name="seed2"></param> | ||||
/// <param name="name"></param> | /// <param name="name"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0, string name = null) | |||||
public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0, | |||||
string name = null) | |||||
{ | { | ||||
var _op = _op_def_lib._apply_op_helper("RandomShuffle", | var _op = _op_def_lib._apply_op_helper("RandomShuffle", | ||||
name: name, | name: name, | ||||
@@ -116,7 +117,8 @@ namespace Tensorflow | |||||
/// <param name="seed2"></param> | /// <param name="seed2"></param> | ||||
/// <param name="name"></param> | /// <param name="name"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
public static Tensor truncated_normal(Tensor shape, TF_DataType dtype, int? seed = 0, int? seed2 = 0, string name = null) | |||||
public static Tensor truncated_normal(Tensor shape, TF_DataType dtype, int? seed = 0, | |||||
int? seed2 = 0, string name = null) | |||||
{ | { | ||||
if (!seed.HasValue) | if (!seed.HasValue) | ||||
seed = 0; | seed = 0; | ||||
@@ -127,7 +129,24 @@ namespace Tensorflow | |||||
name: name, | name: name, | ||||
args: new { shape, dtype, seed, seed2 }); | args: new { shape, dtype, seed, seed2 }); | ||||
return _op.outputs[0]; | |||||
return _op.output; | |||||
} | |||||
public static Tensor multinomial(Tensor logits, int num_samples, int? seed = 0, | |||||
int? seed2 = 0, TF_DataType output_dtype = TF_DataType.TF_INT64, string name = null) | |||||
{ | |||||
if (!seed.HasValue) | |||||
seed = 0; | |||||
if (!seed2.HasValue) | |||||
seed2 = 0; | |||||
if (output_dtype == TF_DataType.DtInvalid) | |||||
output_dtype = TF_DataType.TF_INT64; | |||||
var _op = _op_def_lib._apply_op_helper("Multinomial", | |||||
name: name, | |||||
args: new { logits, num_samples, seed, seed2, output_dtype }); | |||||
return _op.output; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -81,6 +81,21 @@ namespace Tensorflow | |||||
}); | }); | ||||
} | } | ||||
public static Tensor cast(float x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) | |||||
{ | |||||
var base_type = dtype.as_base_dtype(); | |||||
return tf_with(ops.name_scope(name, "Cast", new { x }), scope => | |||||
{ | |||||
name = scope; | |||||
var x_tensor = ops.convert_to_tensor(x, name: "x"); | |||||
if (x_tensor.dtype.as_base_dtype() != base_type) | |||||
x_tensor = gen_math_ops.cast(x_tensor, base_type, name: name); | |||||
return x_tensor; | |||||
}); | |||||
} | |||||
public static Tensor cumsum<T>(Tensor x, T axis = default, bool exclusive = false, bool reverse = false, string name = null) | public static Tensor cumsum<T>(Tensor x, T axis = default, bool exclusive = false, bool reverse = false, string name = null) | ||||
{ | { | ||||
return tf_with(ops.name_scope(name, "Cumsum", new {x}), scope => | return tf_with(ops.name_scope(name, "Cumsum", new {x}), scope => | ||||
@@ -204,6 +219,12 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
public static Tensor reduce_mean(Tensor[] input_tensors, int axis, bool keepdims = false, string name = null) | |||||
{ | |||||
var m = gen_math_ops.mean(input_tensors, axis, keepdims, name); | |||||
return _may_reduce_to_scalar(keepdims, axis, m); | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Computes the product of elements across dimensions of a tensor. | /// Computes the product of elements across dimensions of a tensor. | ||||
/// </summary> | /// </summary> | ||||
@@ -142,6 +142,35 @@ namespace Tensorflow | |||||
{ | { | ||||
return ops.convert_to_tensor(shape, name: "shape"); | return ops.convert_to_tensor(shape, name: "shape"); | ||||
} | } | ||||
public static Tensor multinomial(Tensor logits, int num_samples, int? seed = null, | |||||
string name = null, TF_DataType output_dtype = TF_DataType.DtInvalid) | |||||
{ | |||||
return tf_with(ops.name_scope(name, "multinomial", new { logits }), delegate | |||||
{ | |||||
return multinomial_categorical_impl(logits, num_samples, output_dtype, seed); | |||||
}); | |||||
} | |||||
/// <summary> | |||||
/// Implementation for random.categorical (v1) and random.categorical (v2). | |||||
/// </summary> | |||||
/// <param name="logits"></param> | |||||
/// <param name="num_samples"></param> | |||||
/// <param name="output_dtype"></param> | |||||
/// <param name="seed"></param> | |||||
/// <returns></returns> | |||||
private static Tensor multinomial_categorical_impl(Tensor logits, int num_samples, TF_DataType dtype = TF_DataType.DtInvalid, | |||||
int? seed = null) | |||||
{ | |||||
logits = ops.convert_to_tensor(logits, name: "logits"); | |||||
var (seed1, seed2) = random_seed.get_seed(seed); | |||||
return gen_random_ops.multinomial(logits, | |||||
num_samples, | |||||
seed: seed1, | |||||
seed2: seed2, | |||||
output_dtype: dtype); | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -0,0 +1,29 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using System.Threading.Tasks; | |||||
namespace Tensorflow.Training | |||||
{ | |||||
public class learning_rate_decay | |||||
{ | |||||
/// <summary> | |||||
/// Applies a polynomial decay to the learning rate. | |||||
/// </summary> | |||||
/// <param name="learning_rate"></param> | |||||
/// <param name="global_step"></param> | |||||
/// <param name="decay_steps"></param> | |||||
/// <param name="end_learning_rate"></param> | |||||
/// <param name="power"></param> | |||||
/// <param name="cycle"></param> | |||||
/// <param name="name"></param> | |||||
/// <returns></returns> | |||||
public static Tensor polynomial_decay(float learning_rate, RefVariable global_step, float decay_steps, | |||||
float end_learning_rate = 0.0001f, float power = 1.0f, bool cycle = false, | |||||
string name = null) | |||||
{ | |||||
throw new NotImplementedException(""); | |||||
} | |||||
} | |||||
} |