From 0e2488ca7aad56ef4d36f2f28fe3303bfa47a847 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 17 Nov 2019 20:08:16 -0600 Subject: [PATCH] math.reduce_sum, tf.variables_initializer --- src/TensorFlowNET.Core/APIs/tf.array.cs | 26 +- src/TensorFlowNET.Core/APIs/tf.math.cs | 12 +- src/TensorFlowNET.Core/APIs/tf.random.cs | 5 + src/TensorFlowNET.Core/APIs/tf.train.cs | 21 + src/TensorFlowNET.Core/APIs/tf.variable.cs | 9 + src/TensorFlowNET.Core/Binding.Util.cs | 6 +- .../Keras/Optimizers/LearningRateSchedule.cs | 16 + .../Keras/Optimizers/PolynomialDecay.cs | 62 + .../Operations/Initializers/GlorotUniform.cs | 4 +- .../Initializers/VarianceScaling.cs | 32 +- .../{array_ops.py.cs => array_ops.cs} | 1347 ++++++++--------- .../Operations/gen_array_ops.cs | 2 +- .../Operations/gen_math_ops.cs | 2 +- ...gen_random_ops.py.cs => gen_random_ops.cs} | 25 +- src/TensorFlowNET.Core/Operations/math_ops.cs | 21 + .../{random_ops.py.cs => random_ops.cs} | 29 + .../Training/learning_rate_decay.cs | 29 + 17 files changed, 943 insertions(+), 705 deletions(-) create mode 100644 src/TensorFlowNET.Core/Keras/Optimizers/LearningRateSchedule.cs create mode 100644 src/TensorFlowNET.Core/Keras/Optimizers/PolynomialDecay.cs rename src/TensorFlowNET.Core/Operations/{array_ops.py.cs => array_ops.cs} (97%) rename src/TensorFlowNET.Core/Operations/{gen_random_ops.py.cs => gen_random_ops.cs} (85%) rename src/TensorFlowNET.Core/Operations/{random_ops.py.cs => random_ops.cs} (83%) create mode 100644 src/TensorFlowNET.Core/Training/learning_rate_decay.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs index 34303bf9..ec17cecc 100644 --- a/src/TensorFlowNET.Core/APIs/tf.array.cs +++ b/src/TensorFlowNET.Core/APIs/tf.array.cs @@ -17,7 +17,9 @@ using NumSharp; using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; +using static Tensorflow.Binding; namespace Tensorflow { @@ -76,7 +78,14 @@ namespace Tensorflow public Tensor concat(IList values, int axis, string name = "concat") { if (values.Count == 1) - throw new NotImplementedException("tf.concat length is 1"); + { + return tf_with(ops.name_scope(name), scope => + { + var tensor = ops.convert_to_tensor(axis, name: "concat_dim", dtype: dtypes.int32); + Debug.Assert(tensor.TensorShape.ndim == 0); + return identity(values[0], name: scope); + }); + } return gen_array_ops.concat_v2(values.ToArray(), axis, name: name); } @@ -111,7 +120,7 @@ namespace Tensorflow /// /// /// - public static Tensor identity(Tensor input, string name = null) + public Tensor identity(Tensor input, string name = null) => array_ops.identity(input, name: name); /// @@ -150,10 +159,10 @@ namespace Tensorflow /// /// /// - public static Tensor reverse(Tensor tensor, int[] axis, string name = null) + public Tensor reverse(Tensor tensor, int[] axis, string name = null) => gen_array_ops.reverse(tensor, axis, name: name); - public static Tensor reverse(Tensor tensor, Tensor axis, string name = null) + public Tensor reverse(Tensor tensor, Tensor axis, string name = null) => gen_array_ops.reverse(tensor, axis, name: name); /// @@ -277,5 +286,14 @@ namespace Tensorflow /// A `Tensor` with all elements set to zero. public Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) => array_ops.zeros_like(tensor, dtype: dtype, name: name, optimize: optimize); + + /// + /// Stops gradient computation. + /// + /// + /// + /// + public Tensor stop_gradient(Tensor x, string name = null) + => gen_array_ops.stop_gradient(x, name: name); } } diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index e3b9e257..6cb43980 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -434,11 +434,14 @@ namespace Tensorflow public Tensor reduce_sum(Tensor input, int? axis = null, int? reduction_indices = null, bool keepdims = false, string name = null) { - if(!axis.HasValue && reduction_indices.HasValue) + if (!axis.HasValue && reduction_indices.HasValue && !keepdims) return math_ops.reduce_sum(input, reduction_indices.Value); - else if (axis.HasValue && !reduction_indices.HasValue) + else if (axis.HasValue && !reduction_indices.HasValue && !keepdims) return math_ops.reduce_sum(input, axis.Value); - return math_ops.reduce_sum(input, keepdims: keepdims, name: name); + else if (axis.HasValue && !reduction_indices.HasValue && keepdims) + return math_ops.reduce_sum(input, keepdims: keepdims, axis: axis.Value, name: name); + else + return math_ops.reduce_sum(input, keepdims: keepdims, name: name); } public Tensor reduce_sum(Tensor input, TensorShape axis, int? reduction_indices = null, @@ -471,6 +474,9 @@ namespace Tensorflow public Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null) => math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name, reduction_indices: reduction_indices); + public Tensor reduce_mean(Tensor[] input_tensors, int axis, bool keepdims = false, string name = null) + => math_ops.reduce_mean(input_tensors, axis: axis, keepdims: keepdims, name: name); + public Tensor round(Tensor x, string name = null) => gen_math_ops.round(x, name: name); diff --git a/src/TensorFlowNET.Core/APIs/tf.random.cs b/src/TensorFlowNET.Core/APIs/tf.random.cs index c11ca791..56fa840d 100644 --- a/src/TensorFlowNET.Core/APIs/tf.random.cs +++ b/src/TensorFlowNET.Core/APIs/tf.random.cs @@ -65,5 +65,10 @@ namespace Tensorflow public void set_random_seed(int seed) => ops.get_default_graph().seed = seed; + + public Tensor multinomial(Tensor logits, int num_samples, int? seed = null, + string name = null, TF_DataType output_dtype = TF_DataType.DtInvalid) + => random_ops.multinomial(logits, num_samples, seed: seed, + name: name, output_dtype: output_dtype); } } diff --git a/src/TensorFlowNET.Core/APIs/tf.train.cs b/src/TensorFlowNET.Core/APIs/tf.train.cs index 03b0a0e2..862212ef 100644 --- a/src/TensorFlowNET.Core/APIs/tf.train.cs +++ b/src/TensorFlowNET.Core/APIs/tf.train.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using System.Collections.Generic; +using Tensorflow.Keras.Optimizers; using Tensorflow.Train; namespace Tensorflow @@ -73,6 +74,26 @@ namespace Tensorflow public CheckpointState get_checkpoint_state(string checkpoint_dir, string latest_filename = null) => checkpoint_management.get_checkpoint_state(checkpoint_dir, latest_filename: latest_filename); + + public Tensor polynomial_decay(float learning_rate, + RefVariable global_step, + float decay_steps, + float end_learning_rate = 0.0001f, + float power = 1.0f, + bool cycle = false, + string name = null) + { + var decayed = new PolynomialDecay(learning_rate, + decay_steps, + end_learning_rate: end_learning_rate, + power: power, + cycle: cycle, + name: name); + + var decayed_lr = decayed.__call__(global_step); + + return decayed_lr; + } } } } diff --git a/src/TensorFlowNET.Core/APIs/tf.variable.cs b/src/TensorFlowNET.Core/APIs/tf.variable.cs index da7fb027..cbdf68ba 100644 --- a/src/TensorFlowNET.Core/APIs/tf.variable.cs +++ b/src/TensorFlowNET.Core/APIs/tf.variable.cs @@ -27,6 +27,15 @@ namespace Tensorflow .ToArray(); } + /// + /// Returns an Op that initializes a list of variables. + /// + /// List of `Variable` objects to initialize. + /// Optional name for the returned operation. + /// An Op that run the initializers of all the specified variables. + public Operation variables_initializer(VariableV1[] var_list, string name = "init") + => variables.variables_initializer(var_list, name: name); + public Operation global_variables_initializer() { var g = variables.global_variables(); diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index 9df8d45c..31ea0d84 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -115,6 +115,7 @@ namespace Tensorflow return instance; } + [DebuggerStepThrough] [DebuggerNonUserCode()] // with "Just My Code" enabled this lets the debugger break at the origin of the exception public static void tf_with(IObjectLife py, Action action) { @@ -273,7 +274,10 @@ namespace Tensorflow return sum; } - public static double sum(IEnumerable enumerable) + public static float sum(IEnumerable enumerable) + => enumerable.Sum(); + + public static int sum(IEnumerable enumerable) => enumerable.Sum(); public static double sum(Dictionary values) diff --git a/src/TensorFlowNET.Core/Keras/Optimizers/LearningRateSchedule.cs b/src/TensorFlowNET.Core/Keras/Optimizers/LearningRateSchedule.cs new file mode 100644 index 00000000..8bcbb58f --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Optimizers/LearningRateSchedule.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Tensorflow.Keras.Optimizers +{ + public class LearningRateSchedule + { + public LearningRateSchedule() + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Optimizers/PolynomialDecay.cs b/src/TensorFlowNET.Core/Keras/Optimizers/PolynomialDecay.cs new file mode 100644 index 00000000..b44595b5 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Optimizers/PolynomialDecay.cs @@ -0,0 +1,62 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Optimizers +{ + /// + /// A LearningRateSchedule that uses a polynomial decay schedule. + /// + public class PolynomialDecay : LearningRateSchedule + { + float initial_learning_rate; + float decay_steps; + float end_learning_rate; + float power; + bool cycle; + string name; + + public PolynomialDecay(float initial_learning_rate, + float decay_steps, + float end_learning_rate = 0.0001f, + float power = 1.0f, + bool cycle = false, + string name = null) : base() + { + this.initial_learning_rate = initial_learning_rate; + this.decay_steps = decay_steps; + this.end_learning_rate = end_learning_rate; + this.power = power; + this.cycle = cycle; + this.name = name; + } + + public Tensor __call__(RefVariable step) + { + tf_with(ops.name_scope(name ?? "PolynomialDecay"), scope => + { + name = scope; + var initial_learning_rate_tensor = ops.convert_to_tensor(initial_learning_rate, name: "initial_learning_rate"); + var dtype = initial_learning_rate_tensor.dtype; + var end_learning_rate_tensor = math_ops.cast(end_learning_rate, dtype); + var power_tensor = math_ops.cast(power, dtype); + + var global_step_recomp = math_ops.cast(step, dtype); + var decay_steps_recomp = math_ops.cast(decay_steps, dtype); + + if(cycle) + { + throw new NotImplementedException("PolynomialDecay cycle"); + } + else + { + + } + }); + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs b/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs index f418f8a3..0eead27d 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs @@ -19,8 +19,7 @@ namespace Tensorflow.Operations.Initializers public class GlorotUniform : VarianceScaling { public GlorotUniform(float scale = 1.0f, - string mode = "fan_avg", - string distribution = "uniform", + string mode = "FAN_AVG", int? seed = null, TF_DataType dtype = TF_DataType.TF_FLOAT) : base(factor: scale, mode: mode, @@ -36,7 +35,6 @@ namespace Tensorflow.Operations.Initializers { scale = _scale, mode = _mode, - distribution = _distribution, seed = _seed, dtype = _dtype }; diff --git a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs index c0bbcd88..41b6689c 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs @@ -30,6 +30,7 @@ namespace Tensorflow.Operations.Initializers protected string _distribution; protected int? _seed; protected TF_DataType _dtype; + protected bool _uniform; public VarianceScaling(float factor = 2.0f, string mode = "FAN_IN", @@ -49,31 +50,31 @@ namespace Tensorflow.Operations.Initializers _mode = mode; _seed = seed; _dtype = dtype; + _uniform = uniform; } public Tensor call(TensorShape shape, TF_DataType dtype, bool? verify_shape = null) { + float n = 0; var (fan_in, fan_out) = _compute_fans(shape); - if (_mode == "fan_in") - _scale /= Math.Max(1, fan_in); - else if (_mode == "fan_out") - _scale /= Math.Max(1, fan_out); - else - _scale /= Math.Max(1, (fan_in + fan_out) / 2); + if (_mode == "FAN_IN") + n = fan_in; + else if (_mode == "FAN_OUT") + n = fan_out; + else if(_mode == "FAN_AVG") + n = (fan_in + fan_out) / 2.0f; - if (_distribution == "normal" || _distribution == "truncated_normal") - { - float stddev = (float)Math.Sqrt(_scale) / .87962566103423978f; - return random_ops.truncated_normal(shape, mean: 0.0f, stddev: stddev, dtype: dtype, seed: _seed); - } - else if (_distribution == "untruncated_normal") + if(_uniform) { - throw new NotImplementedException("truncated_normal"); + var limit = Convert.ToSingle(Math.Sqrt(3.0f * _scale / n)); + return random_ops.random_uniform(shape, -limit, limit, + dtype, seed: _seed); } else { - var limit = Math.Sqrt(3.0f * _scale); - return random_ops.random_uniform(shape, (float)-limit, (float)limit, dtype, seed: _seed); + var trunc_stddev = Convert.ToSingle(Math.Sqrt(1.3f * _scale / n)); + return random_ops.truncated_normal(shape, 0.0f, trunc_stddev, dtype, + seed: _seed); } } @@ -106,6 +107,7 @@ namespace Tensorflow.Operations.Initializers mode = _mode, distribution = _distribution, seed = _seed, + uniform = _uniform, dtype = _dtype }; } diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs similarity index 97% rename from src/TensorFlowNET.Core/Operations/array_ops.py.cs rename to src/TensorFlowNET.Core/Operations/array_ops.cs index 86ab150f..04964069 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -1,674 +1,673 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using NumSharp; -using System; -using System.Collections.Generic; -using System.Linq; -using Tensorflow.Framework; -using static Tensorflow.Binding; - -namespace Tensorflow -{ - public class array_ops - { - public static Tensor placeholder_with_default(T input, int[] shape, string name = null) - => gen_array_ops.placeholder_with_default(input, shape, name); - - public static Tensor prevent_gradient(Tensor input, string message = "", string name = null) - => gen_array_ops.prevent_gradient(input, message: message, name: name); - - internal static Tensor constant(object value, - TF_DataType dtype = TF_DataType.DtInvalid, - int[] shape = null, - string name = "Const", - bool verify_shape = false) => constant_op._constant_impl(value, - dtype, - shape, - name, - verify_shape: verify_shape, - allow_broadcast: false); - - public static Tensor zeros(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) - { - dtype = dtype.as_base_dtype(); - return tf_with(ops.name_scope(name, "zeros", shape), scope => - { - name = scope; - switch (dtype) - { - case TF_DataType.TF_BOOL: - return _constant_if_small(false, shape, dtype, name); - case TF_DataType.TF_DOUBLE: - return _constant_if_small(0.0D, shape, dtype, name); - case TF_DataType.TF_FLOAT: - return _constant_if_small(0.0F, shape, dtype, name); - case TF_DataType.TF_INT64: - return _constant_if_small(0l, shape, dtype, name); - case TF_DataType.TF_INT32: - return _constant_if_small(0, shape, dtype, name); - case TF_DataType.TF_INT8: - return _constant_if_small(0, shape, dtype, name); - default: - throw new TypeError("can't find type for zeros"); - } - }); - } - - public static Tensor boolean_mask(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0) - { - return tf_with(ops.name_scope(name, values: new { tensor, mask }), delegate - { - var tensor_tensor = ops.convert_to_tensor(tensor, name: "tensor"); - var mask_tensor = ops.convert_to_tensor(mask, name: "mask"); - - var shape_mask = mask_tensor.TensorShape; - var ndims_mask = shape_mask.ndim; - var shape_tensor = tensor_tensor.TensorShape; - - if (ndims_mask < 1) - throw new ValueError("mask cannot be scalar."); - - var leading_size = gen_math_ops.prod(shape(tensor_tensor)[$"{axis}:{axis + ndims_mask}"], new[] { 0 }); - var shape1 = concat(new[] - { - shape(tensor_tensor)[$":{axis}"], - tf.expand_dims(leading_size, 0), - shape(tensor_tensor)[$"{axis + ndims_mask}:"] - }, 0); - tensor_tensor = reshape(tensor, shape1); - var first_dim = shape_tensor.dims.Skip(axis).Take(ndims_mask).First(); - var s1 = tensor_shape.as_shape(shape_tensor.dims.Take(axis).ToArray()); - var s2 = s1.concatenate(new[] { first_dim }).concatenate(shape_tensor.dims.Skip(axis + ndims_mask).ToArray()); - tensor_tensor.set_shape(s2); - - mask_tensor = reshape(mask_tensor, new[] { -1 }); - return _apply_mask_1d(tensor_tensor, mask_tensor, axis); - }); - } - - private static Tensor _apply_mask_1d(Tensor reshaped_tensor, Tensor mask, int axis = 0) - { - var indices = squeeze(where(mask), axis: new[] { 1 }); - return gather(reshaped_tensor, indices, axis: axis); - } - - public static Tensor zeros(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) - { - dtype = dtype.as_base_dtype(); - return tf_with(ops.name_scope(name, "zeros", shape), scope => - { - name = scope; - switch (dtype) - { - case TF_DataType.TF_BOOL: - return gen_array_ops.fill(shape, tf.constant(false, dtype: dtype), name: name); - case TF_DataType.TF_DOUBLE: - return gen_array_ops.fill(shape, tf.constant(0.0D, dtype: dtype), name: name); - case TF_DataType.TF_FLOAT: - return gen_array_ops.fill(shape, tf.constant(0.0F, dtype: dtype), name: name); - case TF_DataType.TF_INT32: - return gen_array_ops.fill(shape, tf.constant(0, dtype: dtype), name: name); - default: - throw new TypeError("can't find type for zeros"); - } - - }); - } - - private static Tensor _constant_if_small(int value, Tensor shape) - { - return shape < 1000; - } - - private static Tensor _constant_if_small(T value, TensorShape shape, TF_DataType dtype, string name) - { - Tensor tShape = null; - if (shape.size < 1000) - { - return constant_op.constant(value, shape: shape, dtype: dtype, name: name); - } - else - { - tShape = constant_op._tensor_shape_tensor_conversion_function(shape); - var c = constant_op.constant(0, dtype: dtype); - return gen_array_ops.fill(tShape, c, name: name); - } - } - - public static Tensor _autopacking_conversion_function(object[] v, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false) - { - var inferred_dtype = _get_dtype_from_nested_lists(v); - if (dtype == TF_DataType.DtInvalid) - dtype = inferred_dtype; - - return _autopacking_helper(v, dtype, name == null ? "packed" : name); - } - - private static TF_DataType _get_dtype_from_nested_lists(object[] list_or_tuple) - { - TF_DataType dtype = TF_DataType.DtInvalid; - - foreach(var obj in list_or_tuple) - { - switch (obj) - { - case Tensor t: - dtype = t.dtype.as_base_dtype(); - break; - } - - if (dtype != TF_DataType.DtInvalid) - break; - } - - return dtype; - } - - public static Tensor _autopacking_helper(object[] list_or_tuple, TF_DataType dtype, string name) - { - var must_pack = false; - var converted_elems = new List(); - return tf_with(ops.name_scope(name), scope => - { - foreach (var (i, elem) in enumerate(list_or_tuple)) - { - converted_elems.Add(elem); - must_pack = true; - } - - if(must_pack) - { - var elems_as_tensors = new List(); - foreach (var (i, elem) in enumerate(converted_elems)) - { - if (elem is Tensor tensor) - elems_as_tensors.Add(tensor); - else - { - var elem_tensor = constant_op.constant(elem, dtype: dtype, name: i.ToString()); - elems_as_tensors.Add(elem_tensor); - } - } - - return gen_array_ops.pack(elems_as_tensors.ToArray(), name: scope); - } - else - { - // return converted_elems.ToArray(); - throw new NotImplementedException("_autopacking_helper.converted_elems"); - } - }); - } - - public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1) - => expand_dims_v2(input, axis, name); - - private static Tensor expand_dims_v2(Tensor input, int axis, string name = null) - => gen_array_ops.expand_dims(input, axis, name); - - /// - /// Returns the rank of a tensor. - /// - /// - /// - /// - public static Tensor rank(Tensor input, string name = null) - => rank_internal(input, name, optimize: true); - - public static Tensor rank_internal(Tensor input, string name = null, bool optimize = true) - { - return tf_with(ops.name_scope(name, "Rank", new List { input }), scope => - { - name = scope; - var input_tensor = ops.convert_to_tensor(input); - var input_shape = tensor_util.to_shape(input_tensor.shape); - if (optimize && input_shape.ndim > 0) - return constant_op.constant(input_shape.ndim, dtype: tf.int32, name: name); - else - return gen_array_ops.rank(input, name); - }); - } - - /// - /// Creates a tensor with all elements set to 1. - /// - /// - /// - /// - /// - /// - public static Tensor ones_like(T tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) - => ones_like_impl(tensor, dtype, name, optimize); - - public static Tensor reshape(T1 tensor, T2 shape, string name = null) - => gen_array_ops.reshape(tensor, shape, null); - - private static Tensor ones_like_impl(T tensor, TF_DataType dtype, string name, bool optimize = true) - { - return tf_with(ops.name_scope(name, "ones_like", new { tensor }), scope => - { - name = scope; - var tensor1 = ops.convert_to_tensor(tensor, name: "tensor"); - var ones_shape = shape_internal(tensor1, optimize: optimize); - if (dtype == TF_DataType.DtInvalid) - dtype = tensor1.dtype; - var ret = ones(ones_shape, dtype: dtype, name: name); - ret.shape = tensor1.shape; - return ret; - }); - } - - public static Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) - { - dtype = dtype.as_base_dtype(); - return tf_with(ops.name_scope(name, "ones", new { shape }), scope => - { - name = scope; - var output = gen_array_ops.fill(shape, constant_op.constant(1.0f, dtype: dtype), name: name); - return output; - }); - } - - public static Tensor ones(Tensor[] shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) - { - dtype = dtype.as_base_dtype(); - return tf_with(ops.name_scope(name, "ones", new { shape }), scope => - { - name = scope; - var output = _constant_if_small(1, shape[0]); - var shape1 = ops.convert_to_tensor(shape, dtype: TF_DataType.TF_INT32); - output = gen_array_ops.fill(shape1, constant_op.constant(1, dtype: dtype), name: name); - return output; - }); - } - - public static Tensor ones(int[] dims, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) - { - dtype = dtype.as_base_dtype(); - return tf_with(ops.name_scope(name, "ones", new { dims }), scope => - { - name = scope; - var output = _constant_if_small(1, dims, dtype, name); - return output; - }); - } - - public static Tensor one_hot(Tensor indices, int depth, - Tensor on_value = null, - Tensor off_value = null, - TF_DataType dtype = TF_DataType.DtInvalid, - int axis = -1, - string name = null) - { - return tf_with(ops.name_scope(name, "one_hot", new { indices, depth, dtype }), scope => - { - name = scope; - var on_exists = false; - var off_exists = false; - var on_dtype = TF_DataType.DtInvalid; - var off_dtype = TF_DataType.DtInvalid; - - if (dtype == TF_DataType.DtInvalid) - dtype = TF_DataType.TF_FLOAT; - - if(!on_exists) - { - on_value = ops.convert_to_tensor(1, dtype, name: "on_value"); - on_dtype = dtype; - } - - if (!off_exists) - { - off_value = ops.convert_to_tensor(0, dtype, name = "off_value"); - off_dtype = dtype; - } - - return gen_array_ops.one_hot(indices, depth, - on_value: on_value, - off_value: off_value, - axis: axis, - name: name); - }); - } - - public static (Tensor, Tensor) unique(Tensor x, TF_DataType out_idx = TF_DataType.TF_INT32, string name = null) - => gen_array_ops.unique(x, out_idx: out_idx, name: name); - - public static Tensor stack(Tensor[] values, int axis = 0, string name = "stack") - { - if (axis == 0) - { - return ops.convert_to_tensor(values, name: name); - } - - var value_shape = ops.convert_to_tensor(values[0], name: name).TensorShape; - - return gen_array_ops.pack(values, axis: axis, name: name); - } - - public static Tensor[] unstack(Tensor value, int? num = null, int axis = 0, string name = "unstack") - { - if(num == null) - { - value = ops.convert_to_tensor(value); - var value_shape = value.TensorShape; - num = value_shape.dims[axis]; - } - - return gen_array_ops.unpack(value, num: num.Value, axis: axis, name: name); - } - - public static Tensor where(Tensor condition, object x = null, object y = null, string name = null) - { - if( x == null && y == null) - { - return tf_with(ops.name_scope(name, "Where", new { condition }), scope => - { - name = scope; - condition = ops.convert_to_tensor(condition, preferred_dtype: dtypes.@bool, name: "condition"); - return gen_array_ops.where(condition: condition, name: name); - }); - } - else if(x != null && y != null) - { - return gen_array_ops.select(condition, x, y, name); - } - else - { - throw new ValueError("x and y must both be non-None or both be None."); - } - } - - /// - /// Returns the shape of a tensor. - /// - /// A `Tensor` or `SparseTensor`. - /// A name for the operation (optional). - /// - /// (Optional) The specified output type of the operation - /// (`int32` or `int64`). Defaults to `tf.int32`. - /// - /// A `Tensor` of type `out_type`. - public static Tensor shape(Tensor input, string name = null, TF_DataType out_type = TF_DataType.TF_INT32) - => shape_internal(input, name, optimize: true, out_type: out_type); - - public static Tensor size(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) - => size_internal(input, name, optimize: optimize, out_type: out_type); - - public static Tensor shape_internal(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) - { - return tf_with(ops.name_scope(name, "Shape", new { input }), scope => - { - name = scope; - - if (!tf.context.executing_eagerly()) - { - var input_tensor = ops.convert_to_tensor(input); - var input_shape = tensor_util.to_shape(input_tensor.shape); - if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined()) - { - var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_dtype()); - return constant_op.constant(nd, name: name); - } - } - - return gen_array_ops.shape(input, name: name, out_type: out_type); - }); - } - - private static Tensor size_internal(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) - { - return tf_with(ops.name_scope(name, "Size", new { input }), scope => - { - name = scope; - - var input_tensor = ops.convert_to_tensor(input); - var input_shape = tensor_util.to_shape(input_tensor.shape); - if (optimize) - { - if (input_shape.is_fully_defined()) - { - return constant_op.constant(input_shape.size, dtype: out_type, name: name); - } - } - - return gen_array_ops.size(input, name: name, out_type: out_type); - }); - } - - public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) - { - return tf_with(ops.name_scope(name, "zeros_like", new Tensor[] { tensor }), scope => - { - name = scope; - tensor = ops.convert_to_tensor(tensor, name: "tensor"); - - // is_fully_defined return unexpected value. - if (optimize && tensor_util.to_shape(tensor.shape).is_fully_defined() && dtype != TF_DataType.TF_VARIANT) - { - - } - - if(dtype != TF_DataType.DtInvalid && dtype != tensor.dtype && dtype != TF_DataType.TF_VARIANT) - { - throw new NotImplementedException("zeros_like"); - // return zeros(shape_internal(tensor, optimize: optimize), dtype: dtype, name: name); - } - else - { - return gen_array_ops.zeros_like(tensor, name: name); - } - }); - } - - /// - /// When building ops to compute gradients, this op prevents the contribution of - /// its inputs to be taken into account.Normally, the gradient generator adds ops - /// to a graph to compute the derivatives of a specified 'loss' by recursively - /// finding out inputs that contributed to its computation.If you insert this op - /// in the graph it inputs are masked from the gradient generator. They are not - /// taken into account for computing gradients. - /// - /// - /// - /// - public static Tensor stop_gradient(Tensor input, string name = null) - => gen_array_ops.stop_gradient(input, name); - - /// - /// Extracts a strided slice of a tensor (generalized python array indexing). - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - public static Tensor strided_slice(Tensor input_, Tensor begin, Tensor end, - Tensor strides = null, - int begin_mask = 0, - int end_mask = 0, - int ellipsis_mask = 0, - int new_axis_mask = 0, - int shrink_axis_mask = 0, - string name = null) - { - var op = gen_array_ops.strided_slice( - input: input_, - begin: begin, - end: end, - strides: strides, - begin_mask: begin_mask, - end_mask: end_mask, - ellipsis_mask: ellipsis_mask, - new_axis_mask: new_axis_mask, - shrink_axis_mask: shrink_axis_mask, - name: name); - - string parent_name = name; - - return op; - } - - /// - /// Removes dimensions of size 1 from the shape of a tensor. - /// Given a tensor `input`, this operation returns a tensor of the same type with - /// all dimensions of size 1 removed.If you don't want to remove all size 1 - /// dimensions, you can remove specific size 1 dimensions by specifying - /// `axis`. - /// - /// A `Tensor`. The `input` to squeeze. - /// An optional list of `ints`. Defaults to `[]`. - /// If specified, only squeezes the dimensions listed.The dimension - /// index starts at 0. It is an error to squeeze a dimension that is not 1. - /// Must be in the range `[-rank(input), rank(input))`. - /// A name for the operation (optional). - /// Deprecated keyword argument that is now axis. - /// A `Tensor`. Has the same type as `input`. - /// Contains the same data as `input`, but has one or more dimensions of - /// size 1 removed. - public static Tensor squeeze(Tensor input, int[] axis = null, string name = null, int[] squeeze_dims = null) - => gen_array_ops.squeeze(input, axis, name); - - public static Tensor identity(Tensor input, string name = null) - => gen_array_ops.identity(input, name); - - public static Tensor invert_permutation(Tensor x, string name = null) - => gen_array_ops.invert_permutation(x, name: name); - - /// - /// Computes the shape of a broadcast given symbolic shapes. - /// When shape_x and shape_y are Tensors representing shapes(i.e.the result of - /// calling tf.shape on another Tensor) this computes a Tensor which is the shape - /// of the result of a broadcasting op applied in tensors of shapes shape_x and - /// shape_y. - /// For example, if shape_x is [1, 2, 3] and shape_y is [5, 1, 3], the result is a - /// Tensor whose value is [5, 2, 3]. - /// This is useful when validating the result of a broadcasting operation when the - /// tensors do not have statically known shapes. - /// - /// A rank 1 integer `Tensor`, representing the shape of x. - /// A rank 1 integer `Tensor`, representing the shape of y. - /// A rank 1 integer `Tensor` representing the broadcasted shape. - public static Tensor broadcast_dynamic_shape(Tensor shape_x, Tensor shape_y) - => gen_array_ops.broadcast_args(shape_x, shape_y); - - public static Tensor broadcast_static_shape(Tensor shape_x, Tensor shape_y) - => Framework.common_shapes.broadcast_shape(shape_x, shape_y); - - /// - /// Concatenates tensors along one dimension. - /// - /// - /// - /// - /// - public static Tensor concat(Tensor[] values, int axis, string name = "concat") - { - if(values.Length == 1) // Degenerate case of one tensor. - { - return tf_with(ops.name_scope(name), scope => { - var t = ops.convert_to_tensor(axis, name: "concat_dim", dtype: TF_DataType.TF_INT32); - return identity(values[0], name: scope); - }); - } - - return gen_array_ops.concat_v2(values, axis, name: name); - } - - public static Tensor concat(object[] values, int axis, string name = "concat") - { - return gen_array_ops.concat_v2(values, axis, name: name); - } - - public static Tensor gather(T1 @params, T2 indices, string name = null, int axis = 0) - { - if (axis != 0) - return gen_array_ops.gather_v2(@params, indices, axis, name: name); - - if (@params is ResourceVariable variable && - indices is Tensor indices_tensor) - return variable.sparse_read(indices_tensor, name); - - return gen_array_ops.gather_v2(@params, indices, axis, name: name); - } - - public static Tensor transpose(T1 a, T2 perm, string name = "transpose", bool conjugate = false) - { - return tf_with(ops.name_scope(name, "transpose", new { a }), scope => - { - return gen_array_ops.transpose(a, perm, name: scope); - }); - } - - public static Tensor slice(Tensor input, Tb begin, Ts size, string name = null) - => gen_array_ops.slice(input, begin, size, name: name); - - public static Tensor stack(object values, int axis = 0, string name = "stack") - { - if (axis == 0) - // If the input is a constant list, it can be converted to a constant op - return ops.convert_to_tensor(values, name: name); - - throw new NotImplementedException("array_ops.stack"); - } - - public static Tensor pad(Tensor tensor, Tensor paddings, string mode = "CONSTANT", string name = null, int constant_values = 0) - { - Tensor result = null; - mode = mode.ToUpper(); - if(mode == "CONSTANT") - { - if (constant_values != 0) - throw new NotImplementedException("gen_array_ops.pad_v2"); - else - result = gen_array_ops.pad(tensor, paddings, name: name); - } - - // Restore shape information where possible. - var paddings_constant = tensor_util.constant_value( - result.op.inputs[1], partial: true); - var input_shape = result.op.inputs[0].TensorShape; - if (input_shape.ndim > -1 && - !result.TensorShape.is_fully_defined() && - !(paddings_constant is null)) - { - var new_shape = new List(); - foreach((NDArray padding, int dim) in zip(paddings_constant.GetNDArrays(), np.array(input_shape.dims).GetNDArrays())) - { - if (padding is null || dim == -1 || padding.GetData().Contains(-1)) - new_shape.Add(-1); - else - new_shape.Add(np.sum(padding) + dim); - } - result.set_shape(new_shape.ToArray()); - } - - return result; - } - - public static Tensor placeholder(TF_DataType dtype) - { - throw new NotImplementedException("array_ops.placeholder"); - } - } -} +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using NumSharp; +using System; +using System.Collections.Generic; +using System.Linq; +using Tensorflow.Framework; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class array_ops + { + public static Tensor placeholder_with_default(T input, int[] shape, string name = null) + => gen_array_ops.placeholder_with_default(input, shape, name); + + public static Tensor prevent_gradient(Tensor input, string message = "", string name = null) + => gen_array_ops.prevent_gradient(input, message: message, name: name); + + internal static Tensor constant(object value, + TF_DataType dtype = TF_DataType.DtInvalid, + int[] shape = null, + string name = "Const", + bool verify_shape = false) => constant_op._constant_impl(value, + dtype, + shape, + name, + verify_shape: verify_shape, + allow_broadcast: false); + + public static Tensor zeros(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) + { + dtype = dtype.as_base_dtype(); + return tf_with(ops.name_scope(name, "zeros", shape), scope => + { + name = scope; + switch (dtype) + { + case TF_DataType.TF_BOOL: + return _constant_if_small(false, shape, dtype, name); + case TF_DataType.TF_DOUBLE: + return _constant_if_small(0.0D, shape, dtype, name); + case TF_DataType.TF_FLOAT: + return _constant_if_small(0.0F, shape, dtype, name); + case TF_DataType.TF_INT64: + return _constant_if_small(0l, shape, dtype, name); + case TF_DataType.TF_INT32: + return _constant_if_small(0, shape, dtype, name); + case TF_DataType.TF_INT8: + return _constant_if_small(0, shape, dtype, name); + default: + throw new TypeError("can't find type for zeros"); + } + }); + } + + public static Tensor boolean_mask(T1 tensor, T2 mask, string name = "boolean_mask", int axis = 0) + { + return tf_with(ops.name_scope(name, values: new { tensor, mask }), delegate + { + var tensor_tensor = ops.convert_to_tensor(tensor, name: "tensor"); + var mask_tensor = ops.convert_to_tensor(mask, name: "mask"); + + var shape_mask = mask_tensor.TensorShape; + var ndims_mask = shape_mask.ndim; + var shape_tensor = tensor_tensor.TensorShape; + + if (ndims_mask < 1) + throw new ValueError("mask cannot be scalar."); + + var leading_size = gen_math_ops.prod(shape(tensor_tensor)[$"{axis}:{axis + ndims_mask}"], new[] { 0 }); + var shape1 = concat(new[] + { + shape(tensor_tensor)[$":{axis}"], + tf.expand_dims(leading_size, 0), + shape(tensor_tensor)[$"{axis + ndims_mask}:"] + }, 0); + tensor_tensor = reshape(tensor, shape1); + var first_dim = shape_tensor.dims.Skip(axis).Take(ndims_mask).First(); + var s1 = tensor_shape.as_shape(shape_tensor.dims.Take(axis).ToArray()); + var s2 = s1.concatenate(new[] { first_dim }).concatenate(shape_tensor.dims.Skip(axis + ndims_mask).ToArray()); + tensor_tensor.set_shape(s2); + + mask_tensor = reshape(mask_tensor, new[] { -1 }); + return _apply_mask_1d(tensor_tensor, mask_tensor, axis); + }); + } + + private static Tensor _apply_mask_1d(Tensor reshaped_tensor, Tensor mask, int axis = 0) + { + var indices = squeeze(where(mask), axis: new[] { 1 }); + return gather(reshaped_tensor, indices, axis: axis); + } + + public static Tensor zeros(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) + { + dtype = dtype.as_base_dtype(); + return tf_with(ops.name_scope(name, "zeros", shape), scope => + { + name = scope; + switch (dtype) + { + case TF_DataType.TF_BOOL: + return gen_array_ops.fill(shape, tf.constant(false, dtype: dtype), name: name); + case TF_DataType.TF_DOUBLE: + return gen_array_ops.fill(shape, tf.constant(0.0D, dtype: dtype), name: name); + case TF_DataType.TF_FLOAT: + return gen_array_ops.fill(shape, tf.constant(0.0F, dtype: dtype), name: name); + case TF_DataType.TF_INT32: + return gen_array_ops.fill(shape, tf.constant(0, dtype: dtype), name: name); + default: + throw new TypeError("can't find type for zeros"); + } + + }); + } + + private static Tensor _constant_if_small(int value, Tensor shape) + { + return shape < 1000; + } + + private static Tensor _constant_if_small(T value, TensorShape shape, TF_DataType dtype, string name) + { + Tensor tShape = null; + if (shape.size < 1000) + { + return constant_op.constant(value, shape: shape, dtype: dtype, name: name); + } + else + { + tShape = constant_op._tensor_shape_tensor_conversion_function(shape); + var c = constant_op.constant(0, dtype: dtype); + return gen_array_ops.fill(tShape, c, name: name); + } + } + + public static Tensor _autopacking_conversion_function(object[] v, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false) + { + var inferred_dtype = _get_dtype_from_nested_lists(v); + if (dtype == TF_DataType.DtInvalid) + dtype = inferred_dtype; + + return _autopacking_helper(v, dtype, name == null ? "packed" : name); + } + + private static TF_DataType _get_dtype_from_nested_lists(object[] list_or_tuple) + { + TF_DataType dtype = TF_DataType.DtInvalid; + + foreach(var obj in list_or_tuple) + { + switch (obj) + { + case Tensor t: + dtype = t.dtype.as_base_dtype(); + break; + } + + if (dtype != TF_DataType.DtInvalid) + break; + } + + return dtype; + } + + public static Tensor _autopacking_helper(object[] list_or_tuple, TF_DataType dtype, string name) + { + var must_pack = false; + var converted_elems = new List(); + return tf_with(ops.name_scope(name), scope => + { + foreach (var (i, elem) in enumerate(list_or_tuple)) + { + converted_elems.Add(elem); + must_pack = true; + } + + if(must_pack) + { + var elems_as_tensors = new List(); + foreach (var (i, elem) in enumerate(converted_elems)) + { + if (elem is Tensor tensor) + elems_as_tensors.Add(tensor); + else + { + var elem_tensor = constant_op.constant(elem, dtype: dtype, name: i.ToString()); + elems_as_tensors.Add(elem_tensor); + } + } + + return gen_array_ops.pack(elems_as_tensors.ToArray(), name: scope); + } + else + { + return tf.constant(np.array(new float[0])); + } + }); + } + + public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1) + => expand_dims_v2(input, axis, name); + + private static Tensor expand_dims_v2(Tensor input, int axis, string name = null) + => gen_array_ops.expand_dims(input, axis, name); + + /// + /// Returns the rank of a tensor. + /// + /// + /// + /// + public static Tensor rank(Tensor input, string name = null) + => rank_internal(input, name, optimize: true); + + public static Tensor rank_internal(Tensor input, string name = null, bool optimize = true) + { + return tf_with(ops.name_scope(name, "Rank", new List { input }), scope => + { + name = scope; + var input_tensor = ops.convert_to_tensor(input); + var input_shape = tensor_util.to_shape(input_tensor.shape); + if (optimize && input_shape.ndim > 0) + return constant_op.constant(input_shape.ndim, dtype: tf.int32, name: name); + else + return gen_array_ops.rank(input, name); + }); + } + + /// + /// Creates a tensor with all elements set to 1. + /// + /// + /// + /// + /// + /// + public static Tensor ones_like(T tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) + => ones_like_impl(tensor, dtype, name, optimize); + + public static Tensor reshape(T1 tensor, T2 shape, string name = null) + => gen_array_ops.reshape(tensor, shape, null); + + private static Tensor ones_like_impl(T tensor, TF_DataType dtype, string name, bool optimize = true) + { + return tf_with(ops.name_scope(name, "ones_like", new { tensor }), scope => + { + name = scope; + var tensor1 = ops.convert_to_tensor(tensor, name: "tensor"); + var ones_shape = shape_internal(tensor1, optimize: optimize); + if (dtype == TF_DataType.DtInvalid) + dtype = tensor1.dtype; + var ret = ones(ones_shape, dtype: dtype, name: name); + ret.shape = tensor1.shape; + return ret; + }); + } + + public static Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) + { + dtype = dtype.as_base_dtype(); + return tf_with(ops.name_scope(name, "ones", new { shape }), scope => + { + name = scope; + var output = gen_array_ops.fill(shape, constant_op.constant(1.0f, dtype: dtype), name: name); + return output; + }); + } + + public static Tensor ones(Tensor[] shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) + { + dtype = dtype.as_base_dtype(); + return tf_with(ops.name_scope(name, "ones", new { shape }), scope => + { + name = scope; + var output = _constant_if_small(1, shape[0]); + var shape1 = ops.convert_to_tensor(shape, dtype: TF_DataType.TF_INT32); + output = gen_array_ops.fill(shape1, constant_op.constant(1, dtype: dtype), name: name); + return output; + }); + } + + public static Tensor ones(int[] dims, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) + { + dtype = dtype.as_base_dtype(); + return tf_with(ops.name_scope(name, "ones", new { dims }), scope => + { + name = scope; + var output = _constant_if_small(1, dims, dtype, name); + return output; + }); + } + + public static Tensor one_hot(Tensor indices, int depth, + Tensor on_value = null, + Tensor off_value = null, + TF_DataType dtype = TF_DataType.DtInvalid, + int axis = -1, + string name = null) + { + return tf_with(ops.name_scope(name, "one_hot", new { indices, depth, dtype }), scope => + { + name = scope; + var on_exists = false; + var off_exists = false; + var on_dtype = TF_DataType.DtInvalid; + var off_dtype = TF_DataType.DtInvalid; + + if (dtype == TF_DataType.DtInvalid) + dtype = TF_DataType.TF_FLOAT; + + if(!on_exists) + { + on_value = ops.convert_to_tensor(1, dtype, name: "on_value"); + on_dtype = dtype; + } + + if (!off_exists) + { + off_value = ops.convert_to_tensor(0, dtype, name = "off_value"); + off_dtype = dtype; + } + + return gen_array_ops.one_hot(indices, depth, + on_value: on_value, + off_value: off_value, + axis: axis, + name: name); + }); + } + + public static (Tensor, Tensor) unique(Tensor x, TF_DataType out_idx = TF_DataType.TF_INT32, string name = null) + => gen_array_ops.unique(x, out_idx: out_idx, name: name); + + public static Tensor stack(Tensor[] values, int axis = 0, string name = "stack") + { + if (axis == 0) + { + return ops.convert_to_tensor(values, name: name); + } + + var value_shape = ops.convert_to_tensor(values[0], name: name).TensorShape; + + return gen_array_ops.pack(values, axis: axis, name: name); + } + + public static Tensor[] unstack(Tensor value, int? num = null, int axis = 0, string name = "unstack") + { + if(num == null) + { + value = ops.convert_to_tensor(value); + var value_shape = value.TensorShape; + num = value_shape.dims[axis]; + } + + return gen_array_ops.unpack(value, num: num.Value, axis: axis, name: name); + } + + public static Tensor where(Tensor condition, object x = null, object y = null, string name = null) + { + if( x == null && y == null) + { + return tf_with(ops.name_scope(name, "Where", new { condition }), scope => + { + name = scope; + condition = ops.convert_to_tensor(condition, preferred_dtype: dtypes.@bool, name: "condition"); + return gen_array_ops.where(condition: condition, name: name); + }); + } + else if(x != null && y != null) + { + return gen_array_ops.select(condition, x, y, name); + } + else + { + throw new ValueError("x and y must both be non-None or both be None."); + } + } + + /// + /// Returns the shape of a tensor. + /// + /// A `Tensor` or `SparseTensor`. + /// A name for the operation (optional). + /// + /// (Optional) The specified output type of the operation + /// (`int32` or `int64`). Defaults to `tf.int32`. + /// + /// A `Tensor` of type `out_type`. + public static Tensor shape(Tensor input, string name = null, TF_DataType out_type = TF_DataType.TF_INT32) + => shape_internal(input, name, optimize: true, out_type: out_type); + + public static Tensor size(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) + => size_internal(input, name, optimize: optimize, out_type: out_type); + + public static Tensor shape_internal(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) + { + return tf_with(ops.name_scope(name, "Shape", new { input }), scope => + { + name = scope; + + if (!tf.context.executing_eagerly()) + { + var input_tensor = ops.convert_to_tensor(input); + var input_shape = tensor_util.to_shape(input_tensor.shape); + if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined()) + { + var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_dtype()); + return constant_op.constant(nd, name: name); + } + } + + return gen_array_ops.shape(input, name: name, out_type: out_type); + }); + } + + private static Tensor size_internal(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32) + { + return tf_with(ops.name_scope(name, "Size", new { input }), scope => + { + name = scope; + + var input_tensor = ops.convert_to_tensor(input); + var input_shape = tensor_util.to_shape(input_tensor.shape); + if (optimize) + { + if (input_shape.is_fully_defined()) + { + return constant_op.constant(input_shape.size, dtype: out_type, name: name); + } + } + + return gen_array_ops.size(input, name: name, out_type: out_type); + }); + } + + public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) + { + return tf_with(ops.name_scope(name, "zeros_like", new Tensor[] { tensor }), scope => + { + name = scope; + tensor = ops.convert_to_tensor(tensor, name: "tensor"); + + // is_fully_defined return unexpected value. + if (optimize && tensor_util.to_shape(tensor.shape).is_fully_defined() && dtype != TF_DataType.TF_VARIANT) + { + + } + + if(dtype != TF_DataType.DtInvalid && dtype != tensor.dtype && dtype != TF_DataType.TF_VARIANT) + { + throw new NotImplementedException("zeros_like"); + // return zeros(shape_internal(tensor, optimize: optimize), dtype: dtype, name: name); + } + else + { + return gen_array_ops.zeros_like(tensor, name: name); + } + }); + } + + /// + /// When building ops to compute gradients, this op prevents the contribution of + /// its inputs to be taken into account.Normally, the gradient generator adds ops + /// to a graph to compute the derivatives of a specified 'loss' by recursively + /// finding out inputs that contributed to its computation.If you insert this op + /// in the graph it inputs are masked from the gradient generator. They are not + /// taken into account for computing gradients. + /// + /// + /// + /// + public static Tensor stop_gradient(Tensor input, string name = null) + => gen_array_ops.stop_gradient(input, name); + + /// + /// Extracts a strided slice of a tensor (generalized python array indexing). + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor strided_slice(Tensor input_, Tensor begin, Tensor end, + Tensor strides = null, + int begin_mask = 0, + int end_mask = 0, + int ellipsis_mask = 0, + int new_axis_mask = 0, + int shrink_axis_mask = 0, + string name = null) + { + var op = gen_array_ops.strided_slice( + input: input_, + begin: begin, + end: end, + strides: strides, + begin_mask: begin_mask, + end_mask: end_mask, + ellipsis_mask: ellipsis_mask, + new_axis_mask: new_axis_mask, + shrink_axis_mask: shrink_axis_mask, + name: name); + + string parent_name = name; + + return op; + } + + /// + /// Removes dimensions of size 1 from the shape of a tensor. + /// Given a tensor `input`, this operation returns a tensor of the same type with + /// all dimensions of size 1 removed.If you don't want to remove all size 1 + /// dimensions, you can remove specific size 1 dimensions by specifying + /// `axis`. + /// + /// A `Tensor`. The `input` to squeeze. + /// An optional list of `ints`. Defaults to `[]`. + /// If specified, only squeezes the dimensions listed.The dimension + /// index starts at 0. It is an error to squeeze a dimension that is not 1. + /// Must be in the range `[-rank(input), rank(input))`. + /// A name for the operation (optional). + /// Deprecated keyword argument that is now axis. + /// A `Tensor`. Has the same type as `input`. + /// Contains the same data as `input`, but has one or more dimensions of + /// size 1 removed. + public static Tensor squeeze(Tensor input, int[] axis = null, string name = null, int[] squeeze_dims = null) + => gen_array_ops.squeeze(input, axis, name); + + public static Tensor identity(Tensor input, string name = null) + => gen_array_ops.identity(input, name); + + public static Tensor invert_permutation(Tensor x, string name = null) + => gen_array_ops.invert_permutation(x, name: name); + + /// + /// Computes the shape of a broadcast given symbolic shapes. + /// When shape_x and shape_y are Tensors representing shapes(i.e.the result of + /// calling tf.shape on another Tensor) this computes a Tensor which is the shape + /// of the result of a broadcasting op applied in tensors of shapes shape_x and + /// shape_y. + /// For example, if shape_x is [1, 2, 3] and shape_y is [5, 1, 3], the result is a + /// Tensor whose value is [5, 2, 3]. + /// This is useful when validating the result of a broadcasting operation when the + /// tensors do not have statically known shapes. + /// + /// A rank 1 integer `Tensor`, representing the shape of x. + /// A rank 1 integer `Tensor`, representing the shape of y. + /// A rank 1 integer `Tensor` representing the broadcasted shape. + public static Tensor broadcast_dynamic_shape(Tensor shape_x, Tensor shape_y) + => gen_array_ops.broadcast_args(shape_x, shape_y); + + public static Tensor broadcast_static_shape(Tensor shape_x, Tensor shape_y) + => Framework.common_shapes.broadcast_shape(shape_x, shape_y); + + /// + /// Concatenates tensors along one dimension. + /// + /// + /// + /// + /// + public static Tensor concat(Tensor[] values, int axis, string name = "concat") + { + if(values.Length == 1) // Degenerate case of one tensor. + { + return tf_with(ops.name_scope(name), scope => { + var t = ops.convert_to_tensor(axis, name: "concat_dim", dtype: TF_DataType.TF_INT32); + return identity(values[0], name: scope); + }); + } + + return gen_array_ops.concat_v2(values, axis, name: name); + } + + public static Tensor concat(object[] values, int axis, string name = "concat") + { + return gen_array_ops.concat_v2(values, axis, name: name); + } + + public static Tensor gather(T1 @params, T2 indices, string name = null, int axis = 0) + { + if (axis != 0) + return gen_array_ops.gather_v2(@params, indices, axis, name: name); + + if (@params is ResourceVariable variable && + indices is Tensor indices_tensor) + return variable.sparse_read(indices_tensor, name); + + return gen_array_ops.gather_v2(@params, indices, axis, name: name); + } + + public static Tensor transpose(T1 a, T2 perm, string name = "transpose", bool conjugate = false) + { + return tf_with(ops.name_scope(name, "transpose", new { a }), scope => + { + return gen_array_ops.transpose(a, perm, name: scope); + }); + } + + public static Tensor slice(Tensor input, Tb begin, Ts size, string name = null) + => gen_array_ops.slice(input, begin, size, name: name); + + public static Tensor stack(object values, int axis = 0, string name = "stack") + { + if (axis == 0) + // If the input is a constant list, it can be converted to a constant op + return ops.convert_to_tensor(values, name: name); + + throw new NotImplementedException("array_ops.stack"); + } + + public static Tensor pad(Tensor tensor, Tensor paddings, string mode = "CONSTANT", string name = null, int constant_values = 0) + { + Tensor result = null; + mode = mode.ToUpper(); + if(mode == "CONSTANT") + { + if (constant_values != 0) + throw new NotImplementedException("gen_array_ops.pad_v2"); + else + result = gen_array_ops.pad(tensor, paddings, name: name); + } + + // Restore shape information where possible. + var paddings_constant = tensor_util.constant_value( + result.op.inputs[1], partial: true); + var input_shape = result.op.inputs[0].TensorShape; + if (input_shape.ndim > -1 && + !result.TensorShape.is_fully_defined() && + !(paddings_constant is null)) + { + var new_shape = new List(); + foreach((NDArray padding, int dim) in zip(paddings_constant.GetNDArrays(), np.array(input_shape.dims).GetNDArrays())) + { + if (padding is null || dim == -1 || padding.GetData().Contains(-1)) + new_shape.Add(-1); + else + new_shape.Add(np.sum(padding) + dim); + } + result.set_shape(new_shape.ToArray()); + } + + return result; + } + + public static Tensor placeholder(TF_DataType dtype) + { + throw new NotImplementedException("array_ops.placeholder"); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index cea3e440..29910d04 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -383,7 +383,7 @@ namespace Tensorflow { var _op = _op_def_lib._apply_op_helper("StopGradient", name, args: new { input = x, name }); - return _op.outputs[0]; + return _op.output; } public static Tensor strided_slice(Tensor input, Tensor begin, Tensor end, Tensor strides, diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index e4a8d175..62b0f1b4 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -115,7 +115,7 @@ namespace Tensorflow { var _op = _op_def_lib._apply_op_helper("Mean", name, args: new { input, reduction_indices = axis, keep_dims = keep_dims }); - return _op.outputs[0]; + return _op.output; } public static Tensor prod(T1 input, T2 axis, bool keep_dims = false, string name = null) diff --git a/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_random_ops.cs similarity index 85% rename from src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs rename to src/TensorFlowNET.Core/Operations/gen_random_ops.cs index 011b673f..1bba3a93 100644 --- a/src/TensorFlowNET.Core/Operations/gen_random_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/gen_random_ops.cs @@ -98,7 +98,8 @@ namespace Tensorflow /// /// /// - public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0, string name = null) + public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0, + string name = null) { var _op = _op_def_lib._apply_op_helper("RandomShuffle", name: name, @@ -116,7 +117,8 @@ namespace Tensorflow /// /// /// - public static Tensor truncated_normal(Tensor shape, TF_DataType dtype, int? seed = 0, int? seed2 = 0, string name = null) + public static Tensor truncated_normal(Tensor shape, TF_DataType dtype, int? seed = 0, + int? seed2 = 0, string name = null) { if (!seed.HasValue) seed = 0; @@ -127,7 +129,24 @@ namespace Tensorflow name: name, args: new { shape, dtype, seed, seed2 }); - return _op.outputs[0]; + return _op.output; + } + + public static Tensor multinomial(Tensor logits, int num_samples, int? seed = 0, + int? seed2 = 0, TF_DataType output_dtype = TF_DataType.TF_INT64, string name = null) + { + if (!seed.HasValue) + seed = 0; + if (!seed2.HasValue) + seed2 = 0; + if (output_dtype == TF_DataType.DtInvalid) + output_dtype = TF_DataType.TF_INT64; + + var _op = _op_def_lib._apply_op_helper("Multinomial", + name: name, + args: new { logits, num_samples, seed, seed2, output_dtype }); + + return _op.output; } } } diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index fd73dd0f..848a89cd 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -81,6 +81,21 @@ namespace Tensorflow }); } + public static Tensor cast(float x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null) + { + var base_type = dtype.as_base_dtype(); + + return tf_with(ops.name_scope(name, "Cast", new { x }), scope => + { + name = scope; + var x_tensor = ops.convert_to_tensor(x, name: "x"); + if (x_tensor.dtype.as_base_dtype() != base_type) + x_tensor = gen_math_ops.cast(x_tensor, base_type, name: name); + + return x_tensor; + }); + } + public static Tensor cumsum(Tensor x, T axis = default, bool exclusive = false, bool reverse = false, string name = null) { return tf_with(ops.name_scope(name, "Cumsum", new {x}), scope => @@ -204,6 +219,12 @@ namespace Tensorflow } } + public static Tensor reduce_mean(Tensor[] input_tensors, int axis, bool keepdims = false, string name = null) + { + var m = gen_math_ops.mean(input_tensors, axis, keepdims, name); + return _may_reduce_to_scalar(keepdims, axis, m); + } + /// /// Computes the product of elements across dimensions of a tensor. /// diff --git a/src/TensorFlowNET.Core/Operations/random_ops.py.cs b/src/TensorFlowNET.Core/Operations/random_ops.cs similarity index 83% rename from src/TensorFlowNET.Core/Operations/random_ops.py.cs rename to src/TensorFlowNET.Core/Operations/random_ops.cs index be4aef55..bd718768 100644 --- a/src/TensorFlowNET.Core/Operations/random_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/random_ops.cs @@ -142,6 +142,35 @@ namespace Tensorflow { return ops.convert_to_tensor(shape, name: "shape"); } + + public static Tensor multinomial(Tensor logits, int num_samples, int? seed = null, + string name = null, TF_DataType output_dtype = TF_DataType.DtInvalid) + { + return tf_with(ops.name_scope(name, "multinomial", new { logits }), delegate + { + return multinomial_categorical_impl(logits, num_samples, output_dtype, seed); + }); + } + + /// + /// Implementation for random.categorical (v1) and random.categorical (v2). + /// + /// + /// + /// + /// + /// + private static Tensor multinomial_categorical_impl(Tensor logits, int num_samples, TF_DataType dtype = TF_DataType.DtInvalid, + int? seed = null) + { + logits = ops.convert_to_tensor(logits, name: "logits"); + var (seed1, seed2) = random_seed.get_seed(seed); + return gen_random_ops.multinomial(logits, + num_samples, + seed: seed1, + seed2: seed2, + output_dtype: dtype); + } } } diff --git a/src/TensorFlowNET.Core/Training/learning_rate_decay.cs b/src/TensorFlowNET.Core/Training/learning_rate_decay.cs new file mode 100644 index 00000000..0315789c --- /dev/null +++ b/src/TensorFlowNET.Core/Training/learning_rate_decay.cs @@ -0,0 +1,29 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Tensorflow.Training +{ + public class learning_rate_decay + { + /// + /// Applies a polynomial decay to the learning rate. + /// + /// + /// + /// + /// + /// + /// + /// + /// + public static Tensor polynomial_decay(float learning_rate, RefVariable global_step, float decay_steps, + float end_learning_rate = 0.0001f, float power = 1.0f, bool cycle = false, + string name = null) + { + throw new NotImplementedException(""); + } + } +}