Browse Source

math.reduce_sum, tf.variables_initializer

tags/v0.13
Oceania2018 5 years ago
parent
commit
0e2488ca7a
17 changed files with 943 additions and 705 deletions
  1. +22
    -4
      src/TensorFlowNET.Core/APIs/tf.array.cs
  2. +9
    -3
      src/TensorFlowNET.Core/APIs/tf.math.cs
  3. +5
    -0
      src/TensorFlowNET.Core/APIs/tf.random.cs
  4. +21
    -0
      src/TensorFlowNET.Core/APIs/tf.train.cs
  5. +9
    -0
      src/TensorFlowNET.Core/APIs/tf.variable.cs
  6. +5
    -1
      src/TensorFlowNET.Core/Binding.Util.cs
  7. +16
    -0
      src/TensorFlowNET.Core/Keras/Optimizers/LearningRateSchedule.cs
  8. +62
    -0
      src/TensorFlowNET.Core/Keras/Optimizers/PolynomialDecay.cs
  9. +1
    -3
      src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs
  10. +17
    -15
      src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
  11. +673
    -674
      src/TensorFlowNET.Core/Operations/array_ops.cs
  12. +1
    -1
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  13. +1
    -1
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  14. +22
    -3
      src/TensorFlowNET.Core/Operations/gen_random_ops.cs
  15. +21
    -0
      src/TensorFlowNET.Core/Operations/math_ops.cs
  16. +29
    -0
      src/TensorFlowNET.Core/Operations/random_ops.cs
  17. +29
    -0
      src/TensorFlowNET.Core/Training/learning_rate_decay.cs

+ 22
- 4
src/TensorFlowNET.Core/APIs/tf.array.cs View File

@@ -17,7 +17,9 @@
using NumSharp;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using static Tensorflow.Binding;

namespace Tensorflow
{
@@ -76,7 +78,14 @@ namespace Tensorflow
public Tensor concat(IList<Tensor> values, int axis, string name = "concat")
{
if (values.Count == 1)
throw new NotImplementedException("tf.concat length is 1");
{
return tf_with(ops.name_scope(name), scope =>
{
var tensor = ops.convert_to_tensor(axis, name: "concat_dim", dtype: dtypes.int32);
Debug.Assert(tensor.TensorShape.ndim == 0);
return identity(values[0], name: scope);
});
}

return gen_array_ops.concat_v2(values.ToArray(), axis, name: name);
}
@@ -111,7 +120,7 @@ namespace Tensorflow
/// <param name="input"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor identity(Tensor input, string name = null)
public Tensor identity(Tensor input, string name = null)
=> array_ops.identity(input, name: name);

/// <summary>
@@ -150,10 +159,10 @@ namespace Tensorflow
/// <param name="axis"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor reverse(Tensor tensor, int[] axis, string name = null)
public Tensor reverse(Tensor tensor, int[] axis, string name = null)
=> gen_array_ops.reverse(tensor, axis, name: name);

public static Tensor reverse(Tensor tensor, Tensor axis, string name = null)
public Tensor reverse(Tensor tensor, Tensor axis, string name = null)
=> gen_array_ops.reverse(tensor, axis, name: name);

/// <summary>
@@ -277,5 +286,14 @@ namespace Tensorflow
/// <returns>A `Tensor` with all elements set to zero.</returns>
public Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true)
=> array_ops.zeros_like(tensor, dtype: dtype, name: name, optimize: optimize);

/// <summary>
/// Stops gradient computation.
/// </summary>
/// <param name="x"></param>
/// <param name="name"></param>
/// <returns></returns>
public Tensor stop_gradient(Tensor x, string name = null)
=> gen_array_ops.stop_gradient(x, name: name);
}
}

+ 9
- 3
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -434,11 +434,14 @@ namespace Tensorflow
public Tensor reduce_sum(Tensor input, int? axis = null, int? reduction_indices = null,
bool keepdims = false, string name = null)
{
if(!axis.HasValue && reduction_indices.HasValue)
if (!axis.HasValue && reduction_indices.HasValue && !keepdims)
return math_ops.reduce_sum(input, reduction_indices.Value);
else if (axis.HasValue && !reduction_indices.HasValue)
else if (axis.HasValue && !reduction_indices.HasValue && !keepdims)
return math_ops.reduce_sum(input, axis.Value);
return math_ops.reduce_sum(input, keepdims: keepdims, name: name);
else if (axis.HasValue && !reduction_indices.HasValue && keepdims)
return math_ops.reduce_sum(input, keepdims: keepdims, axis: axis.Value, name: name);
else
return math_ops.reduce_sum(input, keepdims: keepdims, name: name);
}

