@@ -17,7 +17,9 @@ | |||
using NumSharp; | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Diagnostics; | |||
using System.Linq; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
@@ -76,7 +78,14 @@ namespace Tensorflow | |||
public Tensor concat(IList<Tensor> values, int axis, string name = "concat") | |||
{ | |||
if (values.Count == 1) | |||
throw new NotImplementedException("tf.concat length is 1"); | |||
{ | |||
return tf_with(ops.name_scope(name), scope => | |||
{ | |||
var tensor = ops.convert_to_tensor(axis, name: "concat_dim", dtype: dtypes.int32); | |||
Debug.Assert(tensor.TensorShape.ndim == 0); | |||
return identity(values[0], name: scope); | |||
}); | |||
} | |||
return gen_array_ops.concat_v2(values.ToArray(), axis, name: name); | |||
} | |||
@@ -111,7 +120,7 @@ namespace Tensorflow | |||
/// <param name="input"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static Tensor identity(Tensor input, string name = null) | |||
public Tensor identity(Tensor input, string name = null) | |||
=> array_ops.identity(input, name: name); | |||
/// <summary> | |||
@@ -150,10 +159,10 @@ namespace Tensorflow | |||
/// <param name="axis"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static Tensor reverse(Tensor tensor, int[] axis, string name = null) | |||
public Tensor reverse(Tensor tensor, int[] axis, string name = null) | |||
=> gen_array_ops.reverse(tensor, axis, name: name); | |||
public static Tensor reverse(Tensor tensor, Tensor axis, string name = null) | |||
public Tensor reverse(Tensor tensor, Tensor axis, string name = null) | |||
=> gen_array_ops.reverse(tensor, axis, name: name); | |||
/// <summary> | |||
@@ -277,5 +286,14 @@ namespace Tensorflow | |||
/// <returns>A `Tensor` with all elements set to zero.</returns> | |||
public Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) | |||
=> array_ops.zeros_like(tensor, dtype: dtype, name: name, optimize: optimize); | |||
/// <summary> | |||
/// Stops gradient computation. | |||
/// </summary> | |||
/// <param name="x"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public Tensor stop_gradient(Tensor x, string name = null) | |||
=> gen_array_ops.stop_gradient(x, name: name); | |||
} | |||
} |
@@ -434,11 +434,14 @@ namespace Tensorflow | |||
public Tensor reduce_sum(Tensor input, int? axis = null, int? reduction_indices = null, | |||
bool keepdims = false, string name = null) | |||
{ | |||
if(!axis.HasValue && reduction_indices.HasValue) | |||
if (!axis.HasValue && reduction_indices.HasValue && !keepdims) | |||
return math_ops.reduce_sum(input, reduction_indices.Value); | |||
else if (axis.HasValue && !reduction_indices.HasValue) | |||
else if (axis.HasValue && !reduction_indices.HasValue && !keepdims) | |||
return math_ops.reduce_sum(input, axis.Value); | |||
return math_ops.reduce_sum(input, keepdims: keepdims, name: name); | |||
else if (axis.HasValue && !reduction_indices.HasValue && keepdims) | |||
return math_ops.reduce_sum(input, keepdims: keepdims, axis: axis.Value, name: name); | |||
else | |||
return math_ops.reduce_sum(input, keepdims: keepdims, name: name); | |||
} | |||
public Tensor reduce_sum(Tensor input, TensorShape axis, int? reduction_indices = null, | |||
@@ -471,6 +474,9 @@ namespace Tensorflow | |||
public Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null) | |||
=> math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name, reduction_indices: reduction_indices); | |||
public Tensor reduce_mean(Tensor[] input_tensors, int axis, bool keepdims = false, string name = null) | |||
=> math_ops.reduce_mean(input_tensors, axis: axis, keepdims: keepdims, name: name); | |||
public Tensor round(Tensor x, string name = null) | |||
=> gen_math_ops.round(x, name: name); | |||
@@ -65,5 +65,10 @@ namespace Tensorflow | |||
public void set_random_seed(int seed) | |||
=> ops.get_default_graph().seed = seed; | |||
public Tensor multinomial(Tensor logits, int num_samples, int? seed = null, | |||
string name = null, TF_DataType output_dtype = TF_DataType.DtInvalid) | |||
=> random_ops.multinomial(logits, num_samples, seed: seed, | |||
name: name, output_dtype: output_dtype); | |||
} | |||
} |
@@ -15,6 +15,7 @@ | |||
******************************************************************************/ | |||
using System.Collections.Generic; | |||
using Tensorflow.Keras.Optimizers; | |||
using Tensorflow.Train; | |||
namespace Tensorflow | |||
@@ -73,6 +74,26 @@ namespace Tensorflow | |||
public CheckpointState get_checkpoint_state(string checkpoint_dir, string latest_filename = null) | |||
=> checkpoint_management.get_checkpoint_state(checkpoint_dir, latest_filename: latest_filename); | |||
public Tensor polynomial_decay(float learning_rate, | |||
RefVariable global_step, | |||
float decay_steps, | |||
float end_learning_rate = 0.0001f, | |||
float power = 1.0f, | |||
bool cycle = false, | |||
string name = null) | |||
{ | |||
var decayed = new PolynomialDecay(learning_rate, | |||
decay_steps, | |||
end_learning_rate: end_learning_rate, | |||
power: power, | |||
cycle: cycle, | |||
name: name); | |||
var decayed_lr = decayed.__call__(global_step); | |||
return decayed_lr; | |||
} | |||
} | |||
} | |||
} |
@@ -27,6 +27,15 @@ namespace Tensorflow | |||
.ToArray(); | |||
} | |||
/// <summary> | |||
/// Returns an Op that initializes a list of variables. | |||
/// </summary> | |||
/// <param name="var_list">List of `Variable` objects to initialize.</param> | |||
/// <param name="name">Optional name for the returned operation.</param> | |||
/// <returns>An Op that run the initializers of all the specified variables.</returns> | |||
public Operation variables_initializer(VariableV1[] var_list, string name = "init") | |||
=> variables.variables_initializer(var_list, name: name); | |||
public Operation global_variables_initializer() | |||
{ | |||
var g = variables.global_variables(); | |||
@@ -115,6 +115,7 @@ namespace Tensorflow | |||
return instance; | |||
} | |||
[DebuggerStepThrough] | |||
[DebuggerNonUserCode()] // with "Just My Code" enabled this lets the debugger break at the origin of the exception | |||
public static void tf_with(IObjectLife py, Action<IObjectLife> action) | |||
{ | |||
@@ -273,7 +274,10 @@ namespace Tensorflow | |||
return sum; | |||
} | |||
public static double sum(IEnumerable<int> enumerable) | |||
public static float sum(IEnumerable<float> enumerable) | |||
=> enumerable.Sum(); | |||
public static int sum(IEnumerable<int> enumerable) | |||
=> enumerable.Sum(); | |||
public static double sum<TKey, TValue>(Dictionary<TKey, TValue> values) | |||
@@ -0,0 +1,16 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using System.Threading.Tasks; | |||
namespace Tensorflow.Keras.Optimizers | |||
{ | |||
public class LearningRateSchedule | |||
{ | |||
public LearningRateSchedule() | |||
{ | |||
} | |||
} | |||
} |
@@ -0,0 +1,62 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using System.Threading.Tasks; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Optimizers | |||
{ | |||
/// <summary> | |||
/// A LearningRateSchedule that uses a polynomial decay schedule. | |||
/// </summary> | |||
public class PolynomialDecay : LearningRateSchedule | |||
{ | |||
float initial_learning_rate; | |||
float decay_steps; | |||
float end_learning_rate; | |||
float power; | |||
bool cycle; | |||
string name; | |||
public PolynomialDecay(float initial_learning_rate, | |||
float decay_steps, | |||
float end_learning_rate = 0.0001f, | |||
float power = 1.0f, | |||
bool cycle = false, | |||
string name = null) : base() | |||
{ | |||
this.initial_learning_rate = initial_learning_rate; | |||
this.decay_steps = decay_steps; | |||
this.end_learning_rate = end_learning_rate; | |||
this.power = power; | |||
this.cycle = cycle; | |||
this.name = name; | |||
} | |||
public Tensor __call__(RefVariable step) | |||
{ | |||
tf_with(ops.name_scope(name ?? "PolynomialDecay"), scope => | |||
{ | |||
name = scope; | |||
var initial_learning_rate_tensor = ops.convert_to_tensor(initial_learning_rate, name: "initial_learning_rate"); | |||
var dtype = initial_learning_rate_tensor.dtype; | |||
var end_learning_rate_tensor = math_ops.cast(end_learning_rate, dtype); | |||
var power_tensor = math_ops.cast(power, dtype); | |||
var global_step_recomp = math_ops.cast(step, dtype); | |||
var decay_steps_recomp = math_ops.cast(decay_steps, dtype); | |||
if(cycle) | |||
{ | |||
throw new NotImplementedException("PolynomialDecay cycle"); | |||
} | |||
else | |||
{ | |||
} | |||
}); | |||
throw new NotImplementedException(""); | |||
} | |||
} | |||
} |
@@ -19,8 +19,7 @@ namespace Tensorflow.Operations.Initializers | |||
public class GlorotUniform : VarianceScaling | |||
{ | |||
public GlorotUniform(float scale = 1.0f, | |||
string mode = "fan_avg", | |||
string distribution = "uniform", | |||
string mode = "FAN_AVG", | |||
int? seed = null, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT) : base(factor: scale, | |||
mode: mode, | |||
@@ -36,7 +35,6 @@ namespace Tensorflow.Operations.Initializers | |||
{ | |||
scale = _scale, | |||
mode = _mode, | |||
distribution = _distribution, | |||
seed = _seed, | |||
dtype = _dtype | |||
}; | |||
@@ -30,6 +30,7 @@ namespace Tensorflow.Operations.Initializers | |||
protected string _distribution; | |||
protected int? _seed; | |||
protected TF_DataType _dtype; | |||
protected bool _uniform; | |||
public VarianceScaling(float factor = 2.0f, | |||
string mode = "FAN_IN", | |||
@@ -49,31 +50,31 @@ namespace Tensorflow.Operations.Initializers | |||
_mode = mode; | |||
_seed = seed; | |||
_dtype = dtype; | |||
_uniform = uniform; | |||
} | |||
public Tensor call(TensorShape shape, TF_DataType dtype, bool? verify_shape = null) | |||
{ | |||
float n = 0; | |||
var (fan_in, fan_out) = _compute_fans(shape); | |||
if (_mode == "fan_in") | |||
_scale /= Math.Max(1, fan_in); | |||
else if (_mode == "fan_out") | |||
_scale /= Math.Max(1, fan_out); | |||
else | |||
_scale /= Math.Max(1, (fan_in + fan_out) / 2); | |||
if (_mode == "FAN_IN") | |||
n = fan_in; | |||
else if (_mode == "FAN_OUT") | |||
n = fan_out; | |||
else if(_mode == "FAN_AVG") | |||
n = (fan_in + fan_out) / 2.0f; | |||
if (_distribution == "normal" || _distribution == "truncated_normal") | |||
{ | |||
float stddev = (float)Math.Sqrt(_scale) / .87962566103423978f; | |||
return random_ops.truncated_normal(shape, mean: 0.0f, stddev: stddev, dtype: dtype, seed: _seed); | |||
} | |||
else if (_distribution == "untruncated_normal") | |||
if(_uniform) | |||
{ | |||
throw new NotImplementedException("truncated_normal"); | |||
var limit = Convert.ToSingle(Math.Sqrt(3.0f * _scale / n)); | |||
return random_ops.random_uniform(shape, -limit, limit, | |||
dtype, seed: _seed); | |||
} | |||
else | |||
{ | |||
var limit = Math.Sqrt(3.0f * _scale); | |||
return random_ops.random_uniform(shape, (float)-limit, (float)limit, dtype, seed: _seed); | |||
var trunc_stddev = Convert.ToSingle(Math.Sqrt(1.3f * _scale / n)); | |||
return random_ops.truncated_normal(shape, 0.0f, trunc_stddev, dtype, | |||
seed: _seed); | |||
} | |||
} | |||
@@ -106,6 +107,7 @@ namespace Tensorflow.Operations.Initializers | |||
mode = _mode, | |||
distribution = _distribution, | |||
seed = _seed, | |||
uniform = _uniform, | |||
dtype = _dtype | |||
}; | |||
} | |||
@@ -383,7 +383,7 @@ namespace Tensorflow | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("StopGradient", name, args: new { input = x, name }); | |||
return _op.outputs[0]; | |||
return _op.output; | |||
} | |||
public static Tensor strided_slice(Tensor input, Tensor begin, Tensor end, Tensor strides, | |||
@@ -115,7 +115,7 @@ namespace Tensorflow | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("Mean", name, args: new { input, reduction_indices = axis, keep_dims = keep_dims }); | |||
return _op.outputs[0]; | |||
return _op.output; | |||
} | |||
public static Tensor prod<T1, T2>(T1 input, T2 axis, bool keep_dims = false, string name = null) | |||
@@ -98,7 +98,8 @@ namespace Tensorflow | |||
/// <param name="seed2"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0, string name = null) | |||
public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0, | |||
string name = null) | |||
{ | |||
var _op = _op_def_lib._apply_op_helper("RandomShuffle", | |||
name: name, | |||
@@ -116,7 +117,8 @@ namespace Tensorflow | |||
/// <param name="seed2"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static Tensor truncated_normal(Tensor shape, TF_DataType dtype, int? seed = 0, int? seed2 = 0, string name = null) | |||
public static Tensor truncated_normal(Tensor shape, TF_DataType dtype, int? seed = 0, | |||
int? seed2 = 0, string name = null) | |||
{ | |||
if (!seed.HasValue) | |||
seed = 0; | |||
@@ -127,7 +129,24 @@ namespace Tensorflow | |||
name: name, | |||
args: new { shape, dtype, seed, seed2 }); | |||
return _op.outputs[0]; | |||
return _op.output; | |||
} | |||
public static Tensor multinomial(Tensor logits, int num_samples, int? seed = 0, | |||
int? seed2 = 0, TF_DataType output_dtype = TF_DataType.TF_INT64, string name = null) | |||
{ | |||
if (!seed.HasValue) | |||
seed = 0; | |||
if (!seed2.HasValue) | |||
seed2 = 0; | |||
if (output_dtype == TF_DataType.DtInvalid) | |||
output_dtype = TF_DataType.TF_INT64; | |||
var _op = _op_def_lib._apply_op_helper("Multinomial", | |||
name: name, | |||
args: new { logits, num_samples, seed, seed2, output_dtype }); | |||
return _op.output; | |||
} | |||
} | |||
} |
@@ -81,6 +81,21 @@ namespace Tensorflow | |||
}); | |||
} | |||
public static Tensor cast(float x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) | |||
{ | |||
var base_type = dtype.as_base_dtype(); | |||
return tf_with(ops.name_scope(name, "Cast", new { x }), scope => | |||
{ | |||
name = scope; | |||
var x_tensor = ops.convert_to_tensor(x, name: "x"); | |||
if (x_tensor.dtype.as_base_dtype() != base_type) | |||
x_tensor = gen_math_ops.cast(x_tensor, base_type, name: name); | |||
return x_tensor; | |||
}); | |||
} | |||
public static Tensor cumsum<T>(Tensor x, T axis = default, bool exclusive = false, bool reverse = false, string name = null) | |||
{ | |||
return tf_with(ops.name_scope(name, "Cumsum", new {x}), scope => | |||
@@ -204,6 +219,12 @@ namespace Tensorflow | |||
} | |||
} | |||
public static Tensor reduce_mean(Tensor[] input_tensors, int axis, bool keepdims = false, string name = null) | |||
{ | |||
var m = gen_math_ops.mean(input_tensors, axis, keepdims, name); | |||
return _may_reduce_to_scalar(keepdims, axis, m); | |||
} | |||
/// <summary> | |||
/// Computes the product of elements across dimensions of a tensor. | |||
/// </summary> | |||
@@ -142,6 +142,35 @@ namespace Tensorflow | |||
{ | |||
return ops.convert_to_tensor(shape, name: "shape"); | |||
} | |||
public static Tensor multinomial(Tensor logits, int num_samples, int? seed = null, | |||
string name = null, TF_DataType output_dtype = TF_DataType.DtInvalid) | |||
{ | |||
return tf_with(ops.name_scope(name, "multinomial", new { logits }), delegate | |||
{ | |||
return multinomial_categorical_impl(logits, num_samples, output_dtype, seed); | |||
}); | |||
} | |||
/// <summary> | |||
/// Implementation for random.categorical (v1) and random.categorical (v2). | |||
/// </summary> | |||
/// <param name="logits"></param> | |||
/// <param name="num_samples"></param> | |||
/// <param name="output_dtype"></param> | |||
/// <param name="seed"></param> | |||
/// <returns></returns> | |||
private static Tensor multinomial_categorical_impl(Tensor logits, int num_samples, TF_DataType dtype = TF_DataType.DtInvalid, | |||
int? seed = null) | |||
{ | |||
logits = ops.convert_to_tensor(logits, name: "logits"); | |||
var (seed1, seed2) = random_seed.get_seed(seed); | |||
return gen_random_ops.multinomial(logits, | |||
num_samples, | |||
seed: seed1, | |||
seed2: seed2, | |||
output_dtype: dtype); | |||
} | |||
} | |||
} | |||
@@ -0,0 +1,29 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using System.Threading.Tasks; | |||
namespace Tensorflow.Training | |||
{ | |||
public class learning_rate_decay | |||
{ | |||
/// <summary> | |||
/// Applies a polynomial decay to the learning rate. | |||
/// </summary> | |||
/// <param name="learning_rate"></param> | |||
/// <param name="global_step"></param> | |||
/// <param name="decay_steps"></param> | |||
/// <param name="end_learning_rate"></param> | |||
/// <param name="power"></param> | |||
/// <param name="cycle"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static Tensor polynomial_decay(float learning_rate, RefVariable global_step, float decay_steps, | |||
float end_learning_rate = 0.0001f, float power = 1.0f, bool cycle = false, | |||
string name = null) | |||
{ | |||
throw new NotImplementedException(""); | |||
} | |||
} | |||
} |