public Tensor reduce_sum(Tensor input, TensorShape axis, int? reduction_indices = null,
@@ -471,6 +474,9 @@ namespace Tensorflow
public Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null)
=> math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name, reduction_indices: reduction_indices);

public Tensor reduce_mean(Tensor[] input_tensors, int axis, bool keepdims = false, string name = null)
=> math_ops.reduce_mean(input_tensors, axis: axis, keepdims: keepdims, name: name);

public Tensor round(Tensor x, string name = null)
=> gen_math_ops.round(x, name: name);



+ 5
- 0
src/TensorFlowNET.Core/APIs/tf.random.cs View File

@@ -65,5 +65,10 @@ namespace Tensorflow

public void set_random_seed(int seed)
=> ops.get_default_graph().seed = seed;

public Tensor multinomial(Tensor logits, int num_samples, int? seed = null,
string name = null, TF_DataType output_dtype = TF_DataType.DtInvalid)
=> random_ops.multinomial(logits, num_samples, seed: seed,
name: name, output_dtype: output_dtype);
}
}

+ 21
- 0
src/TensorFlowNET.Core/APIs/tf.train.cs View File

@@ -15,6 +15,7 @@
******************************************************************************/

using System.Collections.Generic;
using Tensorflow.Keras.Optimizers;
using Tensorflow.Train;

namespace Tensorflow
@@ -73,6 +74,26 @@ namespace Tensorflow

public CheckpointState get_checkpoint_state(string checkpoint_dir, string latest_filename = null)
=> checkpoint_management.get_checkpoint_state(checkpoint_dir, latest_filename: latest_filename);

public Tensor polynomial_decay(float learning_rate,
RefVariable global_step,
float decay_steps,
float end_learning_rate = 0.0001f,
float power = 1.0f,
bool cycle = false,
string name = null)
{
var decayed = new PolynomialDecay(learning_rate,
decay_steps,
end_learning_rate: end_learning_rate,
power: power,
cycle: cycle,
name: name);

var decayed_lr = decayed.__call__(global_step);

return decayed_lr;
}
}
}
}

+ 9
- 0
src/TensorFlowNET.Core/APIs/tf.variable.cs View File

@@ -27,6 +27,15 @@ namespace Tensorflow
.ToArray();
}

/// <summary>
/// Returns an Op that initializes a list of variables.
/// </summary>
/// <param name="var_list">List of `Variable` objects to initialize.</param>
/// <param name="name">Optional name for the returned operation.</param>
/// <returns>An Op that run the initializers of all the specified variables.</returns>
public Operation variables_initializer(VariableV1[] var_list, string name = "init")
=> variables.variables_initializer(var_list, name: name);

public Operation global_variables_initializer()
{
var g = variables.global_variables();


+ 5
- 1
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -115,6 +115,7 @@ namespace Tensorflow
return instance;
}

[DebuggerStepThrough]
[DebuggerNonUserCode()] // with "Just My Code" enabled this lets the debugger break at the origin of the exception
public static void tf_with(IObjectLife py, Action<IObjectLife> action)
{
@@ -273,7 +274,10 @@ namespace Tensorflow
return sum;
}

public static double sum(IEnumerable<int> enumerable)
public static float sum(IEnumerable<float> enumerable)
=> enumerable.Sum();

public static int sum(IEnumerable<int> enumerable)
=> enumerable.Sum();

public static double sum<TKey, TValue>(Dictionary<TKey, TValue> values)


+ 16
- 0
src/TensorFlowNET.Core/Keras/Optimizers/LearningRateSchedule.cs View File

@@ -0,0 +1,16 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace Tensorflow.Keras.Optimizers
{
public class LearningRateSchedule
{
public LearningRateSchedule()
{

}
}
}

+ 62
- 0
src/TensorFlowNET.Core/Keras/Optimizers/PolynomialDecay.cs View File

@@ -0,0 +1,62 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Optimizers
{
/// <summary>
/// A LearningRateSchedule that uses a polynomial decay schedule.
/// </summary>
public class PolynomialDecay : LearningRateSchedule
{
float initial_learning_rate;
float decay_steps;
float end_learning_rate;
float power;
bool cycle;
string name;

public PolynomialDecay(float initial_learning_rate,
float decay_steps,
float end_learning_rate = 0.0001f,
float power = 1.0f,
bool cycle = false,
string name = null) : base()
{
this.initial_learning_rate = initial_learning_rate;
this.decay_steps = decay_steps;
this.end_learning_rate = end_learning_rate;
this.power = power;
this.cycle = cycle;
this.name = name;
}

public Tensor __call__(RefVariable step)
{
tf_with(ops.name_scope(name ?? "PolynomialDecay"), scope =>
{
name = scope;
var initial_learning_rate_tensor = ops.convert_to_tensor(initial_learning_rate, name: "initial_learning_rate");
var dtype = initial_learning_rate_tensor.dtype;
var end_learning_rate_tensor = math_ops.cast(end_learning_rate, dtype);
var power_tensor = math_ops.cast(power, dtype);

var global_step_recomp = math_ops.cast(step, dtype);
var decay_steps_recomp = math_ops.cast(decay_steps, dtype);

if(cycle)
{
throw new NotImplementedException("PolynomialDecay cycle");
}
else
{

}
});
throw new NotImplementedException("");
}
}
}

+ 1
- 3
src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs View File

@@ -19,8 +19,7 @@ namespace Tensorflow.Operations.Initializers
public class GlorotUniform : VarianceScaling
{
public GlorotUniform(float scale = 1.0f,
string mode = "fan_avg",
string distribution = "uniform",
string mode = "FAN_AVG",
int? seed = null,
TF_DataType dtype = TF_DataType.TF_FLOAT) : base(factor: scale,
mode: mode,
@@ -36,7 +35,6 @@ namespace Tensorflow.Operations.Initializers
{
scale = _scale,
mode = _mode,
distribution = _distribution,
seed = _seed,
dtype = _dtype
};


+ 17
- 15
src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs View File

@@ -30,6 +30,7 @@ namespace Tensorflow.Operations.Initializers
protected string _distribution;
protected int? _seed;
protected TF_DataType _dtype;
protected bool _uniform;

public VarianceScaling(float factor = 2.0f,
string mode = "FAN_IN",
@@ -49,31 +50,31 @@ namespace Tensorflow.Operations.Initializers
_mode = mode;
_seed = seed;
_dtype = dtype;
_uniform = uniform;
}

public Tensor call(TensorShape shape, TF_DataType dtype, bool? verify_shape = null)
{
float n = 0;
var (fan_in, fan_out) = _compute_fans(shape);
if (_mode == "fan_in")
_scale /= Math.Max(1, fan_in);
else if (_mode == "fan_out")
_scale /= Math.Max(1, fan_out);
else
_scale /= Math.Max(1, (fan_in + fan_out) / 2);
if (_mode == "FAN_IN")
n = fan_in;
else if (_mode == "FAN_OUT")
n = fan_out;
else if(_mode == "FAN_AVG")
n = (fan_in + fan_out) / 2.0f;

if (_distribution == "normal" || _distribution == "truncated_normal")
{
float stddev = (float)Math.Sqrt(_scale) / .87962566103423978f;
return random_ops.truncated_normal(shape, mean: 0.0f, stddev: stddev, dtype: dtype, seed: _seed);
}
else if (_distribution == "untruncated_normal")
if(_uniform)
{
throw new NotImplementedException("truncated_normal");
var limit = Convert.ToSingle(Math.Sqrt(3.0f * _scale / n));
return random_ops.random_uniform(shape, -limit, limit,
dtype, seed: _seed);
}
else
{
var limit = Math.Sqrt(3.0f * _scale);
return random_ops.random_uniform(shape, (float)-limit, (float)limit, dtype, seed: _seed);
var trunc_stddev = Convert.ToSingle(Math.Sqrt(1.3f * _scale / n));
return random_ops.truncated_normal(shape, 0.0f, trunc_stddev, dtype,
seed: _seed);
}
}

@@ -106,6 +107,7 @@ namespace Tensorflow.Operations.Initializers
mode = _mode,
distribution = _distribution,
seed = _seed,
uniform = _uniform,
dtype = _dtype
};
}


src/TensorFlowNET.Core/Operations/array_ops.cs
File diff suppressed because it is too large
View File


+ 1
- 1
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -383,7 +383,7 @@ namespace Tensorflow
{
var _op = _op_def_lib._apply_op_helper("StopGradient", name, args: new { input = x, name });

return _op.outputs[0];
return _op.output;
}

public static Tensor strided_slice(Tensor input, Tensor begin, Tensor end, Tensor strides,


+ 1
- 1
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -115,7 +115,7 @@ namespace Tensorflow
{
var _op = _op_def_lib._apply_op_helper("Mean", name, args: new { input, reduction_indices = axis, keep_dims = keep_dims });
return _op.outputs[0];
return _op.output;
}
public static Tensor prod<T1, T2>(T1 input, T2 axis, bool keep_dims = false, string name = null)


src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs → src/TensorFlowNET.Core/Operations/gen_random_ops.cs View File

@@ -98,7 +98,8 @@ namespace Tensorflow
/// <param name="seed2"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0, string name = null)
public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0,
string name = null)
{
var _op = _op_def_lib._apply_op_helper("RandomShuffle",
name: name,
@@ -116,7 +117,8 @@ namespace Tensorflow
/// <param name="seed2"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor truncated_normal(Tensor shape, TF_DataType dtype, int? seed = 0, int? seed2 = 0, string name = null)
public static Tensor truncated_normal(Tensor shape, TF_DataType dtype, int? seed = 0,
int? seed2 = 0, string name = null)
{
if (!seed.HasValue)
seed = 0;
@@ -127,7 +129,24 @@ namespace Tensorflow
name: name,
args: new { shape, dtype, seed, seed2 });

return _op.outputs[0];
return _op.output;
}

public static Tensor multinomial(Tensor logits, int num_samples, int? seed = 0,
int? seed2 = 0, TF_DataType output_dtype = TF_DataType.TF_INT64, string name = null)
{
if (!seed.HasValue)
seed = 0;
if (!seed2.HasValue)
seed2 = 0;
if (output_dtype == TF_DataType.DtInvalid)
output_dtype = TF_DataType.TF_INT64;

var _op = _op_def_lib._apply_op_helper("Multinomial",
name: name,
args: new { logits, num_samples, seed, seed2, output_dtype });

return _op.output;
}
}
}

+ 21
- 0
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -81,6 +81,21 @@ namespace Tensorflow
});
}

public static Tensor cast(float x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
{
var base_type = dtype.as_base_dtype();

return tf_with(ops.name_scope(name, "Cast", new { x }), scope =>
{
name = scope;
var x_tensor = ops.convert_to_tensor(x, name: "x");
if (x_tensor.dtype.as_base_dtype() != base_type)
x_tensor = gen_math_ops.cast(x_tensor, base_type, name: name);

return x_tensor;
});
}

public static Tensor cumsum<T>(Tensor x, T axis = default, bool exclusive = false, bool reverse = false, string name = null)
{
return tf_with(ops.name_scope(name, "Cumsum", new {x}), scope =>
@@ -204,6 +219,12 @@ namespace Tensorflow
}
}

public static Tensor reduce_mean(Tensor[] input_tensors, int axis, bool keepdims = false, string name = null)
{
var m = gen_math_ops.mean(input_tensors, axis, keepdims, name);
return _may_reduce_to_scalar(keepdims, axis, m);
}

/// <summary>
/// Computes the product of elements across dimensions of a tensor.
/// </summary>


src/TensorFlowNET.Core/Operations/random_ops.py.cs → src/TensorFlowNET.Core/Operations/random_ops.cs View File

@@ -142,6 +142,35 @@ namespace Tensorflow
{
return ops.convert_to_tensor(shape, name: "shape");
}

public static Tensor multinomial(Tensor logits, int num_samples, int? seed = null,
string name = null, TF_DataType output_dtype = TF_DataType.DtInvalid)
{
return tf_with(ops.name_scope(name, "multinomial", new { logits }), delegate
{
return multinomial_categorical_impl(logits, num_samples, output_dtype, seed);
});
}

/// <summary>
/// Implementation for random.categorical (v1) and random.categorical (v2).
/// </summary>
/// <param name="logits"></param>
/// <param name="num_samples"></param>
/// <param name="output_dtype"></param>
/// <param name="seed"></param>
/// <returns></returns>
private static Tensor multinomial_categorical_impl(Tensor logits, int num_samples, TF_DataType dtype = TF_DataType.DtInvalid,
int? seed = null)
{
logits = ops.convert_to_tensor(logits, name: "logits");
var (seed1, seed2) = random_seed.get_seed(seed);
return gen_random_ops.multinomial(logits,
num_samples,
seed: seed1,
seed2: seed2,
output_dtype: dtype);
}
}
}


+ 29
- 0
src/TensorFlowNET.Core/Training/learning_rate_decay.cs View File

@@ -0,0 +1,29 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace Tensorflow.Training
{
public class learning_rate_decay
{
/// <summary>
/// Applies a polynomial decay to the learning rate.
/// </summary>
/// <param name="learning_rate"></param>
/// <param name="global_step"></param>
/// <param name="decay_steps"></param>
/// <param name="end_learning_rate"></param>
/// <param name="power"></param>
/// <param name="cycle"></param>
/// <param name="name"></param>
/// <returns></returns>
public static Tensor polynomial_decay(float learning_rate, RefVariable global_step, float decay_steps,
float end_learning_rate = 0.0001f, float power = 1.0f, bool cycle = false,
string name = null)
{
throw new NotImplementedException("");
}
}
}

Loading…
Cancel
Save