diff --git a/README.md b/README.md index 77688745..d157e09d 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ [![Badge](https://img.shields.io/badge/link-996.icu-red.svg)](https://996.icu/#/en_US) [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/javiercp/BinderTF.NET/master?urlpath=lab) -*master branch is based on tensorflow 2.3 now, v0.15-tensorflow1.15 is from tensorflow1.15.* +*master branch is based on tensorflow v2.4, v0.3x branch is based on tensorflow v2.3, v0.15-tensorflow1.15 is from tensorflow1.15.* ![tensors_flowing](docs/assets/tensors_flowing.gif) @@ -30,7 +30,8 @@ Go through the online docs [TensorFlow for .NET](https://scisharp.github.io/tens | TensorFlow | tf native1.14, cuda 10.0 | tf native 1.15, cuda 10.0 | tf native 2.3, cuda 10.1 | tf native 2.4, cuda 11 | | -------------------------- | ------------- | -------------- | ------------- | ------------- | -| tf.net 0.3x, tf.keras 0.2 | | | x | not compatible | +| tf.net 0.4x, tf.keras 0.5 | | | | x | +| tf.net 0.3x, tf.keras 0.4 | | | x | | | tf.net 0.2x | | x | x | | | tf.net 0.15 | x | x | | | | tf.net 0.14 | x | | | | @@ -50,10 +51,10 @@ PM> Install-Package TensorFlow.Keras ### Install tensorflow binary ### For CPU version -PM> Install-Package SciSharp.TensorFlow.Redist -Version 2.3.1 +PM> Install-Package SciSharp.TensorFlow.Redist ### For GPU version (CUDA and cuDNN are required) -PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU -Version 2.3.1 +PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU ``` Import TF.NET and Keras API in your project. diff --git a/src/TensorFlowNET.Console/MemoryBasicTest.cs b/src/TensorFlowNET.Console/MemoryBasicTest.cs index bbb23391..9586fe4e 100644 --- a/src/TensorFlowNET.Console/MemoryBasicTest.cs +++ b/src/TensorFlowNET.Console/MemoryBasicTest.cs @@ -112,16 +112,18 @@ namespace Tensorflow var strides = new[] { 1, 1, 1, 1 }; var dilations = new[] { 1, 1, 1, 1 }; - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Conv2D", null, - null, - input, filter, - "strides", strides, - "use_cudnn_on_gpu", true, - "padding", "VALID", - "explicit_paddings", new int[0], - "data_format", "NHWC", - "dilations", dilations); + var results = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("Conv2D", null, input, filter) + { + attrs = ConvertToDict(new + { + strides, + use_cudnn_on_gpu = true, + padding = "VALID", + explicit_paddings = new int[0], + data_format = "NHWC", + dilations + }) + }); }; public Action Conv2DWithVariable @@ -132,16 +134,18 @@ namespace Tensorflow var strides = new[] { 1, 1, 1, 1 }; var dilations = new[] { 1, 1, 1, 1 }; - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Conv2D", null, - null, - input, filter, - "strides", strides, - "use_cudnn_on_gpu", true, - "padding", "VALID", - "explicit_paddings", new int[0], - "data_format", "NHWC", - "dilations", dilations); + var results = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("Conv2D", null, input, filter) + { + attrs = ConvertToDict(new + { + strides, + use_cudnn_on_gpu = true, + padding = "VALID", + explicit_paddings = new int[0], + data_format = "NHWC", + dilations + }) + }); }; public Action Dataset diff --git a/src/TensorFlowNET.Console/Tensorflow.Console.csproj b/src/TensorFlowNET.Console/Tensorflow.Console.csproj index bc2c90d5..d6a76889 100644 --- a/src/TensorFlowNET.Console/Tensorflow.Console.csproj +++ b/src/TensorFlowNET.Console/Tensorflow.Console.csproj @@ -11,6 +11,7 @@ TRACE;DEBUG + x64 diff --git a/src/TensorFlowNET.Core/APIs/tf.linalg.cs b/src/TensorFlowNET.Core/APIs/tf.linalg.cs index bf8c358c..beb3122c 100644 --- a/src/TensorFlowNET.Core/APIs/tf.linalg.cs +++ b/src/TensorFlowNET.Core/APIs/tf.linalg.cs @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ******************************************************************************/ +using static Tensorflow.Binding; namespace Tensorflow { @@ -37,8 +38,8 @@ namespace Tensorflow public Tensor matmul(Tensor a, Tensor b) => math_ops.matmul(a, b); - public Tensor batch_matmul(Tensor x, Tensor y) - => gen_math_ops.batch_mat_mul(x, y); + public Tensor batch_matmul(Tensor x, Tensor y, bool adj_x = false, bool adj_y = false, string name = null) + => math_ops.batch_matmul(x, y, adj_x: adj_x, adj_y: adj_y, name: name); } public Tensor diag(Tensor diagonal, string name = null) @@ -47,7 +48,32 @@ namespace Tensorflow public Tensor matmul(Tensor a, Tensor b) => math_ops.matmul(a, b); - public Tensor batch_matmul(Tensor x, Tensor y) - => gen_math_ops.batch_mat_mul(x, y); + /// + /// Multiply slices of the two matrices "x" and "y". + /// + /// + /// The `BatchMatMul` operation is embedded into the + /// `MatMul` operation on the DLL side. However the expected + /// attributes are not the same, hence we need to expose this + /// method to have the right args list on the `_apply_op_helper` + /// function. + /// + /// For each rank > 2 the first rank - 2 dimensions are considered + /// as fixed, and have to be consistent across the two matrices. A + /// common matrix multiplication is then applied over the residual + /// 2 dimensions. + /// + /// e.g. + /// x is (3, 6, 12); y is (3, 12, 6) + /// batch_matmul(x, y) ==> (3, 6, 6) + /// + /// + /// + /// + /// + /// + /// + public Tensor batch_matmul(Tensor x, Tensor y, bool adj_x = false, bool adj_y = false, string name = null) + => math_ops.batch_matmul(x, y, adj_x: adj_x, adj_y: adj_y, name: name); } } diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index ff43c206..f438f870 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -32,6 +32,28 @@ namespace Tensorflow /// public Tensor erf(Tensor x, string name = null) => math_ops.erf(x, name); + + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + /// + public Tensor bincount(Tensor arr, Tensor weights = null, + Tensor minlength = null, + Tensor maxlength = null, + TF_DataType dtype = TF_DataType.TF_INT32, + string name = null, + TensorShape axis = null, + bool binary_output = false) + => math_ops.bincount(arr, weights: weights, minlength: minlength, maxlength: maxlength, + dtype: dtype, name: name, axis: axis, binary_output: binary_output); } public Tensor abs(Tensor x, string name = null) diff --git a/src/TensorFlowNET.Core/APIs/tf.random.cs b/src/TensorFlowNET.Core/APIs/tf.random.cs index bd74c8fd..d5656e87 100644 --- a/src/TensorFlowNET.Core/APIs/tf.random.cs +++ b/src/TensorFlowNET.Core/APIs/tf.random.cs @@ -93,7 +93,12 @@ namespace Tensorflow => random_ops.random_shuffle(value, seed: seed, name: name); public void set_random_seed(int seed) - => ops.get_default_graph().seed = seed; + { + if (executing_eagerly()) + Context.set_global_seed(seed); + else + ops.get_default_graph().seed = seed; + } public Tensor multinomial(Tensor logits, int num_samples, int? seed = null, string name = null, TF_DataType output_dtype = TF_DataType.DtInvalid) diff --git a/src/TensorFlowNET.Core/APIs/tf.sparse.cs b/src/TensorFlowNET.Core/APIs/tf.sparse.cs index 11f6a55d..7de02f33 100644 --- a/src/TensorFlowNET.Core/APIs/tf.sparse.cs +++ b/src/TensorFlowNET.Core/APIs/tf.sparse.cs @@ -14,17 +14,18 @@ limitations under the License. ******************************************************************************/ +using System; using Tensorflow.Framework; namespace Tensorflow { public partial class tensorflow { - public SparseTensor SparseTensor(long[,] indices, T[] values, long[] dense_shape) - => new SparseTensor(indices, values, dense_shape); + public SparseTensor SparseTensor(long[,] indices, Array values, long[] dense_shape) + => new SparseTensor(indices, values, dense_shape); - public Tensor sparse_tensor_to_dense(SparseTensor sp_input, - T default_value = default, + public Tensor sparse_tensor_to_dense(SparseTensor sp_input, + Array default_value = default, bool validate_indices = true, string name = null) => gen_sparse_ops.sparse_to_dense(sp_input.indices, diff --git a/src/TensorFlowNET.Core/APIs/tf.strings.cs b/src/TensorFlowNET.Core/APIs/tf.strings.cs index f580a67d..38a40eb4 100644 --- a/src/TensorFlowNET.Core/APIs/tf.strings.cs +++ b/src/TensorFlowNET.Core/APIs/tf.strings.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using Tensorflow.Framework; + namespace Tensorflow { public partial class tensorflow @@ -64,6 +66,27 @@ namespace Tensorflow public Tensor substr(string input, int pos, int len, string name = null, string @uint = "BYTE") => ops.substr(input, pos, len, @uint: @uint, name: name); + + /// + /// String lengths of `input`. + /// + /// + /// + /// + /// + public Tensor string_length(Tensor input, string name = null, string unit = "BYTE") + => ops.string_length(input, name: name, unit: unit); + + public RaggedTensor split(Tensor input, string sep = "", int maxsplit = -1, string name = null) + => ops.string_split_v2(input, sep: sep, maxsplit : maxsplit, name : name); + + public (RaggedTensor, RaggedTensor) unicode_decode_with_offsets(Tensor input, string input_encoding, + string errors = "replace", int replacement_char = 0xFFFD, + bool replace_control_characters = false, string name = null) + => ops.unicode_decode_with_offsets(input, input_encoding, errors, + replacement_char: replacement_char, + replace_control_characters: replace_control_characters, + name: name); } } } diff --git a/src/TensorFlowNET.Core/Contexts/AutoModeArgs.cs b/src/TensorFlowNET.Core/Contexts/AutoModeArgs.cs deleted file mode 100644 index 4fcd4bf0..00000000 --- a/src/TensorFlowNET.Core/Contexts/AutoModeArgs.cs +++ /dev/null @@ -1,13 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; - -namespace Tensorflow -{ - public class AutoModeArgs - { - public Func GetGradientAttrs { get; set; } - public object OpInputArgs { get; set; } - public object OpAttrs { get; set; } - } -} diff --git a/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs b/src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs similarity index 62% rename from src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs rename to src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs index 2f22865c..def787a9 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.ExecuteOp.cs @@ -30,67 +30,35 @@ namespace Tensorflow.Contexts public sealed partial class Context { // [DebuggerStepThrough] - public T RunInAutoMode(Func graphAction, Func eagerAction, params object[] args) + public Tensors ExecuteOp(string OpType, string Name, ExecuteOpArgs args) { - if (tf.Context.has_graph_arg(args)) + Func graphAction = () => { - if (executing_eagerly()) - { - graph_mode(); - var result = graphAction(); - restore_mode(); - return result; - } - else - { - return graphAction(); - } - } - else - { - if (tf.Context.executing_eagerly()) + var keywords = new Dictionary(); + if(args.OpInputArgs != null) { - return eagerAction(); + foreach (var (i, input) in enumerate(args.OpInputArgs)) + keywords[$"input_{i}"] = input; } - else + + if(args.OpAttrs != null) { - return graphAction(); + foreach (var attr in args.OpAttrs) + keywords[attr.Key] = attr.Value; } - } - } - - // [DebuggerStepThrough] - public Tensors RunInAutoMode2(string OpType, string Name, AutoModeArgs args) - { - var inputArgs = ConvertToDict(args.OpInputArgs); - var attrDict = ConvertToDict(args.OpAttrs); - - Func graphAction = () => - { - foreach (var attr in attrDict) - inputArgs[attr.Key] = attr.Value; - return tf.OpDefLib._apply_op_helper(OpType, Name, inputArgs).output; + + return tf.OpDefLib._apply_op_helper(OpType, Name, keywords).outputs; }; - Func eagerAction = () => + Func eagerAction = () => { - var attrs = new object[attrDict.Count() * 2]; - int i = 0; - foreach(var arg in attrDict) + return tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(OpType, Name, args.OpInputArgs) { - attrs[i]= arg.Key; - attrs[i + 1] = arg.Value; - i += 2; - } - - return tf.Runner.TFE_FastPathExecute2(tf.Context, tf.Context.DeviceName, - OpType, Name, - null, - inputArgs.Values.ToArray(), - attrs).FirstOrDefault(); + attrs = args.OpAttrs + }); }; - if (tf.Context.has_graph_arg(inputArgs.Values)) + if (tf.Context.has_graph_arg(args.OpInputArgs)) { if (executing_eagerly()) { diff --git a/src/TensorFlowNET.Core/Contexts/Context.cs b/src/TensorFlowNET.Core/Contexts/Context.cs index be4b56b2..95f75a94 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.cs @@ -42,6 +42,9 @@ namespace Tensorflow.Contexts SafeContextHandle _handle; public SafeContextHandle Handle => _handle; + int? _seed; + Random _rng; + public Context() { _device_policy = ContextDevicePlacementPolicy.DEVICE_PLACEMENT_SILENT; @@ -71,6 +74,24 @@ namespace Tensorflow.Contexts initialized = true; } + public void set_global_seed(int? seed) + { + _seed = seed; + if (seed.HasValue) + _rng = new Random(seed.Value); + else + _rng = null; + // Also clear the kernel cache, to reset any existing seeds + if (_handle != null) + c_api.TFE_ContextClearCaches(_handle); + } + + public int? global_seed() + => _seed; + + public int? internal_operation_seed() + => _rng?.Next(0, int.MaxValue); + public void start_step() => c_api.TFE_ContextStartStep(_handle); @@ -86,7 +107,7 @@ namespace Tensorflow.Contexts { if(context_switches.Count() == 0) tf.enable_eager_execution(); - + return context_switches.Current().EagerMode; } @@ -115,7 +136,10 @@ namespace Tensorflow.Contexts public bool has_graph_arg(params object[] args) { var flatten_args = nest.flatten(args); - bool has_graph_arg = false; + /*if (flatten_args.Count(x => x.GetType().IsValueType) == flatten_args.Count()) + return tf.Context.executing_eagerly() == false*/ + + bool has_graph_arg = !tf.Context.executing_eagerly(); foreach (var el in flatten_args) { if (el is Tensor tensor && !tensor.IsEagerTensor) diff --git a/src/TensorFlowNET.Core/Contexts/ExecuteOpArgs.cs b/src/TensorFlowNET.Core/Contexts/ExecuteOpArgs.cs new file mode 100644 index 00000000..ecdcff8e --- /dev/null +++ b/src/TensorFlowNET.Core/Contexts/ExecuteOpArgs.cs @@ -0,0 +1,25 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + public class ExecuteOpArgs + { + public Func GetGradientAttrs { get; set; } + public object[] OpInputArgs { get; set; } + public Dictionary OpAttrs { get; set; } + + public ExecuteOpArgs(params object[] inputArgs) + { + OpInputArgs = inputArgs; + } + + public ExecuteOpArgs SetAttributes(object attrs) + { + OpAttrs = ConvertToDict(attrs); + return this; + } + } +} diff --git a/src/TensorFlowNET.Core/Data/DatasetV2.cs b/src/TensorFlowNET.Core/Data/DatasetV2.cs index 0297eb6b..a8033802 100644 --- a/src/TensorFlowNET.Core/Data/DatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/DatasetV2.cs @@ -68,6 +68,17 @@ namespace Tensorflow public IDatasetV2 map(Func map_func, int num_parallel_calls) => new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls); + public OwnedIterator make_one_shot_iterator() + { + if (tf.Context.executing_eagerly()) + { + // with ops.colocate_with(self._variant_tensor) + return new OwnedIterator(this); + } + + throw new NotImplementedException(""); + } + public IDatasetV2 flat_map(Func map_func) => new FlatMapDataset(this, map_func); @@ -105,18 +116,7 @@ namespace Tensorflow } public Tensor dataset_cardinality(string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "DatasetCardinality", name, - null, - variant_tensor); - return results[0]; - } - - throw new NotImplementedException(""); - } + => tf.Context.ExecuteOp("DatasetCardinality", name, new ExecuteOpArgs(variant_tensor)); public override string ToString() => $"{GetType().Name} shapes: {string.Join(", ", structure.Select(x => x.shape))}, types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}"; diff --git a/src/TensorFlowNET.Core/Data/IDatasetV2.cs b/src/TensorFlowNET.Core/Data/IDatasetV2.cs index d0e372dc..9ce392d9 100644 --- a/src/TensorFlowNET.Core/Data/IDatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/IDatasetV2.cs @@ -72,6 +72,8 @@ namespace Tensorflow IDatasetV2 map(Func map_func, int num_parallel_calls); + OwnedIterator make_one_shot_iterator(); + IDatasetV2 flat_map(Func map_func); IDatasetV2 model(AutotuneAlgorithm algorithm, long cpu_budget); diff --git a/src/TensorFlowNET.Core/Data/OwnedIterator.cs b/src/TensorFlowNET.Core/Data/OwnedIterator.cs index 571e79a6..0a955929 100644 --- a/src/TensorFlowNET.Core/Data/OwnedIterator.cs +++ b/src/TensorFlowNET.Core/Data/OwnedIterator.cs @@ -26,6 +26,7 @@ namespace Tensorflow dataset = dataset.apply_options(); _dataset = dataset; _element_spec = dataset.element_spec; + // _flat_output_types = (_iterator_resource, _deleter) = ops.anonymous_iterator_v2(_dataset.output_types, _dataset.output_shapes); ops.make_iterator(dataset.variant_tensor, _iterator_resource); } diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs index a7d8503a..479d2aa4 100644 --- a/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs +++ b/src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs @@ -15,84 +15,54 @@ namespace Tensorflow.Eager /// public partial class EagerRunner { - int kFastPathExecuteInputStartIndex = 0; UnorderedMap thread_local_eager_operation_map = new UnorderedMap(); - public Tensor[] TFE_FastPathExecute2(Context ctx, - string device_name, - string opName, - string name, - Action callbacks, - object[] inputArgs, - object[] attrs) + public Tensor[] TFE_FastPathExecute(FastPathOpExecInfo op_exec_info) { - var args = new List(); - args.AddRange(inputArgs); - if (attrs != null) - args.AddRange(attrs); - return TFE_FastPathExecute(ctx, device_name, opName, name, callbacks, args.ToArray()); - } - - public Tensor[] TFE_FastPathExecute(Context ctx, - string device_name, - string opName, - string name, - Action callbacks, - params object[] args) - { - if (ctx == null) - throw new ValueError("This function does not handle the case of the path where " + - "all inputs are not already EagerTensors."); + if (op_exec_info.ctx == null) + op_exec_info.ctx = tf.Context; + if (string.IsNullOrEmpty(op_exec_info.device_name)) + op_exec_info.device_name = tf.Context.DeviceName; - int args_size = args.Length; var attr_list_sizes = new Dictionary(); - FastPathOpExecInfo op_exec_info = new FastPathOpExecInfo() - { - ctx = ctx, - args = args, - device_name = device_name, - op_name = opName, - name = name, - }; - op_exec_info.run_gradient_callback = HasAccumulatorOrTape(); - op_exec_info.run_post_exec_callbacks = callbacks != null; + op_exec_info.run_post_exec_callbacks = op_exec_info.callbacks != null; op_exec_info.run_callbacks = op_exec_info.run_gradient_callback || op_exec_info.run_post_exec_callbacks; var status = tf.Status; - using var op = GetOp(ctx, opName, status); + using var op = GetOp(op_exec_info.ctx, op_exec_info.op_name, status); - var op_def = tf.get_default_graph().GetOpDef(opName); + var op_def = tf.get_default_graph().GetOpDef(op_exec_info.op_name); var flattened_attrs = new List(op_def.Attr.Count * 2); var flattened_inputs = new List(op_def.InputArg.Count); // Set non-inferred attrs, including setting defaults if the attr is passed in // as None. - for (int i = kFastPathExecuteInputStartIndex + op_def.InputArg.Count; i < args_size; i += 2) + if(op_exec_info.attrs != null) { - var attr_name = args[i].ToString(); - var attr_value = args[i + 1]; - - var attr = op_def.Attr.FirstOrDefault(x => x.Name == attr_name); - if (attr != null) + foreach (var attr1 in op_exec_info.attrs) { - flattened_attrs.Add(attr_name); - flattened_attrs.Add(attr_value); + var attr = op_def.Attr.FirstOrDefault(x => x.Name == attr1.Key); + if (attr != null) + { + flattened_attrs.Add(attr.Name); + flattened_attrs.Add(attr1.Value); - SetOpAttrWithDefaults(ctx, op, attr, attr_name, attr_value, attr_list_sizes, status); - status.Check(true); + SetOpAttrWithDefaults(op_exec_info.ctx, op, attr, attr.Name, attr1.Value, attr_list_sizes, status); + status.Check(true); + } } } - c_api.TFE_OpSetDevice(op, device_name, status.Handle); + c_api.TFE_OpSetDevice(op, op_exec_info.device_name, status.Handle); status.Check(true); // Add inferred attrs and inputs. for (int i = 0; i < op_def.InputArg.Count; i++) { - var input = args[kFastPathExecuteInputStartIndex + i]; + var input = op_exec_info.args[i]; var input_arg = op_def.InputArg[i]; if (!string.IsNullOrEmpty(input_arg.NumberAttr)) { @@ -107,7 +77,7 @@ namespace Tensorflow.Eager if (len > 0) { - var fast_input_array = (object[])args[i]; + var fast_input_array = (object[])op_exec_info.args[i]; // First item adds the type attr. if (!AddInputToOp(fast_input_array[i], true, input_arg, flattened_attrs, flattened_inputs, op, status)) return null; @@ -151,7 +121,7 @@ namespace Tensorflow.Eager else { // The item is a single item. - AddInputToOp(args[i], true, input_arg, flattened_attrs, flattened_inputs, op, status); + AddInputToOp(op_exec_info.args[i], true, input_arg, flattened_attrs, flattened_inputs, op, status); } } @@ -179,7 +149,7 @@ namespace Tensorflow.Eager if (op_exec_info.run_callbacks) { RunCallbacks(op_exec_info, - kFastPathExecuteInputStartIndex + op_def.InputArg.Count(), + op_def.InputArg.Count(), flattened_inputs.ToArray(), flattened_attrs.ToArray(), flat_result); } diff --git a/src/TensorFlowNET.Core/Eager/FastPathOpExecInfo.cs b/src/TensorFlowNET.Core/Eager/FastPathOpExecInfo.cs index 654c25b2..2cdf025a 100644 --- a/src/TensorFlowNET.Core/Eager/FastPathOpExecInfo.cs +++ b/src/TensorFlowNET.Core/Eager/FastPathOpExecInfo.cs @@ -1,6 +1,8 @@ -using Tensorflow.Contexts; +using System; +using System.Collections.Generic; +using Tensorflow.Contexts; -namespace Tensorflow.Eager +namespace Tensorflow { public class FastPathOpExecInfo { @@ -9,8 +11,17 @@ namespace Tensorflow.Eager public string op_name { get; set; } public string name { get; set; } public object[] args { get; set; } + public Dictionary attrs { get; set; } public bool run_gradient_callback { get; set; } public bool run_post_exec_callbacks { get; set; } public bool run_callbacks { get; set; } + public Action callbacks { get; set; } + + public FastPathOpExecInfo(string opName, string name, params object[] inputArgs) + { + this.op_name = opName; + this.name = name; + this.args = inputArgs; + } } } diff --git a/src/TensorFlowNET.Core/Eager/IEagerRunner.cs b/src/TensorFlowNET.Core/Eager/IEagerRunner.cs index c1fb8607..38202af6 100644 --- a/src/TensorFlowNET.Core/Eager/IEagerRunner.cs +++ b/src/TensorFlowNET.Core/Eager/IEagerRunner.cs @@ -16,20 +16,7 @@ namespace Tensorflow.Eager TF_DataType default_dtype = TF_DataType.DtInvalid, object[] args = null); - Tensor[] TFE_FastPathExecute2(Context ctx, - string device_name, - string opName, - string name, - Action callbacks, - object[] inputArgs, - object[] attrs); - - Tensor[] TFE_FastPathExecute(Context ctx, - string device_name, - string opName, - string name, - Action callbacks, - params object[] args); + Tensor[] TFE_FastPathExecute(FastPathOpExecInfo op_exec_info); Tensor[] TFE_Execute(Context ctx, string device_name, diff --git a/src/TensorFlowNET.Core/Framework/random_seed.py.cs b/src/TensorFlowNET.Core/Framework/random_seed.py.cs index e8af1993..8732c030 100644 --- a/src/TensorFlowNET.Core/Framework/random_seed.py.cs +++ b/src/TensorFlowNET.Core/Framework/random_seed.py.cs @@ -14,16 +14,43 @@ limitations under the License. ******************************************************************************/ +using System.Collections.Generic; +using static Tensorflow.Binding; + namespace Tensorflow { public class random_seed { private static int DEFAULT_GRAPH_SEED = 87654321; + private static Dictionary _graph_to_seed_dict = new Dictionary(); public static (int?, int?) get_seed(int? op_seed = null) { + int? global_seed; + + if (tf.executing_eagerly()) + global_seed = tf.Context.global_seed(); + else + global_seed = ops.get_default_graph().seed; + + if (global_seed.HasValue) + { + if (!op_seed.HasValue) + if (tf.executing_eagerly()) + op_seed = tf.Context.internal_operation_seed(); + else + { + if (!_graph_to_seed_dict.TryGetValue(ops.get_default_graph().graph_key, out int seed)) + seed = 0; + _graph_to_seed_dict[ops.get_default_graph().graph_key] = seed + 1; + op_seed = seed; + } + + return (global_seed, op_seed); + } + if (op_seed.HasValue) - return (DEFAULT_GRAPH_SEED, 0); + return (DEFAULT_GRAPH_SEED, op_seed); else return (null, null); } diff --git a/src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs b/src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs deleted file mode 100644 index f17f668c..00000000 --- a/src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs +++ /dev/null @@ -1,63 +0,0 @@ -using System; -using System.Linq; -using static Tensorflow.Binding; - -namespace Tensorflow.Framework -{ - /// - /// Represents a sparse tensor. - /// - public class SparseTensor : CompositeTensor, _TensorLike - { - long[,] _indices; - public Tensor indices; - - T[] _values; - public Tensor values; - - long[] _dense_shape; - public Tensor dense_shape; - - TensorShape _shape; - public TensorShape shape => _shape; - - public TF_DataType dtype => dtypes.as_dtype(typeof(T)); - - public SparseTensor(long[,] indices_, T[] values_, long[] dense_shape_) - { - tf_with(ops.name_scope(null, "SparseTensor", new { }), delegate - { - indices = ops.convert_to_tensor( - indices_, name: "indices", dtype: dtypes.int64); - values = ops.convert_to_tensor(values_, name: "values"); - dense_shape = ops.convert_to_tensor( - dense_shape_, name: "dense_shape", dtype: dtypes.int64); - }); - - _indices = indices_; - _values = values_; - _dense_shape = dense_shape_; - - var indices_shape = indices.TensorShape.with_rank(2); - var values_shape = values.TensorShape.with_rank(1); - var dense_shape_shape = dense_shape.TensorShape.with_rank(1); - - indices_shape["0"].merge_with(values_shape[0]); - indices_shape["1"].merge_with(dense_shape_shape[0]); - - _shape = new TensorShape(_dense_shape.Select(x => Convert.ToInt32(x)).ToArray()); - } - } - - public interface _TensorLike - { - } - - public static class sparse_tensor_extension - { - public static bool is_sparse(this _TensorLike x) - { - return x.GetType().Name.Contains("SparseTensor"); - } - } -} diff --git a/src/TensorFlowNET.Core/Framework/tensor_shape.cs b/src/TensorFlowNET.Core/Framework/tensor_shape.cs index 0cdb633a..0aaca304 100644 --- a/src/TensorFlowNET.Core/Framework/tensor_shape.cs +++ b/src/TensorFlowNET.Core/Framework/tensor_shape.cs @@ -44,14 +44,14 @@ namespace Tensorflow.Framework return true; } - if (other.is_sparse()) + if (other.IsSparseTensor) { return self.dtype.is_compatible_with(other.dtype); } return self.dtype.is_compatible_with(other.dtype) && _shape_is_compatible_0dim(self.shape, other.shape) && - !self.is_sparse(); + !self.IsSparseTensor; } public static Dimension dimension_at_index(TensorShape shape, int index) diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs index 2d0d7d28..a071d234 100644 --- a/src/TensorFlowNET.Core/Gradients/math_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs @@ -291,23 +291,23 @@ namespace Tensorflow.Gradients var b = math_ops.conj(op.inputs[1]); if (!t_a && !t_b) { - grad_a = gen_math_ops.batch_mat_mul(grad, b, adj_y: true); - grad_b = gen_math_ops.batch_mat_mul(a, grad, adj_x: true); + grad_a = math_ops.batch_matmul(grad, b, adj_y: true); + grad_b = math_ops.batch_matmul(a, grad, adj_x: true); } else if (!t_a && t_b) { - grad_a = gen_math_ops.batch_mat_mul(grad, b); - grad_b = gen_math_ops.batch_mat_mul(grad, a, adj_x: true); + grad_a = math_ops.batch_matmul(grad, b); + grad_b = math_ops.batch_matmul(grad, a, adj_x: true); } else if (t_a && !t_b) { - grad_a = gen_math_ops.batch_mat_mul(grad, b); - grad_b = gen_math_ops.batch_mat_mul(grad, a, adj_x: true); + grad_a = math_ops.batch_matmul(grad, b); + grad_b = math_ops.batch_matmul(grad, a, adj_x: true); } else if (t_a && t_b) { - grad_a = gen_math_ops.batch_mat_mul(b, grad, adj_x: true, adj_y: true); - grad_b = gen_math_ops.batch_mat_mul(grad, a, adj_x: true, adj_y: true); + grad_a = math_ops.batch_matmul(b, grad, adj_x: true, adj_y: true); + grad_b = math_ops.batch_matmul(grad, a, adj_x: true, adj_y: true); } return new Tensor[] { grad_a, grad_b }; diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs index ab55da4e..ddeadc00 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs @@ -11,5 +11,6 @@ namespace Tensorflow.Keras.ArgsDefinition public int MaxTokens { get; set; } = -1; public string OutputMode { get; set; } = "int"; public int OutputSequenceLength { get; set; } = -1; + public string[] Vocabulary { get; set; } } } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs index e2815f81..346ba2dd 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs @@ -40,37 +40,16 @@ namespace Tensorflow.Operations /// /// public static Tensor conv2d(Conv2dParams parameters) - { - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Conv2D", parameters.Name, - null, - parameters.Input, parameters.Filter, - "strides", parameters.Strides, - "use_cudnn_on_gpu", parameters.UseCudnnOnGpu, - "padding", parameters.Padding, - "explicit_paddings", parameters.ExplicitPaddings, - "data_format", parameters.DataFormat, - "dilations", parameters.Dilations); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Conv2D", name: parameters.Name, args: new - { - input = parameters.Input, - filter = parameters.Filter, - strides = parameters.Strides, - padding = parameters.Padding, - use_cudnn_on_gpu = parameters.UseCudnnOnGpu, - explicit_paddings = parameters.ExplicitPaddings, - data_format = parameters.DataFormat, - dilations = parameters.Dilations - }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("Conv2D", parameters.Name, new ExecuteOpArgs(parameters.Input, parameters.Filter) + .SetAttributes(new + { + strides = parameters.Strides, + padding = parameters.Padding, + use_cudnn_on_gpu = parameters.UseCudnnOnGpu, + explicit_paddings = parameters.ExplicitPaddings, + data_format = parameters.DataFormat, + dilations = parameters.Dilations + })); /// /// Computes the gradients of convolution with respect to the filter. @@ -83,43 +62,16 @@ namespace Tensorflow.Operations string data_format = "NHWC", int[] dilations = null, string name = null) - { - if (explicit_paddings == null) - explicit_paddings = new int[0]; - if (dilations == null) - dilations = new int[] { 1, 1, 1, 1 }; - - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Conv2DBackpropFilter", name, - null, - input, filter_sizes, out_backprop, - "strides", strides, - "use_cudnn_on_gpu", use_cudnn_on_gpu, - "padding", padding, - "explicit_paddings", explicit_paddings, - "data_format", data_format, - "dilations", dilations); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropFilter", name: name, args: new - { - input, - filter_sizes, - out_backprop, - strides, - padding, - use_cudnn_on_gpu, - explicit_paddings, - data_format, - dilations - }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("Conv2DBackpropFilter", name, new ExecuteOpArgs(input, filter_sizes, out_backprop) + .SetAttributes(new + { + strides, + padding, + use_cudnn_on_gpu, + explicit_paddings = explicit_paddings ?? new int[0], + data_format, + dilations = dilations ?? new int[] { 1, 1, 1, 1 } + })); /// /// Computes the gradients of convolution with respect to the input. @@ -132,99 +84,29 @@ namespace Tensorflow.Operations string data_format = "NHWC", int[] dilations = null, string name = null) - { - if (explicit_paddings == null) - explicit_paddings = new int[0]; - if (dilations == null) - dilations = new int[] { 1, 1, 1, 1 }; - - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Conv2DBackpropInput", name, - null, - input_sizes, filter, out_backprop, - "strides", strides, - "use_cudnn_on_gpu", use_cudnn_on_gpu, - "padding", padding, - "explicit_paddings", explicit_paddings, - "data_format", data_format, - "dilations", dilations); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Conv2DBackpropInput", name: name, args: new - { - input_sizes, - filter, - out_backprop, - strides, - padding, - use_cudnn_on_gpu, - explicit_paddings, - data_format, - dilations - }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("Conv2DBackpropInput", name, new ExecuteOpArgs(input_sizes, filter, out_backprop) + .SetAttributes(new + { + strides, + padding, + use_cudnn_on_gpu, + explicit_paddings = explicit_paddings ?? new int[0], + data_format, + dilations = dilations ?? new int[] { 1, 1, 1, 1 } + })); public static Tensor bias_add(Tensor value, IVariableV1 bias, string data_format = null, string name = null) - { - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "BiasAdd", name, - null, - value, bias, - "data_format", data_format); - - return results[0]; - } - - if (data_format == null) - data_format = "NHWC"; - - var _op = tf.OpDefLib._apply_op_helper("BiasAdd", name: name, args: new - { - value, - bias, - data_format - }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("BiasAdd", name, new ExecuteOpArgs(value, bias) + .SetAttributes(new { data_format = data_format ?? "NHWC" })); public static Tensor bias_add_grad(Tensor out_backprop, string data_format = "NHWC", string name = null) - { - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "BiasAddGrad", name, - null, - out_backprop, - "data_format", data_format); - - return results[0]; - } - - if (data_format == null) - data_format = "NHWC"; - - var _op = tf.OpDefLib._apply_op_helper("BiasAddGrad", name: name, args: new - { - out_backprop, - data_format - }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("BiasAddGrad", name, new ExecuteOpArgs(out_backprop) + .SetAttributes(new { data_format = data_format ?? "NHWC" })); /// /// Computes exponential linear: exp(features) - 1 if &lt; 0, features otherwise. @@ -269,29 +151,19 @@ namespace Tensorflow.Operations } public static Tensor[] fused_batch_norm_grad_v3(FusedBatchNormParams @params) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("FusedBatchNormGradV3", name: @params.Name, - args: new - { - y_backprop = @params.YBackprop, - x = @params.X, - scale = @params.Scale, - reserve_space_1 = @params.ReserveSpace1, - reserve_space_2 = @params.ReserveSpace2, - reserve_space_3 = @params.ReserveSpace3, - epsilon = @params.Epsilon, - data_format = @params.DataFormat, - is_training = @params.IsTraining - }).outputs, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "FusedBatchNormGradV3", @params.Name, - null, - @params.YBackprop, @params.X, @params.Scale, - @params.ReserveSpace1, @params.ReserveSpace2, @params.ReserveSpace3, - "epsilon", @params.Epsilon, - "data_format", @params.DataFormat, - "is_training", @params.IsTraining), - @params.YBackprop); + => tf.Context.ExecuteOp("FusedBatchNormGradV3", @params.Name, + new ExecuteOpArgs(@params.YBackprop, + @params.X, + @params.Scale, + @params.ReserveSpace1, + @params.ReserveSpace2, + @params.ReserveSpace3) + .SetAttributes(new + { + epsilon = @params.Epsilon, + data_format = @params.DataFormat, + is_training = @params.IsTraining + })); public static Tensor[] fused_batch_norm(Tensor x, Tensor scale, @@ -328,39 +200,8 @@ namespace Tensorflow.Operations string data_format = "NHWC", bool is_training = true, string name = null) - { - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "FusedBatchNormV3", name, - null, - x, - scale, - offset, - mean, - variance, - "epsilon", epsilon, - "exponential_avg_factor", exponential_avg_factor, - "data_format", data_format, - "is_training", is_training); - - return results; - } - - var _op = tf.OpDefLib._apply_op_helper("FusedBatchNormV3", name: name, args: new - { - x, - scale, - offset, - mean, - variance, - epsilon, - data_format, - is_training - }); - - return _op.outputs; - } + => tf.Context.ExecuteOp("FusedBatchNormV3", name, new ExecuteOpArgs(x, scale, offset, mean, variance) + .SetAttributes(new { epsilon, data_format, is_training })); /// /// Local Response Normalization. @@ -388,14 +229,7 @@ namespace Tensorflow.Operations } public static Tensor log_softmax(Tensor logits, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("LogSoftmax", name: name, - args: new { logits }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "LogSoftmax", name, - null, - logits).FirstOrDefault(), - logits); + => tf.Context.ExecuteOp("LogSoftmax", name, new ExecuteOpArgs(logits)); /// /// Says whether the targets are in the top `K` predictions. @@ -418,19 +252,8 @@ namespace Tensorflow.Operations } public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("LeakyRelu", name: name, - args: new - { - features, - alpha - }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "LeakyRelu", name, - null, - features, - "alpha", alpha).FirstOrDefault(), - features); + => tf.Context.ExecuteOp("LeakyRelu", name, + new ExecuteOpArgs(features).SetAttributes(new { alpha })); public static Tensor max_pool(Tensor input, int[] ksize, @@ -438,63 +261,25 @@ namespace Tensorflow.Operations string padding, string data_format = "NHWC", string name = null) - { - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "MaxPool", name, - null, - input, - "ksize", ksize, - "strides", strides, - "padding", padding, - "data_format", data_format); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("MaxPool", name: name, args: new - { - input, - ksize, - strides, - padding, - data_format, - }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("MaxPool", name, new ExecuteOpArgs(input) + .SetAttributes(new + { + ksize, + strides, + padding, + data_format + })); public static Tensor max_pool_grad(Tensor orig_input, Tensor orig_output, Tensor grad, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null) - { - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "MaxPoolGrad", name, - null, - orig_input, orig_output, grad, - "ksize", ksize, - "strides", strides, - "padding", padding, - "data_format", data_format); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("MaxPoolGrad", name: name, args: new - { - orig_input, - orig_output, - grad, - ksize, - strides, - padding, - data_format - }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("MaxPoolGrad", name, new ExecuteOpArgs(orig_input, orig_output, grad) + .SetAttributes(new + { + ksize, + strides, + padding, + data_format + })); public static Tensor[] top_kv2(Tensor input, int k, bool sorted = true, string name = null) { @@ -509,68 +294,14 @@ namespace Tensorflow.Operations } public static Tensor relu_grad(Tensor gradients, Tensor features, string name = null) - { - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "ReluGrad", name, - null, - gradients, features); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("ReluGrad", name: name, args: new - { - gradients, - features - }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("ReluGrad", name, new ExecuteOpArgs(gradients, features)); public static Tensor leaky_relu_grad(Tensor gradients, Tensor features, float alpha = 0.2f, string name = null) - { - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "LeakyReluGrad", name, - null, - gradients, features, - "alpha", alpha); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("LeakyReluGrad", name: name, args: new - { - gradients, - features, - alpha - }); - - return _op.output; - } + => tf.Context.ExecuteOp("LeakyReluGrad", name, new ExecuteOpArgs(gradients, features) + .SetAttributes(new { alpha })); public static Tensor softmax(Tensor logits, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Softmax", name, - null, - logits); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Softmax", name: name, args: new - { - logits - }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("Softmax", name, new ExecuteOpArgs(logits)); /// /// Computes softmax cross entropy cost and gradients to backpropagate. @@ -581,23 +312,9 @@ namespace Tensorflow.Operations /// public static (Tensor, Tensor) softmax_cross_entropy_with_logits(Tensor features, Tensor labels, string name = null) { - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "SoftmaxCrossEntropyWithLogits", name, - null, - features, labels); - - return (results[0], results[1]); - } + var results = tf.Context.ExecuteOp("SoftmaxCrossEntropyWithLogits", name, new ExecuteOpArgs(features, labels)); - var _op = tf.OpDefLib._apply_op_helper("SoftmaxCrossEntropyWithLogits", name: name, args: new - { - features, - labels - }); - - return (_op.outputs[0], _op.outputs[1]); + return (results[0], results[1]); } /// @@ -629,21 +346,9 @@ namespace Tensorflow.Operations /// public static (Tensor loss, Tensor backprop) sparse_softmax_cross_entropy_with_logits(Tensor features, Tensor labels, string name = "SparseSoftmaxCrossEntropyWithLogits") { - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "SparseSoftmaxCrossEntropyWithLogits", name, - null, - features, labels); - - return (results[0], results[1]); - } - - var op = tf.OpDefLib._apply_op_helper("SparseSoftmaxCrossEntropyWithLogits", name: name, args: new { features, labels }); - int _idx = 0; - var loss = op.outputs[_idx++]; - var backprop = op.outputs[_idx++]; - return (loss, backprop); + var results = tf.Context.ExecuteOp("SparseSoftmaxCrossEntropyWithLogits", name, new ExecuteOpArgs(features, labels)); + + return (results[0], results[1]); } /// @@ -653,35 +358,9 @@ namespace Tensorflow.Operations /// A name for the operation (optional). /// A `Tensor`. Has the same type as `features`. public static Tensor relu(Tensor features, string name = null) - { - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Relu", name, - null, - features); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Relu", name: name, args: new { features }); - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("Relu", name, new ExecuteOpArgs(features)); public static Tensor tanh(Tensor x, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Tanh", name, - null, - x); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Tanh", name: name, args: new { x }); - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("Tanh", name, new ExecuteOpArgs(x)); } } diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index 752b1d51..560b681e 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -68,10 +68,10 @@ namespace Tensorflow string _scope_name = scope; // Perform input type inference - foreach (var input_arg in op_def.InputArg) + foreach (var (i, input_arg) in enumerate(op_def.InputArg)) { var input_name = input_arg.Name; - + if (keywords.ContainsKey(input_name)) values = keywords[input_name]; else if (keywords.ContainsKey(input_name + "_")) @@ -79,6 +79,10 @@ namespace Tensorflow input_name += "_"; values = keywords[input_name]; } + else if (keywords.ContainsKey($"input_{i}")) + { + values = keywords[$"input_{i}"]; + } else throw new TypeError("No argument for input " + input_name); diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index bc4b1206..2eb32775 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -57,20 +57,8 @@ namespace Tensorflow /// gradients in some corner cases. /// public static Tensor prevent_gradient(Tensor input, string message = "", string name = null) - { - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "PreventGradient", name, - null, - input, - "message", message); - return results[0]; - } - - var op = tf.OpDefLib._apply_op_helper("PreventGradient", name: name, args: new { input, message }); - return op.output; - } + => tf.Context.ExecuteOp("PreventGradient", name, new ExecuteOpArgs(input) + .SetAttributes(new { message })); internal static Tensor constant(object value, TF_DataType dtype = TF_DataType.DtInvalid, @@ -737,35 +725,27 @@ namespace Tensorflow public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy, long begin_mask = 0, long end_mask = 0, long ellipsis_mask = 0, long new_axis_mask = 0, long shrink_axis_mask = 0, string name = null) - => tf.Context.RunInAutoMode2("StridedSliceGrad", name, new AutoModeArgs - { - OpInputArgs = new + => tf.Context.ExecuteOp("StridedSliceGrad", name, + new ExecuteOpArgs(shape, begin, end, strides, dy) { - shape, - begin, - end, - strides, - dy - }, - OpAttrs = new + GetGradientAttrs = (op) => new + { + T = op.get_attr("T"), + Index = op.get_attr("Index"), + begin_mask = op.get_attr("begin_mask"), + end_mask = op.get_attr("end_mask"), + ellipsis_mask = op.get_attr("ellipsis_mask"), + new_axis_mask = op.get_attr("new_axis_mask"), + shrink_axis_mask = op.get_attr("shrink_axis_mask") + } + }.SetAttributes(new { begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask - }, - GetGradientAttrs = (op) => new - { - T = op.get_attr("T"), - Index = op.get_attr("Index"), - begin_mask = op.get_attr("begin_mask"), - end_mask = op.get_attr("end_mask"), - ellipsis_mask = op.get_attr("ellipsis_mask"), - new_axis_mask = op.get_attr("new_axis_mask"), - shrink_axis_mask = op.get_attr("shrink_axis_mask") - } - }); + })); /// /// Removes dimensions of size 1 from the shape of a tensor. @@ -800,38 +780,17 @@ namespace Tensorflow int num_cols = -1, float padding_value = 0, string align = "RIGHT_LEFT") - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "MatrixDiagV3", name, - null, - diagonal, k, num_rows, num_cols, padding_value, - "align", align); - return results[0]; - } - - throw new NotImplementedException(""); - } + => tf.Context.ExecuteOp("MatrixDiagV3", name, + new ExecuteOpArgs(diagonal, k, num_rows, num_cols, padding_value) + .SetAttributes(new { align })); public static Tensor matrix_set_diag(Tensor input, Tensor diagonal, string name = "set_diag", int k = 0, string align = "RIGHT_LEFT") - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "MatrixSetDiagV3", name, - null, - input, diagonal, k, - "align", align); - return results[0]; - } - - throw new NotImplementedException(""); - } + => tf.Context.ExecuteOp("MatrixSetDiagV3", name, new ExecuteOpArgs(input, diagonal, k) + .SetAttributes(new { align })); /// /// Computes the shape of a broadcast given symbolic shapes. @@ -960,9 +919,8 @@ namespace Tensorflow => gen_array_ops.slice(input, begin, size, name: name); public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name = null) - => tf.Context.RunInAutoMode2("Slice", name, new AutoModeArgs + => tf.Context.ExecuteOp("Slice", name, new ExecuteOpArgs(input, begin, size) { - OpInputArgs = new { input, begin, size }, GetGradientAttrs = (op) => new { T = op.get_attr("T"), diff --git a/src/TensorFlowNET.Core/Operations/bitwise_ops.cs b/src/TensorFlowNET.Core/Operations/bitwise_ops.cs index 4b4e0f5e..7536357c 100644 --- a/src/TensorFlowNET.Core/Operations/bitwise_ops.cs +++ b/src/TensorFlowNET.Core/Operations/bitwise_ops.cs @@ -94,20 +94,7 @@ namespace Tensorflow.Operations /// /// Tensor unary_op(Tensor x, string opName, string name) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - opName, name, - null, - x); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper(opName, name, args: new { x }); - return _op.output; - } + => tf.Context.ExecuteOp(opName, name, new ExecuteOpArgs(x)); /// /// Helper method to invoke binary operator with specified name. @@ -118,21 +105,7 @@ namespace Tensorflow.Operations /// /// Tensor binary_op(Tensor x, Tensor y, string opName, string name) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - opName, name, - null, - x, y); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper(opName, name, args: new { x, y }); - return _op.output; - } - + => tf.Context.ExecuteOp(opName, name, new ExecuteOpArgs(x, y)); #endregion } } diff --git a/src/TensorFlowNET.Core/Operations/dataset_ops.cs b/src/TensorFlowNET.Core/Operations/dataset_ops.cs index 3a8d70b4..fcad0709 100644 --- a/src/TensorFlowNET.Core/Operations/dataset_ops.cs +++ b/src/TensorFlowNET.Core/Operations/dataset_ops.cs @@ -8,26 +8,10 @@ namespace Tensorflow public class dataset_ops { public Tensor tensor_dataset(Tensor[] components, TensorShape[] output_shapes, string name = null) - { - if (tf.Context.executing_eagerly()) + => tf.Context.ExecuteOp("TensorDataset", name, new ExecuteOpArgs() { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "TensorDataset", name, - null, - new object[] - { - components, - "output_shapes", output_shapes - }); - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("TensorDataset", - name: name, - args: new { components, output_shapes }); - - return _op.output; - } + OpInputArgs = new object[] { components } + }.SetAttributes(new { output_shapes })); /// /// Creates a dataset that emits each dim-0 slice of `components` once. @@ -37,192 +21,62 @@ namespace Tensorflow /// /// public Tensor tensor_slice_dataset(Tensor[] components, TensorShape[] output_shapes, string name = null) - { - if (tf.Context.executing_eagerly()) + => tf.Context.ExecuteOp("TensorSliceDataset", name, new ExecuteOpArgs() { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "TensorSliceDataset", name, - null, - new object[] - { - components, - "output_shapes", output_shapes - }); - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("TensorSliceDataset", - name: name, - args: new { components, output_shapes }); - - return _op.outputs[0]; - } + OpInputArgs = new object[] { components } + }.SetAttributes(new { output_shapes })); public Tensor range_dataset(Tensor start, Tensor stop, Tensor step, TF_DataType[] output_types, TensorShape[] output_shapes, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "RangeDataset", name, - null, - start, stop, step, - "output_types", output_types, - "output_shapes", output_shapes); - return results[0]; - } - - throw new NotImplementedException(""); - } + => tf.Context.ExecuteOp("RangeDataset", name, new ExecuteOpArgs(start, stop, step) + .SetAttributes(new { output_types, output_shapes })); public Tensor repeat_dataset(Tensor input_dataset, Tensor count, TF_DataType[] output_types, TensorShape[] output_shapes, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "RepeatDataset", name, - null, - input_dataset, count, - "output_types", output_types, - "output_shapes", output_shapes); - return results[0]; - } - - throw new NotImplementedException(""); - } + => tf.Context.ExecuteOp("RepeatDataset", name, new ExecuteOpArgs(input_dataset, count) + .SetAttributes(new { output_types, output_shapes })); public Tensor shard_dataset(Tensor input_dataset, Tensor num_shards, Tensor index, TF_DataType[] output_types, TensorShape[] output_shapes, bool require_non_empty = false, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "ShardDataset", name, - null, - input_dataset, num_shards, index, - "require_non_empty", require_non_empty, - "output_types", output_types, - "output_shapes", output_shapes); - return results[0]; - } - - throw new NotImplementedException(""); - } + => tf.Context.ExecuteOp("ShardDataset", name, new ExecuteOpArgs(input_dataset, num_shards, index) + .SetAttributes(new { require_non_empty, output_types, output_shapes })); public Tensor zip_dataset(Tensor[] input_datasets, TF_DataType[] output_types, TensorShape[] output_shapes, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "ZipDataset", name, - null, - new object[] - { - input_datasets, - "output_types", output_types, - "output_shapes", output_shapes - }); - return results[0]; - } - - throw new NotImplementedException(""); - } + => tf.Context.ExecuteOp("ZipDataset", name, new ExecuteOpArgs() + { + OpInputArgs = new object[] { input_datasets } + }.SetAttributes(new { output_types, output_shapes })); public Tensor shuffle_dataset_v3(Tensor input_dataset, Tensor buffer_size, Tensor seed, Tensor seed2, Tensor seed_generator, TF_DataType[] output_types, TensorShape[] output_shapes, bool reshuffle_each_iteration = true, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "ShuffleDatasetV3", name, - null, - input_dataset, buffer_size, - seed, seed2, seed_generator, - "reshuffle_each_iteration", reshuffle_each_iteration, - "output_types", output_types, - "output_shapes", output_shapes); - return results[0]; - } - - throw new NotImplementedException(""); - } + => tf.Context.ExecuteOp("ShuffleDatasetV3", name, new ExecuteOpArgs(input_dataset, buffer_size, seed, seed2, seed_generator) + .SetAttributes(new { reshuffle_each_iteration, output_types, output_shapes })); public Tensor skip_dataset(Tensor input_dataset, Tensor count, TF_DataType[] output_types, TensorShape[] output_shapes, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "SkipDataset", name, - null, - input_dataset, count, - "output_types", output_types, - "output_shapes", output_shapes); - return results[0]; - } - - throw new NotImplementedException(""); - } + => tf.Context.ExecuteOp("SkipDataset", name, new ExecuteOpArgs(input_dataset, count) + .SetAttributes(new { output_types, output_shapes })); public Tensor dummy_seed_generator(string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "DummySeedGenerator", name, - null); - return results[0]; - } - - throw new NotImplementedException(""); - } + => tf.Context.ExecuteOp("DummySeedGenerator", name, new ExecuteOpArgs()); public Tensor concatenate_dataset(Tensor input_dataset, Tensor another_dataset, TF_DataType[] output_types, TensorShape[] output_shapes, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "ConcatenateDataset", name, - null, - input_dataset, another_dataset, - "output_types", output_types, - "output_shapes", output_shapes); - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("ConcatenateDataset", - name: name, - args: new { input_dataset, another_dataset, output_types, output_shapes }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("ConcatenateDataset", name, new ExecuteOpArgs(input_dataset, another_dataset) + .SetAttributes(new { output_types, output_shapes })); public Tensor cache_dataset_v2(Tensor input_dataset, Tensor filename, Tensor cache, TF_DataType[] output_types, TensorShape[] output_shapes, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "CacheDatasetV2", name, - null, - input_dataset, filename, cache, - "output_types", output_types, - "output_shapes", output_shapes); - return results[0]; - } - - throw new NotImplementedException(""); - } + => tf.Context.ExecuteOp("CacheDatasetV2", name, new ExecuteOpArgs(input_dataset, filename, cache) + .SetAttributes(new { output_types, output_shapes })); /// /// Creates a dataset that batches `batch_size` elements from `input_dataset`. @@ -240,21 +94,9 @@ namespace Tensorflow TF_DataType[] output_types, TensorShape[] output_shapes, bool parallel_copy = false, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "BatchDatasetV2", name, - null, - input_dataset, buffer_size, drop_remainder, - "parallel_copy", parallel_copy, - "output_types", output_types, - "output_shapes", output_shapes); - return results[0]; - } - - throw new NotImplementedException(""); - } + => tf.Context.ExecuteOp("BatchDatasetV2", name, + new ExecuteOpArgs(input_dataset, buffer_size, drop_remainder) + .SetAttributes(new { parallel_copy, output_types, output_shapes })); /// /// @@ -262,17 +104,7 @@ namespace Tensorflow /// /// public Tensor dummy_memory_cache(string name = "") - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "DummyMemoryCache", name, - null); - return results[0]; - } - - throw new NotImplementedException(""); - } + => tf.Context.ExecuteOp("DummyMemoryCache", name, new ExecuteOpArgs()); /// /// Creates a dataset that asynchronously prefetches elements from `input_dataset`. @@ -290,22 +122,14 @@ namespace Tensorflow int? slack_period = 0, bool legacy_autotune = true, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "PrefetchDataset", name, - null, - input_dataset, buffer_size, - "output_types", output_types, - "output_shapes", output_shapes, - "slack_period", slack_period, - "legacy_autotune", legacy_autotune); - return results[0]; - } - - throw new NotImplementedException(""); - } + => tf.Context.ExecuteOp("PrefetchDataset", name, new ExecuteOpArgs(input_dataset, buffer_size) + .SetAttributes(new + { + output_types, + output_shapes, + slack_period, + legacy_autotune + })); /// /// Creates a dataset that contains `count` elements from the `input_dataset`. @@ -319,20 +143,8 @@ namespace Tensorflow public Tensor take_dataset(Tensor input_dataset, Tensor count, TF_DataType[] output_types, TensorShape[] output_shapes, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "TakeDataset", name, - null, - input_dataset, count, - "output_types", output_types, - "output_shapes", output_shapes); - return results[0]; - } - - throw new NotImplementedException(""); - } + => tf.Context.ExecuteOp("TakeDataset", name, new ExecuteOpArgs(input_dataset, count) + .SetAttributes(new { output_types, output_shapes })); /// /// Creates a dataset by applying optimizations to `input_dataset`. @@ -348,24 +160,13 @@ namespace Tensorflow TF_DataType[] output_types, TensorShape[] output_shapes, string[] optimization_configs = null, string name = null) - { - if (optimization_configs == null) - optimization_configs = new string[0]; - - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "OptimizeDataset", name, - null, - input_dataset, optimizations, - "output_types", output_types, - "output_shapes", output_shapes, - "optimization_configs", optimization_configs); - return results[0]; - } - - throw new NotImplementedException(""); - } + => tf.Context.ExecuteOp("OptimizeDataset", name, new ExecuteOpArgs(input_dataset, optimizations) + .SetAttributes(new + { + output_types, + output_shapes, + optimization_configs = optimization_configs ?? new string[0] + })); /// /// Identity transformation that models performance. @@ -381,22 +182,14 @@ namespace Tensorflow TF_DataType[] output_types, TensorShape[] output_shapes, AutotuneAlgorithm algorithm, long cpu_budget, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "ModelDataset", name, - null, - input_dataset, - "algorithm", algorithm, - "cpu_budget", cpu_budget, - "output_types", output_types, - "output_shapes", output_shapes); - return results[0]; - } - - throw new NotImplementedException(""); - } + => tf.Context.ExecuteOp("ModelDataset", name, new ExecuteOpArgs(input_dataset) + .SetAttributes(new + { + algorithm, + cpu_budget, + output_types, + output_shapes + })); /// /// A container for an iterator resource. @@ -407,17 +200,9 @@ namespace Tensorflow /// A tuple of `Tensor` objects (handle, deleter). public (Tensor, Tensor) anonymous_iterator_v2(TF_DataType[] output_types, TensorShape[] output_shapes, string name = null) { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "AnonymousIteratorV2", name, - null, - "output_types", output_types, - "output_shapes", output_shapes); - return (results[0], results[1]); - } - - throw new NotImplementedException(""); + var results = tf.Context.ExecuteOp("AnonymousIteratorV2", name, + new ExecuteOpArgs().SetAttributes(new { output_types, output_shapes })); + return (results[0], results[1]); } /// @@ -427,19 +212,8 @@ namespace Tensorflow /// /// /// The created Operation. - public ITensorOrOperation make_iterator(Tensor dataset, Tensor iterator, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "MakeIterator", name, - null, - dataset, iterator); - return null; - } - - throw new NotImplementedException(""); - } + public void make_iterator(Tensor dataset, Tensor iterator, string name = null) + => tf.Context.ExecuteOp("MakeIterator", name, new ExecuteOpArgs(dataset, iterator)); /// /// @@ -450,23 +224,15 @@ namespace Tensorflow /// public Tensor map_dataset(Tensor dataset, ConcreteFunction f, TF_DataType[] output_types, TensorShape[] output_shapes, bool use_inter_op_parallelism = true, bool preserve_cardinality = false, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "MapDataset", name, - null, - dataset, new Tensor[0], - "f", f, - "output_types", output_types, - "output_shapes", output_shapes, - "use_inter_op_parallelism", use_inter_op_parallelism, - "preserve_cardinality", preserve_cardinality); - return results[0]; - } - - throw new NotImplementedException(""); - } + => tf.Context.ExecuteOp("MapDataset", name, new ExecuteOpArgs(dataset, new Tensor[0]) + .SetAttributes(new + { + f, + output_types, + output_shapes, + use_inter_op_parallelism, + preserve_cardinality + })); /// /// Creates a dataset that applies `f` to the outputs of `input_dataset`. @@ -479,21 +245,8 @@ namespace Tensorflow /// public Tensor flat_map_dataset(Tensor dataset, ConcreteFunction f, TF_DataType[] output_types, TensorShape[] output_shapes, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "FlatMapDataset", name, - null, - dataset, new Tensor[0], - "f", f, - "output_types", output_types, - "output_shapes", output_shapes); - return results[0]; - } - - throw new NotImplementedException(""); - } + => tf.Context.ExecuteOp("FlatMapDataset", name, new ExecuteOpArgs(dataset, new Tensor[0]) + .SetAttributes(new { f, output_types, output_shapes })); /// /// Creates a dataset that applies `f` to the outputs of `input_dataset`. @@ -512,24 +265,17 @@ namespace Tensorflow string deterministic = "default", bool preserve_cardinality = false, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "ParallelMapDatasetV2", name, - null, - dataset, new Tensor[0], num_parallel_calls, - "f", f, - "output_types", output_types, - "output_shapes", output_shapes, - "use_inter_op_parallelism", use_inter_op_parallelism, - "deterministic", deterministic, - "preserve_cardinality", preserve_cardinality); - return results[0]; - } - - throw new NotImplementedException(""); - } + => tf.Context.ExecuteOp("ParallelMapDatasetV2", name, + new ExecuteOpArgs(dataset, new Tensor[0], num_parallel_calls) + .SetAttributes(new + { + f, + output_types, + output_shapes, + use_inter_op_parallelism, + deterministic, + preserve_cardinality + })); /// /// A container for an iterator resource. @@ -538,19 +284,8 @@ namespace Tensorflow /// /// /// The created Operation. - public ITensorOrOperation delete_iterator(Tensor handle, Tensor deleter, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "DeleteIterator", name, - null, - handle, deleter); - return null; - } - - throw new NotImplementedException(""); - } + public void delete_iterator(Tensor handle, Tensor deleter, string name = null) + => tf.Context.ExecuteOp("DeleteIterator", name, new ExecuteOpArgs(handle, deleter)); /// /// Gets the next output from the given iterator . @@ -561,19 +296,7 @@ namespace Tensorflow /// /// public Tensor[] iterator_get_next(Tensor iterator, TF_DataType[] output_types, TensorShape[] output_shapes, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "IteratorGetNext", name, - null, - iterator, - "output_types", output_types, - "output_shapes", output_shapes); - return results; - } - - throw new NotImplementedException(""); - } + => tf.Context.ExecuteOp("IteratorGetNext", name, new ExecuteOpArgs(iterator) + .SetAttributes(new { output_types, output_shapes })); } } diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index e29227c4..c034c7fd 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -45,20 +45,7 @@ namespace Tensorflow /// /// public static Tensor concat_v2(T[] values, Ta axis, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "ConcatV2", name, - null, - values, axis); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("ConcatV2", name: name, args: new { values, axis }); - return _op.output; - } + => tf.Context.ExecuteOp("ConcatV2", name, new ExecuteOpArgs(values, axis)); public static Tensor concat_v2(Tensor[] values, Tensor axis, string name = null) { @@ -72,14 +59,7 @@ namespace Tensorflow } public static Tensor concat_v2(Tensor[] values, int axis, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("ConcatV2", name: name, - args: new { values, axis }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "ConcatV2", name, - null, - values, axis).FirstOrDefault(), - values); + => tf.Context.ExecuteOp("ConcatV2", name, new ExecuteOpArgs(values, axis)); private static Tensor concat_v2_eager_fallback(T1[] values, T2 axis, string name, Context ctx) { @@ -131,38 +111,11 @@ namespace Tensorflow /// /// public static Tensor diag(Tensor diagonal, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Diag", name, - null, - diagonal); - - return results[0]; - } - - var op = tf.OpDefLib._apply_op_helper("Diag", name: name, args: new { diagonal }); - - return op.output; - } + => tf.Context.ExecuteOp("Diag", name, new ExecuteOpArgs(diagonal)); public static Tensor expand_dims(Tensor input, int axis, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "ExpandDims", name, - null, - input, tf.convert_to_tensor(axis)); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("ExpandDims", name: name, args: new { input, dim = axis }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("ExpandDims", name, new ExecuteOpArgs(input, axis) + .SetAttributes(new { dim = axis })); public static Tensor gather_v2(T1 @params, T2 indices, int axis, string name = null) { @@ -202,14 +155,10 @@ namespace Tensorflow } public static Tensor pack(Tensor[] values, int axis = 0, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("Pack", name, new { values, axis }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Pack", name, - null, - values, - "axis", axis).FirstOrDefault(), - values, axis); + => tf.Context.ExecuteOp("Pack", name, new ExecuteOpArgs() + { + OpInputArgs = new object[] { values } + }.SetAttributes(new { axis })); /// /// Return a tensor with the same shape and contents as the input tensor or value. @@ -217,29 +166,7 @@ namespace Tensorflow /// /// public static Tensor identity(Tensor input, string name = null) - { - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Identity", name, - null, - input); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Identity", name, new { input }); - - if (tf.Runner.MustRecordGradient()) - { - tf.Runner.RecordGradient("Identity", _op.inputs, new object[] - { - "T", _op.get_attr("T") - }, _op.outputs); - } - - return _op.output; - } + => tf.Context.ExecuteOp("Identity", name, new ExecuteOpArgs(input)); public static Tensor invert_permutation(Tensor x, string name = null) { @@ -256,21 +183,7 @@ namespace Tensorflow } public static Tensor rank(Tensor input, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Rank", name, - null, - input); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Rank", name: name, args: new { input }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("Rank", name, new ExecuteOpArgs(input)); /// /// Creates a tensor filled with a scalar value. @@ -280,20 +193,7 @@ namespace Tensorflow /// A name for the operation (optional). /// A `Tensor`. Has the same type as `value`. public static Tensor fill(Tensor dims, T value, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Fill", name, - null, - dims, value); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Fill", name, new { dims, value }); - return _op.output; - } + => tf.Context.ExecuteOp("Fill", name, new ExecuteOpArgs(dims, value)); /// /// Return the reduction indices for computing gradients of s0 op s1 with broadcast. @@ -304,19 +204,8 @@ namespace Tensorflow /// A tuple of `Tensor` objects (r0, r1). public static (Tensor, Tensor) broadcast_gradient_args(Tensor s0, Tensor s1, string name = "") { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "BroadcastGradientArgs", name, - null, - s0, s1); - - return (results[0], results[1]); - } - - var _op = tf.OpDefLib._apply_op_helper("BroadcastGradientArgs", name, new { s0, s1 }); - - return (_op.outputs[0], _op.outputs[1]); + var results = tf.Context.ExecuteOp("BroadcastGradientArgs", name, new ExecuteOpArgs(s0, s1)); + return (results[0], results[1]); } public static Tensor reverse(Tensor tensor, T axis, string name = null) @@ -326,31 +215,10 @@ namespace Tensorflow } public static Tensor reshape(Tensor tensor, T shape, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("Reshape", name, new { tensor, shape }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Reshape", name, - null, - tensor, shape).FirstOrDefault(), - tensor, shape); + => tf.Context.ExecuteOp("Reshape", name, new ExecuteOpArgs(tensor, shape)); public static Tensor reshape(Tensor tensor, object[] shape, string name = null) - { - try - { - return tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("Reshape", name, new { tensor, shape }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Reshape", name, - null, - tensor, shape).FirstOrDefault(), - tensor, shape); - } - catch (InvalidArgumentError ex) - { - return reshape_eager_fallback(tensor, shape, name, tf.Context); - } - } + => tf.Context.ExecuteOp("Reshape", name, new ExecuteOpArgs(tensor, shape)); private static Tensor reshape_eager_fallback(Tensor tensor, object[] shape, string name, Context ctx) { @@ -400,21 +268,8 @@ namespace Tensorflow TF_DataType dtype = TF_DataType.DtInvalid, int axis = -1, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "OneHot", name, - null, - indices, depth, on_value, off_value, - "axis", axis); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("OneHot", name, new { indices, depth, on_value, off_value, axis }); - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("OneHot", name, new ExecuteOpArgs(indices, depth, on_value, off_value) + .SetAttributes(new { axis })); /// /// A placeholder op that passes through `input` when its output is not fed. @@ -430,35 +285,10 @@ namespace Tensorflow } public static Tensor select(Tensor condition, Tx x, Ty y, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Select", name, - null, - condition, x, y); - - return results[0]; - } + => tf.Context.ExecuteOp("Select", name, new ExecuteOpArgs(condition, x, y)); - var _op = tf.OpDefLib._apply_op_helper("Select", name, new { condition, t = x, e = y }); - return _op.outputs[0]; - } public static Tensor select_v2(Tensor condition, Tx x, Ty y, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "SelectV2", name, - null, - condition, x, y); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("SelectV2", name, new { condition, t = x, e = y }); - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("SelectV2", name, new ExecuteOpArgs(condition, x, y)); public static Tensor scatter_nd(Tensor indices, Tensor updates, Tensor[] shape, string name = null) { @@ -467,15 +297,8 @@ namespace Tensorflow } public static Tensor shape(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("Shape", name, - new { input, out_type }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Shape", name, - null, - input, - "out_type", out_type).FirstOrDefault(), - input); + => tf.Context.ExecuteOp("Shape", name, new ExecuteOpArgs(input) + .SetAttributes(new { out_type })); /// /// Returns shape of tensors. @@ -485,21 +308,10 @@ namespace Tensorflow /// /// public static Tensor[] shape_n(Tensor[] input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null) - { - if (tf.executing_eagerly()) + => tf.Context.ExecuteOp("ShapeN", name, new ExecuteOpArgs() { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "ShapeN", name, - null, - input, - "out_type", out_type); - - return results; - } - - var _op = tf.OpDefLib._apply_op_helper("ShapeN", name, new { input, out_type }); - return _op.outputs; - } + OpInputArgs = new object[] { input } + }.SetAttributes(new { out_type })); public static Tensor size(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null) { @@ -542,72 +354,23 @@ namespace Tensorflow public static Tensor[] split_v(Tensor value, Tensor size_splits, int axis, int num_split, string name = null) - { - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "SplitV", name, - null, - value, size_splits, axis, - "num_split", num_split); - - return results; - } - - var _op = tf.OpDefLib._apply_op_helper("SplitV", name, new { split_dim = axis, value, num_split }); - return _op.outputs; - } + => tf.Context.ExecuteOp("SplitV", name, new ExecuteOpArgs(value, size_splits, axis) + .SetAttributes(new { num_split })); public static Tensor tile(Tensor input, Tensor multiples, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("Tile", name, new { input, multiples }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Tile", name, - null, - input, multiples).FirstOrDefault(), - input, multiples); + => tf.Context.ExecuteOp("Tile", name, new ExecuteOpArgs(input, multiples)); public static Tensor tile(Tensor input, object[] multiples, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("Tile", name, new { input, multiples }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Tile", name, - null, - input, multiples).FirstOrDefault(), - input, multiples); + => tf.Context.ExecuteOp("Tile", name, new ExecuteOpArgs(input, multiples)); public static Tensor transpose(Tensor x, T1 perm, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Transpose", name, - null, - x, perm); - - return results[0]; - } - var _op = tf.OpDefLib._apply_op_helper("Transpose", name, new { x, perm }); - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("Transpose", name, new ExecuteOpArgs(x, perm)); public static Tensor ones_like(Tensor x, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("OnesLike", name, new { x }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "OnesLike", name, - null, - x).FirstOrDefault(), - x); + => tf.Context.ExecuteOp("OnesLike", name, new ExecuteOpArgs(x)); public static Tensor zeros_like(Tensor x, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("ZerosLike", name, new { x }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "ZerosLike", name, - null, - x).FirstOrDefault(), - x); + => tf.Context.ExecuteOp("ZerosLike", name, new ExecuteOpArgs(x)); public static Tensor stop_gradient(Tensor x, string name = null) { @@ -623,53 +386,32 @@ namespace Tensorflow long new_axis_mask = 0, long shrink_axis_mask = 0, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("StridedSlice", name, new - { - input, - begin, - end, - strides, - begin_mask, - end_mask, - ellipsis_mask, - new_axis_mask, - shrink_axis_mask - }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "StridedSlice", name, - null, - input, begin, end, strides, - "begin_mask", begin_mask, - "end_mask", end_mask, - "ellipsis_mask", ellipsis_mask, - "new_axis_mask", new_axis_mask, - "shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), - input, begin, end, strides); - - public static Operation resource_strided_slice_assign(Tensor input, Tensor begin, Tensor end, Tensor strides, Tensor value, + => tf.Context.ExecuteOp("StridedSlice", name, new ExecuteOpArgs(input, begin, end, strides) + .SetAttributes(new + { + begin_mask, + end_mask, + ellipsis_mask, + new_axis_mask, + shrink_axis_mask + })); + + public static Tensor resource_strided_slice_assign(Tensor input, Tensor begin, Tensor end, Tensor strides, Tensor value, int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0, int shrink_axis_mask = 0, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("ResourceStridedSliceAssign", name, new - { - input, begin, end, strides, value, - begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask - }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "ResourceStridedSliceAssign", name, - null, - input, begin, end, strides, value, - "begin_mask", begin_mask, - "end_mask", end_mask, - "ellipsis_mask", ellipsis_mask, - "new_axis_mask", new_axis_mask, - "shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), - input, begin, end, strides, value); + => tf.Context.ExecuteOp("ResourceStridedSliceAssign", name, new ExecuteOpArgs(input, begin, end, strides, value) + .SetAttributes(new + { + begin_mask, + end_mask, + ellipsis_mask, + new_axis_mask, + shrink_axis_mask + })); public static Tensor strided_slice(Tensor input, T[] begin, T[] end, T[] strides, int begin_mask = 0, @@ -707,23 +449,8 @@ namespace Tensorflow /// A name for the operation (optional). /// A `Tensor`. Has the same type as `input`. public static Tensor squeeze(Tensor input, int[] axis = null, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Squeeze", name, - null, - input, - "squeeze_dims", axis); - - return results[0]; - } - - if (axis == null) axis = new int[0]; - var _op = tf.OpDefLib._apply_op_helper("Squeeze", name, args: new { input, squeeze_dims = axis }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("Squeeze", name, new ExecuteOpArgs(input) + .SetAttributes(new { squeeze_dims = axis })); /// /// Return the shape of s0 op s1 with broadcast. @@ -749,20 +476,6 @@ namespace Tensorflow /// /// public static Tensor broadcast_to(Tensor input, T shape, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "BroadcastTo", name, - null, - input, shape); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("BroadcastTo", name, args: new { input, shape, name }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("BroadcastTo", name, new ExecuteOpArgs(input, shape)); } } diff --git a/src/TensorFlowNET.Core/Operations/gen_image_ops.cs b/src/TensorFlowNET.Core/Operations/gen_image_ops.cs index 955f2db3..8b81dc8a 100644 --- a/src/TensorFlowNET.Core/Operations/gen_image_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_image_ops.cs @@ -70,38 +70,17 @@ namespace Tensorflow float acceptable_fraction = 1, string dct_method = "", string name = null) - { - // Add nodes to the TensorFlow graph. - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "DecodeJpeg", name, - null, - contents, - "channels", channels, - "ratio", ratio, - "fancy_upscaling", fancy_upscaling, - "try_recover_truncated", try_recover_truncated, - "acceptable_fraction", acceptable_fraction, - "dct_method", dct_method); - return results[0]; - } - else - { - var _op = tf.OpDefLib._apply_op_helper("DecodeJpeg", name: name, args: new - { - contents, - channels, - ratio, - fancy_upscaling, - try_recover_truncated, - acceptable_fraction, - dct_method - }); - - return _op.outputs[0]; - } - } + => tf.Context.ExecuteOp("DecodeJpeg", name, + new ExecuteOpArgs(contents).SetAttributes( + new + { + channels, + ratio, + fancy_upscaling, + try_recover_truncated, + acceptable_fraction, + dct_method + })); public static Tensor decode_gif(Tensor contents, string name = null) @@ -171,85 +150,36 @@ namespace Tensorflow bool align_corners = false, bool half_pixel_centers = false, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "ResizeBilinear", name, - null, - images, size, - "align_corners", align_corners, - "half_pixel_centers", half_pixel_centers); - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("ResizeBilinear", name: name, args: new - { - images, - size, - align_corners - }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("ResizeBilinear", name, + new ExecuteOpArgs(images, size).SetAttributes(new + { + align_corners, + half_pixel_centers + })); public static Tensor resize_bicubic(Tensor images, Tensor size, bool align_corners = false, bool half_pixel_centers = false, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "ResizeBicubic", name, - null, - images, size, - "align_corners", align_corners, - "half_pixel_centers", half_pixel_centers); - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("ResizeBicubic", name: name, args: new - { - images, - size, - align_corners - }); - - return _op.outputs[0]; - } - + => tf.Context.ExecuteOp("ResizeBicubic", name, + new ExecuteOpArgs(images, size).SetAttributes(new { align_corners, half_pixel_centers })); + public static Tensor resize_nearest_neighbor(Tensor images, Tsize size, bool align_corners = false, bool half_pixel_centers = false, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("ResizeNearestNeighbor", name: name, args: new - { - images, - size, - align_corners, - half_pixel_centers - }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "ResizeNearestNeighbor", name, - null, - images, size, - "align_corners", align_corners, - "half_pixel_centers", half_pixel_centers).FirstOrDefault(), - images); + => tf.Context.ExecuteOp("ResizeNearestNeighbor", name, + new ExecuteOpArgs(images, size).SetAttributes(new { align_corners, half_pixel_centers })); public static Tensor resize_nearest_neighbor_grad(Tensor grads, Tensor size, bool align_corners = false, bool half_pixel_centers = false, string name = null) - => tf.Context.RunInAutoMode2("ResizeNearestNeighborGrad", name, new AutoModeArgs + => tf.Context.ExecuteOp("ResizeNearestNeighborGrad", name, new ExecuteOpArgs(grads, size) { - OpInputArgs = new { grads, size }, - OpAttrs = new { align_corners, half_pixel_centers }, GetGradientAttrs = (op) => new { T = op.get_attr("T"), align_corners = op.get_attr("align_corners"), half_pixel_centers = op.get_attr("half_pixel_centers") } - }); + }.SetAttributes(new { align_corners, half_pixel_centers })); } } diff --git a/src/TensorFlowNET.Core/Operations/gen_logging_ops.cs b/src/TensorFlowNET.Core/Operations/gen_logging_ops.cs index c62a8b8a..03159aaa 100644 --- a/src/TensorFlowNET.Core/Operations/gen_logging_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_logging_ops.cs @@ -25,10 +25,9 @@ namespace Tensorflow { if (tf.Context.executing_eagerly()) { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + var results = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo( "Assert", name, - null, - new object[] { condition, data, summarize }); + new object[] { condition, data, summarize })); return results[0]; } diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index b40b3b91..f6775ad9 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -37,20 +37,10 @@ namespace Tensorflow /// /// public static Tensor add_n(Tensor[] inputs, string name = null) - { - if (tf.Context.executing_eagerly()) + => tf.Context.ExecuteOp("AddN", name, new ExecuteOpArgs() { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "AddN", name, - null, - new[] { inputs }); - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("AddN", name, args: new { inputs }); - - return _op.outputs[0]; - } + OpInputArgs = new object[] { inputs } + }); /// /// Returns the index with the largest value across dimensions of a tensor. @@ -61,20 +51,9 @@ namespace Tensorflow /// /// public static Tensor arg_max(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "ArgMax", name, - null, - input, dimension, - "output_type", output_type); - - return results[0]; - } + => tf.Context.ExecuteOp("ArgMax", name, new ExecuteOpArgs(input, dimension) + .SetAttributes(new { output_type })); - return tf.OpDefLib._apply_op_helper("ArgMax", name, args: new { input, dimension, output_type }).output; - } /// /// Returns the index with the smallest value across dimensions of a tensor. @@ -116,13 +95,7 @@ namespace Tensorflow /// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) /// public static Tensor div_no_nan(Tensor x, Tensor y, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("DivNoNan", name: name, new { x, y }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "DivNoNan", name, - null, - x, y).FirstOrDefault(), - x, y); + => tf.Context.ExecuteOp("DivNoNan", name, new ExecuteOpArgs(x, y)); public static Tensor mean(Tensor input, int axis, bool keep_dims = false, string name = null) => mean(input, ops.convert_to_tensor(axis), keep_dims: keep_dims, name: name); @@ -141,17 +114,15 @@ namespace Tensorflow /// A name for the operation (optional). /// A `Tensor`. Has the same type as `input`. public static Tensor mean(Tensor input, Tensor axis, bool keep_dims = false, string name = null) - => tf.Context.RunInAutoMode2("Mean", name, new AutoModeArgs + => tf.Context.ExecuteOp("Mean", name, new ExecuteOpArgs(input, axis) { - OpInputArgs = new { input, axis }, - OpAttrs = new { keep_dims, reduction_indices = axis }, GetGradientAttrs = (op) => new { T = op.get_attr("T"), Tidx = op.get_attr("Tidx"), keep_dims = op.get_attr("keep_dims") } - }); + }.SetAttributes(new { keep_dims, reduction_indices = axis })); public static Tensor mean(Tensor[] inputs, Tensor axis, bool keep_dims = false, string name = null) { @@ -176,28 +147,8 @@ namespace Tensorflow } public static Tensor prod(T1 input, T2 axis, bool keep_dims = false, string name = null) - { - if (tf.Context.executing_eagerly()) - { - try - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Prod", name, - null, - input, axis, - "keep_dims", keep_dims); - - return results[0]; - } - catch (Exception) - { - return prod_eager_fallback(input as Tensor, axis as int[], keep_dims, name, tf.Context); - } - } - - var _op = tf.OpDefLib._apply_op_helper("Prod", name, args: new { input, reduction_indices = axis, keep_dims }); - return _op.output; - } + => tf.Context.ExecuteOp("Prod", name, + new ExecuteOpArgs(input, axis).SetAttributes(new { keep_dims, reduction_indices = axis })); private static Tensor prod_eager_fallback(Tensor input_t, int[] axis, bool keep_dims, string name, Context ctx = null) { @@ -224,84 +175,22 @@ namespace Tensorflow } public static Tensor add(Tensor x, Tensor y, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Add", name, null, - x, y); - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Add", name, args: new { x, y }); - - return _op.output; - } + => tf.Context.ExecuteOp("Add", name, new ExecuteOpArgs(x, y)); public static Tensor add(Tx x, Ty y, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Add", name, - null, - x, y); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Add", name, args: new { x, y }); - - return _op.output; - } + => tf.Context.ExecuteOp("Add", name, new ExecuteOpArgs(x, y)); public static Tensor add_v2(Tx x, Ty y, string name = null) - { - // forward_compatible(2019, 6, 25): - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "AddV2", name, - null, - x, y); - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("AddV2", name, args: new { x, y }); - - return _op.output; - } + => tf.Context.ExecuteOp("AddV2", name, new ExecuteOpArgs(x, y)); public static Tensor atan(Tensor x, string name = null) - { - var _op = tf.OpDefLib._apply_op_helper("Atan", name, args: new { x }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("Atan", name, new ExecuteOpArgs(x)); public static Tensor ceil(Tensor x, string name = null) - { - var _op = tf.OpDefLib._apply_op_helper("Ceil", name, args: new { x }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("Ceil", name, new ExecuteOpArgs(x)); public static Tensor sin(Tensor x, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Sin", name, - null, - x); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Sin", name, args: new { x }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("Sin", name, new ExecuteOpArgs(x)); /// /// Computes sigmoid of x element-wise. @@ -318,13 +207,7 @@ namespace Tensorflow /// Specifically, y = 1 / (1 + exp(-x)). /// public static Tensor sigmoid(Tensor x, string name = "Sigmoid") - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("Sigmoid", name: name, new { x }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Sigmoid", name, - null, - x).FirstOrDefault(), - x); + => tf.Context.ExecuteOp("Sigmoid", name, new ExecuteOpArgs(x)); /// /// Computes the gradient of the sigmoid of x wrt its input. @@ -344,27 +227,10 @@ namespace Tensorflow /// dy is the corresponding input gradient. /// public static Tensor sigmoid_grad(Tensor y, Tensor dy, string name = "SigmoidGrad") - => tf.Context.RunInAutoMode2("SigmoidGrad", name, new AutoModeArgs - { - OpInputArgs = new { y, dy } - }); + => tf.Context.ExecuteOp("SigmoidGrad", name, new ExecuteOpArgs(y, dy)); public static Tensor sign(T x, string name = "Sign") - { - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Sign", name, - null, - x); - - return results[0]; - } - - var op = tf.OpDefLib._apply_op_helper("Sign", name: name, args: new { x }); - - return op.outputs[0]; - } + => tf.Context.ExecuteOp("Sign", name, new ExecuteOpArgs(x)); public static Tensor sinh(Tensor x, string name = null) { @@ -374,21 +240,7 @@ namespace Tensorflow } public static Tensor cos(T x, string name = null) - { - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Cos", name, - null, - x); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Cos", name, args: new { x }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("Cos", name, new ExecuteOpArgs(x)); public static Tensor cosh(Tensor x, string name = null) { @@ -397,13 +249,6 @@ namespace Tensorflow return _op.outputs[0]; } - public static Tensor cumsum(Tensor x, T axis, bool exclusive = false, bool reverse = false, string name = null) - { - var _op = tf.OpDefLib._apply_op_helper("Cumsum", name, args: new { x, axis, exclusive, reverse }); - - return _op.outputs[0]; - } - /// /// Computes the sum along segments of a tensor. /// @@ -419,38 +264,10 @@ namespace Tensorflow } public static Tensor tan(Tensor x, string name = null) - { - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Tan", name, - null, - x); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Tan", name, args: new { x }); - - return _op.output; - } + => tf.Context.ExecuteOp("Tan", name, new ExecuteOpArgs(x)); public static Tensor tanh(Tensor x, string name = null) - { - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Tanh", name, - null, - x); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Tanh", name, args: new { x }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("Tanh", name, new ExecuteOpArgs(x)); /// /// Computes the gradient for the tanh of `x` wrt its input. @@ -460,20 +277,7 @@ namespace Tensorflow /// /// public static Tensor tanh_grad(Tensor y, Tensor dy, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "TanhGrad", name, - null, - y, dy); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("TanhGrad", name: name, args: new { y, dy }).output; - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("TanhGrad", name, new ExecuteOpArgs(y, dy)); public static Tensor floor(Tensor x, string name = null) { @@ -490,21 +294,7 @@ namespace Tensorflow } public static Tensor greater(Tx x, Ty y, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Greater", name, - null, - x, y); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Greater", name: name, args: new { x, y }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("Greater", name, new ExecuteOpArgs(x, y)); /// /// Computes the log of the absolute value of `Gamma(x)` element-wise. @@ -525,82 +315,22 @@ namespace Tensorflow } public static Tensor greater_equal(Tx x, Ty y, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "GreaterEqual", name, - null, - x, y); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("GreaterEqual", name: name, args: new { x, y }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("GreaterEqual", name, new ExecuteOpArgs(x, y)); public static Tensor less(Tx x, Ty y, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Less", name, - null, - x, y); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Less", name: name, args: new { x, y }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("Less", name, new ExecuteOpArgs(x, y)); public static Tensor less_equal(Tx x, Ty y, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "LessEqual", name, - null, - x, y); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("LessEqual", name: name, args: new { x, y }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("LessEqual", name, new ExecuteOpArgs(x, y)); public static Tensor log1p(Tensor x, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("Log1p", name: name, new { x }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Log1p", name, - null, - x).FirstOrDefault(), - x); + => tf.Context.ExecuteOp("Log1p", name, new ExecuteOpArgs(x)); public static Tensor logical_and(Tensor x, Tensor y, string name = null) => tf.OpDefLib._apply_op_helper("LogicalAnd", name, args: new { x, y }); public static Tensor logical_and(bool x, bool y, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "LogicalAnd", name, - null, - x, y); - - return results[0]; - } - - return tf.OpDefLib._apply_op_helper("LogicalAnd", name, args: new { x, y }); - } + => tf.Context.ExecuteOp("LogicalAnd", name, new ExecuteOpArgs(x, y)); public static Tensor logical_not(Tensor x, string name = null) { @@ -625,21 +355,7 @@ namespace Tensorflow } public static Tensor squared_difference(Tensor x, Tensor y, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "SquaredDifference", name, - null, - x,y); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("SquaredDifference", name, args: new { x, y, name }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("SquaredDifference", name, new ExecuteOpArgs(x, y)); /// /// Computes square of x element-wise. @@ -648,21 +364,7 @@ namespace Tensorflow /// A name for the operation (optional). /// A `Tensor`. Has the same type as `x`. public static Tensor square(Tensor x, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Square", name, - null, - x); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Square", name, args: new { x }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("Square", name, new ExecuteOpArgs(x)); /// /// Returns which elements of x are finite. @@ -691,13 +393,7 @@ namespace Tensorflow /// A name for the operation (optional). /// A `Tensor`. Has the same type as `x`. public static Tensor exp(Tensor x, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("Exp", name, args: new { x }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Exp", name, - null, - x).FirstOrDefault(), - x); + => tf.Context.ExecuteOp("Exp", name, new ExecuteOpArgs(x)); /// /// Computes natural logarithm of x element-wise. @@ -706,104 +402,26 @@ namespace Tensorflow /// name: A name for the operation (optional). /// A `Tensor`. Has the same type as `x`. public static Tensor log(Tensor x, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Log", name, - null, - x); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Log", name, args: new { x }); + => tf.Context.ExecuteOp("Log", name, new ExecuteOpArgs(x)); - return _op.outputs[0]; - } public static Tensor softplus(Tensor features, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Softplus", name, - null, - features); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Softplus", name, args: new { features }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("Softplus", name, new ExecuteOpArgs(features)); public static Tensor cast(Tensor x, TF_DataType DstT, bool Truncate = false, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Cast", name, - null, - x, - "DstT", DstT, "Truncate", Truncate).FirstOrDefault(), - x); + => tf.Context.ExecuteOp("Cast", name, new ExecuteOpArgs(x) + .SetAttributes(new { DstT, Truncate })); public static Tensor neg(Tensor x, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Neg", name, - null, - x); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Neg", name, args: new { x }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("Neg", name, new ExecuteOpArgs(x)); public static Tensor sqrt(Tensor x, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Sqrt", name, - null, - x); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Sqrt", name, args: new { x }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("Sqrt", name, new ExecuteOpArgs(x)); public static Tensor sub(Tensor x, Tensor y, string name = null) - => tf.Context.RunInAutoMode2("Sub", name, new AutoModeArgs - { - OpInputArgs = new { x, y } - }); + => tf.Context.ExecuteOp("Sub", name, new ExecuteOpArgs(x, y)); public static Tensor sub(Tx x, Ty y, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Sub", name, - null, - x, y); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Sub", name, args: new { x, y }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("Sub", name, new ExecuteOpArgs(x, y)); /// /// Returns the truth value of (x == y) element-wise. @@ -813,20 +431,7 @@ namespace Tensorflow /// /// public static Tensor equal(Tx x, Ty y, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Equal", name, - null, - x, y); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Equal", name, args: new { x, y }); - return _op.output; - } + => tf.Context.ExecuteOp("Equal", name, new ExecuteOpArgs(x, y)); /// /// Returns the truth value of (x != y) element-wise. @@ -838,54 +443,13 @@ namespace Tensorflow /// The name. /// public static Tensor not_equal(Tx x, Ty y, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "NotEqual", name, - null, - x, y); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("NotEqual", name, args: new { x, y }); - return _op.output; - } - + => tf.Context.ExecuteOp("NotEqual", name, new ExecuteOpArgs(x, y)); public static Tensor atan2(Tensor y, Tensor x, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Atan2", name, - null, - y, x); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Atan2", name, args: new { y, x }); - return _op.output; - } + => tf.Context.ExecuteOp("Atan2", name, new ExecuteOpArgs(y, x)); public static Tensor mul(Tx x, Ty y, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Mul", name, - null, - x, y); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Mul", name, args: new { x, y }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("Mul", name, new ExecuteOpArgs(x, y)); public static Tensor mul_no_nan(Tx x, Ty y, string name = null) { @@ -895,71 +459,16 @@ namespace Tensorflow } public static Tensor real_div(Tensor x, Tensor y, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "RealDiv", name, - null, - x, y); - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("RealDiv", name, args: new { x, y }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("RealDiv", name, new ExecuteOpArgs(x, y)); public static Tensor reciprocal(Tensor x, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Reciprocal", name, - null, - x); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Reciprocal", name, args: new { x }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("Reciprocal", name, new ExecuteOpArgs(x)); public static Tensor floor_mod(Tensor x, Tensor y, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "FloorMod", name, - null, - x, y); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("FloorMod", name, args: new { x, y }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("FloorMod", name, new ExecuteOpArgs(x, y)); public static Tensor floor_div(Tensor x, Tensor y, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "FloorDiv", name, - null, - x, y); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("FloorDiv", name, args: new { x, y }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("FloorDiv", name, new ExecuteOpArgs(x, y)); /// /// Multiply the matrix "a" by the matrix "b". @@ -971,56 +480,12 @@ namespace Tensorflow /// /// public static Tensor mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false, string name = null) - { - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "MatMul", name, - null, - a, b, - "transpose_a", transpose_a, "transpose_b", transpose_b); - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("MatMul", name, args: new { a, b, transpose_a, transpose_b }); - - return _op.output; - } - - /// - /// Multiply slices of the two matrices "x" and "y". - /// - /// - /// The `BatchMatMul` operation is embedded into the - /// `MatMul` operation on the DLL side. However the expected - /// attributes are not the same, hence we need to expose this - /// method to have the right args list on the `_apply_op_helper` - /// function. - /// - /// For each rank > 2 the first rank - 2 dimensions are considered - /// as fixed, and have to be consistent across the two matrices. A - /// common matrix multiplication is then applied over the residual - /// 2 dimensions. - /// - /// e.g. - /// x is (3, 6, 12); y is (3, 12, 6) - /// batch_matmul(x, y) ==> (3, 6, 6) - /// - /// - /// - /// - /// - /// - /// - public static Tensor batch_mat_mul(Tensor x, Tensor y, bool adj_x = false, bool adj_y = false, string name = null) - { - var _op = tf.OpDefLib._apply_op_helper( - "BatchMatMul", - name, - args: new { x, y, adj_x, adj_y }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("MatMul", name, new ExecuteOpArgs(a, b) + .SetAttributes(new + { + transpose_a, + transpose_b + })); /// /// Returns the max of x and y (i.e. x > y ? x : y) element-wise. @@ -1030,54 +495,13 @@ namespace Tensorflow /// /// public static Tensor maximum(T1 x, T2 y, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Maximum", name, - null, - x, y); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Maximum", name, args: new { x, y }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("Maximum", name, new ExecuteOpArgs(x, y)); public static Tensor minimum(T1 x, T2 y, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Minimum", name, - null, - x, y); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Minimum", name, args: new { x, y }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("Minimum", name, new ExecuteOpArgs(x, y)); public static Tensor _abs(Tensor x, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Abs", name, - null, - x); - - return results[0]; - } - var _op = tf.OpDefLib._apply_op_helper("Abs", name, args: new { x }); - - return _op.output; - } + => tf.Context.ExecuteOp("Abs", name, new ExecuteOpArgs(x)); public static Tensor _any(Tx input, Ty axis, bool keep_dims = false, string name = null) { @@ -1087,14 +511,15 @@ namespace Tensorflow } public static Tensor _max(Tx input, Ty axis, bool keep_dims = false, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("Max", name, new { input, reduction_indices = axis, keep_dims }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Max", name, - null, - input, axis, - "keep_dims", keep_dims).FirstOrDefault(), - input as Tensor); + => tf.Context.ExecuteOp("Max", name, new ExecuteOpArgs(input, axis) + { + GetGradientAttrs = (op) => new + { + T = op.get_attr("T"), + align_corners = op.get_attr("align_corners"), + half_pixel_centers = op.get_attr("half_pixel_centers") + } + }.SetAttributes(new { keep_dims, reduction_indices = axis })); public static Tensor _min(Tx input, Ty axis, bool keep_dims = false, string name = null) { @@ -1104,39 +529,11 @@ namespace Tensorflow } public static Tensor pow(Tx x, Ty y, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Pow", name, - null, - x, y); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Pow", name, args: new { x, y }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("Pow", name, new ExecuteOpArgs(x, y)); public static Tensor _sum(Tx input, Ty axis = default, bool keep_dims = false, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Sum", name, - null, - input, axis, - "keep_dims", keep_dims); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Sum", name, args: new { input, reduction_indices = axis, keep_dims }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("Sum", name, + new ExecuteOpArgs(input, axis).SetAttributes(new { keep_dims, reduction_indices = axis })); public static Tensor _sum(Tensor[] inputs, Tensor axis = default, bool keep_dims = false, string name = null) { @@ -1170,13 +567,7 @@ namespace Tensorflow /// /// public static Tensor range(Tensor start, Tensor limit, Tensor delta, string name = null) - => tf.Context.RunInAutoMode(() - => tf.OpDefLib._apply_op_helper("Range", name, new { start, limit, delta }).output, () - => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Range", name, - null, - start, limit, delta).FirstOrDefault(), - start, limit, delta); + => tf.Context.ExecuteOp("Range", name, new ExecuteOpArgs(start, limit, delta)); /// /// Rounds the values of a tensor to the nearest integer, element-wise. @@ -1207,20 +598,7 @@ namespace Tensorflow /// /// public static Tensor rsqrt(Tensor x, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Rsqrt", name, - null, - x); - - return results[0]; - } - var _op = tf.OpDefLib._apply_op_helper("Rsqrt", name, new { x }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("Rsqrt", name, new ExecuteOpArgs(x)); /// /// Returns the fraction of zeros in value. @@ -1229,10 +607,6 @@ namespace Tensorflow /// A name for the operation (optional). /// The fraction of zeros in value, with type float32. public static Tensor zero_fraction(Tensor value, string name = null) - { - var _op = tf.OpDefLib._apply_op_helper("zero_fraction", name, new { value, name }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("zero_fraction", name, new ExecuteOpArgs(value)); } } diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.eager.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.eager.cs index f80b5f0f..8e6e72d1 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.eager.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.eager.cs @@ -6,13 +6,6 @@ namespace Tensorflow public static partial class gen_math_ops { public static Tensor mul(IntPtr x, IntPtr y, string name = null) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Mul", name, - null, - x, y); - - return results[0]; - } + => tf.Context.ExecuteOp("Mul", name, new ExecuteOpArgs(x, y)); } } diff --git a/src/TensorFlowNET.Core/Operations/gen_random_ops.cs b/src/TensorFlowNET.Core/Operations/gen_random_ops.cs index 8528f4c4..12d41bf2 100644 --- a/src/TensorFlowNET.Core/Operations/gen_random_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_random_ops.cs @@ -29,31 +29,8 @@ namespace Tensorflow /// /// public static Tensor random_standard_normal(Tensor shape, TF_DataType dtype = TF_DataType.DtInvalid, int? seed = null, int? seed2 = null, string name = null) - { - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "RandomStandardNormal", name, - null, - shape, - "seed", seed, - "seed2", seed2, - "dtype", dtype); - - return results[0]; - } - - if (!seed.HasValue) - seed = 0; - if (!seed2.HasValue) - seed2 = 0; - - var _op = tf.OpDefLib._apply_op_helper("RandomStandardNormal", - name: name, - args: new { shape, dtype, seed, seed2 }); - - return _op.output; - } + => tf.Context.ExecuteOp("RandomStandardNormal", name, new ExecuteOpArgs(shape) + .SetAttributes(new { dtype, seed = seed ?? 0, seed2 = seed2 ?? 0 })); /// /// Outputs random integers from a uniform distribution. @@ -89,31 +66,8 @@ namespace Tensorflow /// /// public static Tensor random_uniform(Tensor shape, TF_DataType dtype, int? seed = 0, int? seed2 = 0, string name = null) - { - if (!seed.HasValue) - seed = 0; - if (!seed2.HasValue) - seed2 = 0; - - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "RandomUniform", name, - null, - shape, - "seed", seed, - "seed2", seed2, - "dtype", dtype); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("RandomUniform", - name: name, - args: new { shape, dtype, seed, seed2 }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("RandomUniform", name, new ExecuteOpArgs(shape) + .SetAttributes(new { dtype, seed = seed ?? 0, seed2 = seed2 ?? 0 })); /// /// @@ -125,23 +79,8 @@ namespace Tensorflow /// public static Tensor random_shuffle(Tensor value, int seed = 0, int seed2 = 0, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "RandomShuffle", name, - null, - value, seed, seed2); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("RandomShuffle", - name: name, - args: new { value, seed, seed2 }); - - return _op.output; - } + => tf.Context.ExecuteOp("RandomShuffle", name, new ExecuteOpArgs(value) + .SetAttributes(new { seed = seed, seed2 = seed2 })); /// /// Outputs random values from a truncated normal distribution. @@ -154,31 +93,8 @@ namespace Tensorflow /// public static Tensor truncated_normal(Tensor shape, TF_DataType dtype, int? seed = 0, int? seed2 = 0, string name = null) - { - if (!seed.HasValue) - seed = 0; - if (!seed2.HasValue) - seed2 = 0; - - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "TruncatedNormal", name, - null, - shape, - "seed", seed, - "seed2", seed2, - "dtype", dtype); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("TruncatedNormal", - name: name, - args: new { shape, dtype, seed, seed2 }); - - return _op.output; - } + => tf.Context.ExecuteOp("TruncatedNormal", name, new ExecuteOpArgs(shape) + .SetAttributes(new { dtype, seed = seed ?? 0, seed2 = seed2 ?? 0 })); public static Tensor multinomial(Tensor logits, int num_samples, int? seed = 0, int? seed2 = 0, TF_DataType output_dtype = TF_DataType.TF_INT64, string name = null) diff --git a/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs index a59dda67..e9c4a1f2 100644 --- a/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs @@ -24,10 +24,8 @@ namespace Tensorflow { if (tf.Context.executing_eagerly()) { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "AssignSubVariableOp", name, - null, - resource, value); + tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo( + "AssignSubVariableOp", name, resource, value)); return null; } @@ -46,10 +44,8 @@ namespace Tensorflow { if (tf.Context.executing_eagerly()) { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "AssignAddVariableOp", name, - null, - resource, value); + tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("AssignAddVariableOp", name, + resource, value)); return null; } @@ -63,10 +59,8 @@ namespace Tensorflow { if (tf.Context.executing_eagerly()) { - tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "AssignVariableOp", name, - null, - resource, value); + tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("AssignVariableOp", name, + resource, value)); return null; } @@ -80,10 +74,8 @@ namespace Tensorflow { if (tf.Context.executing_eagerly()) { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "VarIsInitializedOp", name, - null, - resource); + var results = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("VarIsInitializedOp", name, + resource)); return results[0]; } @@ -107,14 +99,17 @@ namespace Tensorflow { if (tf.Context.executing_eagerly()) { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "VarHandleOp", name, - null, - "container", container, - "shared_name", shared_name, - "dtype", dtype, - "shape", shape.dims, - "allowed_devices", new string[0]); + var results = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("VarHandleOp", name) + { + attrs = ConvertToDict(new + { + dtype, + shape = shape.dims, + container, + shared_name, + allowed_devices = new string[0] + }) + }); return results[0]; } @@ -131,26 +126,8 @@ namespace Tensorflow } public static Tensor destroy_resource_op(Tensor resource, bool ignore_lookup_error = true, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "DestroyResourceOp", name, - null, - resource, - "ignore_lookup_error", ignore_lookup_error); - - return results.Length == 0 ? null : results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("DestroyResourceOp", name, new - { - resource, - ignore_lookup_error - }); - - return _op.output; - } + => tf.Context.ExecuteOp("DestroyResourceOp", name, + new ExecuteOpArgs(resource).SetAttributes(new { ignore_lookup_error })); /// /// Reads the value of a variable. @@ -160,26 +137,8 @@ namespace Tensorflow /// /// public static Tensor read_variable_op(Tensor resource, TF_DataType dtype, string name = null) - { - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "ReadVariableOp", name, - null, - resource, - "dtype", dtype); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("ReadVariableOp", name, new - { - resource, - dtype - }); - - return _op.output; - } + => tf.Context.ExecuteOp("ReadVariableOp", name, new ExecuteOpArgs(resource) + .SetAttributes(new { dtype })); public static Tensor resource_gather(Tensor resource, Tensor indices, TF_DataType dtype, int batch_dims = 0, bool validate_indices = true, string name = null) diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index fc8a28d5..ef7988fe 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -45,10 +45,7 @@ namespace Tensorflow => gen_math_ops.add(x, y, name); public static Tensor add_v2(Tensor x, Tensor y, string name = null) - => tf.Context.RunInAutoMode2("AddV2", name, new AutoModeArgs - { - OpInputArgs = new { x, y } - }); + => tf.Context.ExecuteOp("AddV2", name, new ExecuteOpArgs(x, y)); public static Tensor add_v2(Tx x, Ty y, string name = null) => gen_math_ops.add_v2(x, y, name); @@ -171,15 +168,12 @@ namespace Tensorflow } public static Tensor cumsum(Tensor x, T axis = default, bool exclusive = false, bool reverse = false, string name = null) - { - return tf_with(ops.name_scope(name, "Cumsum", new { x }), scope => - { - name = scope; - x = ops.convert_to_tensor(x, name: "x"); - - return gen_math_ops.cumsum(x, axis: axis, exclusive: exclusive, reverse: reverse, name: name); - }); - } + => tf_with(ops.name_scope(name, "Cumsum", new { x }), scope => + { + name = scope; + return tf.Context.ExecuteOp("Cumsum", name, new ExecuteOpArgs(x, axis) + .SetAttributes(new { exclusive, reverse })); + }); /// /// Computes Psi, the derivative of Lgamma (the log of the absolute value of @@ -261,19 +255,13 @@ namespace Tensorflow /// /// public static Tensor erf(Tensor x, string name = null) - => tf.Context.RunInAutoMode2("Erf", name, new AutoModeArgs - { - OpInputArgs = new { x } - }); + => tf.Context.ExecuteOp("Erf", name, new ExecuteOpArgs(x)); public static Tensor sqrt(Tensor x, string name = null) => gen_math_ops.sqrt(x, name: name); public static Tensor multiply(Tensor x, Tensor y, string name = null) - => tf.Context.RunInAutoMode2("Mul", name, new AutoModeArgs - { - OpInputArgs = new { x, y } - }); + => tf.Context.ExecuteOp("Mul", name, new ExecuteOpArgs(x, y)); public static Tensor multiply(Tx x, Ty y, string name = null) => gen_math_ops.mul(x, y, name: name); @@ -720,23 +708,10 @@ namespace Tensorflow => tf_with(ops.name_scope(name, "Pow", new { x, y }), scope => { name = scope; + var x_tensor = ops.convert_to_tensor(x, name: "x"); + var y_tensor = ops.convert_to_tensor(y, name: "y", dtype: x_tensor.dtype.as_base_dtype()); - if (tf.executing_eagerly()) - { - var x_tensor = ops.convert_to_tensor(x, name: "x"); - var y_tensor = ops.convert_to_tensor(y, name: "y", dtype: x_tensor.dtype.as_base_dtype()); - - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Pow", name, - null, - x_tensor, y_tensor); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Pow", name, args: new { x, y }); - - return _op.output; + return tf.Context.ExecuteOp("Pow", name, new ExecuteOpArgs(x_tensor, y_tensor)); }); public static Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range") @@ -818,21 +793,41 @@ namespace Tensorflow public static Tensor batch_matmul(Tensor x, Tensor y, bool adj_x = false, bool adj_y = false, string name = null) - { - Tensor result = null; - - tf_with(ops.name_scope(name, "MatMul", new Tensor[] { x, y }), scope => + => tf_with(ops.name_scope(name, "MatMul", new Tensor[] { x, y }), scope => { name = scope; x = ops.convert_to_tensor(x, name: "a"); y = ops.convert_to_tensor(y, name: "b"); - result = gen_math_ops.batch_mat_mul(x, y, adj_x, adj_y, name); + return tf.Context.ExecuteOp("BatchMatMul", name, new ExecuteOpArgs(x, y) + .SetAttributes(new { adj_x, adj_y })); }); - return result; - } + public static Tensor bincount(Tensor arr, Tensor weights = null, + Tensor minlength = null, + Tensor maxlength = null, + TF_DataType dtype = TF_DataType.TF_INT32, + string name = null, + TensorShape axis = null, + bool binary_output = false) + => tf_with(ops.name_scope(name, "bincount"), scope => + { + name = scope; + if(!binary_output && axis == null) + { + var array_is_nonempty = math_ops.reduce_prod(array_ops.shape(arr)) > 0; + var output_size = math_ops.cast(array_is_nonempty, dtypes.int32) * (math_ops.reduce_max(arr) + 1); + if (minlength != null) + output_size = math_ops.maximum(minlength, output_size); + if (maxlength != null) + output_size = math_ops.minimum(maxlength, output_size); + var weights = constant_op.constant(new long[0], dtype: dtype); + return tf.Context.ExecuteOp("Bincount", name, new ExecuteOpArgs(arr, output_size, weights)); + } + + throw new NotImplementedException(""); + }); /// /// Returns the complex conjugate of a complex number. diff --git a/src/TensorFlowNET.Core/Operations/string_ops.cs b/src/TensorFlowNET.Core/Operations/string_ops.cs index deba7be4..2d7c54c7 100644 --- a/src/TensorFlowNET.Core/Operations/string_ops.cs +++ b/src/TensorFlowNET.Core/Operations/string_ops.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using NumSharp; +using Tensorflow.Framework; using static Tensorflow.Binding; namespace Tensorflow @@ -21,53 +23,13 @@ namespace Tensorflow public class string_ops { public Tensor lower(Tensor input, string encoding = "", string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "StringLower", name, - null, - input, encoding); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("StringLower", name: name, args: new - { - input, - encoding - }); - - return _op.output; - } + => tf.Context.ExecuteOp("StringLower", name, new ExecuteOpArgs(input, encoding)); public Tensor regex_replace(Tensor input, string pattern, string rewrite, bool replace_global = true, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "StaticRegexReplace", name, - null, - input, - "pattern", pattern, - "rewrite", rewrite, - "replace_global", replace_global); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("StaticRegexReplace", name: name, args: new - { - input, - pattern, - rewrite, - replace_global - }); - - return _op.output; - } - + => tf.Context.ExecuteOp("StaticRegexReplace", name, new ExecuteOpArgs(input) + .SetAttributes(new { pattern, rewrite, replace_global })); + /// /// Return substrings from `Tensor` of strings. /// @@ -79,28 +41,93 @@ namespace Tensorflow /// public Tensor substr(T input, int pos, int len, string @uint = "BYTE", string name = null) - { - if (tf.Context.executing_eagerly()) + => tf.Context.ExecuteOp("Substr", name, new ExecuteOpArgs(input, pos, len) + .SetAttributes(new { unit = @uint })); + + /// + /// Computes the length of each string given in the input tensor. + /// + /// + /// + /// + /// + public Tensor string_length(Tensor input, string name = null, string unit = "BYTE") + => tf.Context.ExecuteOp("StringLength", name, new ExecuteOpArgs(input) { - var input_tensor = tf.constant(input); - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Substr", name, - null, - input, pos, len, - "unit", @uint); + GetGradientAttrs = op => new + { + unit = op.get_attr("unit") + } + }.SetAttributes(new { unit })); - return results[0]; - } + public RaggedTensor string_split_v2(Tensor input, string sep = "", int maxsplit = -1, string name = null) + { + return tf_with(ops.name_scope(name, "StringSplit"), scope => + { + var sep_tensor = ops.convert_to_tensor(sep, dtype: TF_DataType.TF_STRING); + var result = tf.Context.ExecuteOp("StringSplitV2", name, + new ExecuteOpArgs(input, sep) + { + GetGradientAttrs = op => new + { + maxsplit = op.get_attr("maxsplit") + } + }.SetAttributes(new { maxsplit })); + var (indices, values, shape) = (result[0], result[1], result[2]); + indices.set_shape(new TensorShape(-1, 2)); + values.set_shape(new TensorShape(-1)); + shape.set_shape(new TensorShape(2)); + + var sparse_result = new SparseTensor(indices, values, shape); + return RaggedTensor.from_value_rowids(sparse_result.values, + value_rowids: sparse_result.indices[Slice.All, 0], + nrows: sparse_result.dense_shape[0], + validate: false); + }); + } - var _op = tf.OpDefLib._apply_op_helper("Substr", name: name, args: new + public (RaggedTensor, RaggedTensor) unicode_decode_with_offsets(Tensor input, string input_encoding, string errors, + int replacement_char = 0xFFFD, bool replace_control_characters = false, string name = null) + { + return tf_with(ops.name_scope(name, "UnicodeDecodeWithOffsets"), scope => { - input, - pos, - len, - unit = @uint + var (codepoints, byte_start_offsets) = _unicode_decode(input, input_encoding, errors, + replacement_char, replace_control_characters, + with_offsets: true, name: name); + return (codepoints, byte_start_offsets); }); + } + + (RaggedTensor, RaggedTensor) _unicode_decode(Tensor input, string input_encoding, string errors, int replacement_char, + bool replace_control_characters, bool with_offsets, string name = null) + { + if (with_offsets) + { + var flat_result = tf.Context.ExecuteOp("UnicodeDecodeWithOffsets", name, new ExecuteOpArgs(input) + { + GetGradientAttrs = op => new + { + input_encoding = op.get_attr("input_encoding"), + errors = op.get_attr("errors"), + replacement_char = op.get_attr("replacement_char"), + replace_control_characters = op.get_attr("replace_control_characters"), + Tsplits = op.get_attr("Tsplits") + } + }.SetAttributes(new + { + input_encoding, + errors, + replacement_char, + replace_control_characters + })); + + var codepoints = RaggedTensor.from_row_splits(flat_result[1], flat_result[0], validate: false); + + var offsets = RaggedTensor.from_row_splits(flat_result[2], flat_result[0], validate: false); + return (codepoints, offsets); + } - return _op.output; + return (null, null); } } } diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj index 26cd5139..7c6e3e00 100644 --- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj +++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj @@ -50,6 +50,7 @@ tf.net 0.4x.x aligns with TensorFlow v2.4.1 native library.true TRACE;DEBUG x64 + TensorFlow.NET.xml diff --git a/src/TensorFlowNET.Core/Tensors/EagerTensorV2.cs b/src/TensorFlowNET.Core/Tensors/EagerTensorV2.cs index 66af00ff..64aadec4 100644 --- a/src/TensorFlowNET.Core/Tensors/EagerTensorV2.cs +++ b/src/TensorFlowNET.Core/Tensors/EagerTensorV2.cs @@ -7,7 +7,7 @@ using static Tensorflow.Binding; namespace Tensorflow { - public class EagerTensorV2 : DisposableObject, ITensor + public class EagerTensorV2 : DisposableObject { SafeTensorHandleHandle EagerTensorHandle; public string Device diff --git a/src/TensorFlowNET.Core/Tensors/ITensor.cs b/src/TensorFlowNET.Core/Tensors/ITensor.cs deleted file mode 100644 index fe483e74..00000000 --- a/src/TensorFlowNET.Core/Tensors/ITensor.cs +++ /dev/null @@ -1,7 +0,0 @@ -namespace Tensorflow -{ - public interface ITensor - { - - } -} diff --git a/src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs b/src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs new file mode 100644 index 00000000..567014ab --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs @@ -0,0 +1,147 @@ +/***************************************************************************** + Copyright 2021 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Text; +using System.Linq; +using Tensorflow.Framework; +using static Tensorflow.Binding; +using NumSharp; + +namespace Tensorflow +{ + /// + /// Represents a ragged tensor. + /// + public class RaggedTensor : CompositeTensor + { + Tensor _values; + RowPartition _row_partition; + Tensor _row_splits => _row_partition.row_splits; + + public TF_DataType dtype => _values.dtype; + public TensorShape shape + { + get + { + var nrows = _row_partition.static_nrows; + var ncols = _row_partition.static_uniform_row_length; + return new TensorShape(nrows, ncols); + } + } + + public RaggedTensor this[params Slice[] slices] + { + get + { + var row_key = slices[0]; + var inner_keys = slices.Skip(1).ToArray(); + + var args = tensor_util.ParseSlices(slices); + + return tf_with(ops.name_scope(null, "RaggedGetItem", args), scope => + { + string name = scope; + return _ragged_getitem_inner_dimensions(this, inner_keys); + }); + } + } + + RaggedTensor _ragged_getitem_inner_dimensions(RaggedTensor input, Slice[] slices) + { + return input; + } + + public RaggedTensor(Tensor values, + bool @internal = true, + RowPartition row_partition = null) + { + _values = values; + _row_partition = row_partition; + } + + public static RaggedTensor from_row_partition(Tensor values, RowPartition row_partition, bool validate = true) + { + return new RaggedTensor(values, @internal: true, row_partition: row_partition); + } + + /// + /// Creates a `RaggedTensor` with rows partitioned by `value_rowids`. + /// + /// + /// + /// + /// + /// + /// + public static RaggedTensor from_value_rowids(Tensor values, Tensor value_rowids, + Tensor nrows = null, string name = null, bool validate = true) + { + return tf_with(ops.name_scope(name, "RaggedFromValueRowIds"), scope => + { + var row_partition = RowPartition.from_value_rowids(value_rowids, + nrows: nrows, + validate: validate); + return from_row_partition(values, row_partition, validate: validate); + }); + } + + public static RaggedTensor from_row_splits(Tensor values, Tensor row_splits, + string name = null, bool validate = true) + { + return tf_with(ops.name_scope(name, "RaggedFromRowSplits"), scope => + { + var row_partition = RowPartition.from_row_splits(row_splits, + validate: validate); + return from_row_partition(values, row_partition, validate: validate); + }); + } + + Tensor _to_variant(bool batched_input = false, string name = null) + => tf_with(ops.name_scope(name, "RaggedToVariant"), scope => + { + return tf.Context.ExecuteOp("RaggedTensorToVariant", name, + new ExecuteOpArgs(nested_row_splits, flat_values) + { + GetGradientAttrs = op => new + { + RAGGED_RANK = op.get_attr("RAGGED_RANK"), + Tvalues = op.get_attr("Tvalues"), + Tsplits = op.get_attr("Tsplits"), + batched_input = op.get_attr("batched_input") + } + }.SetAttributes(new { batched_input })); + }); + + Tensor flat_values + => _values; + + Tensor[] nested_row_splits + => new[] { _row_splits }; + + public override string ToString() + => $"tf.RaggedTensor: shape={shape} [{string.Join(", ", _values.StringData().Take(10))}]"; + + public static implicit operator Tensor(RaggedTensor indexedSlices) + => indexedSlices._to_variant(); + + public static implicit operator RaggedTensor(Tensor tensor) + { + return tensor.Tag as RaggedTensor; + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs b/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs new file mode 100644 index 00000000..6a52397a --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs @@ -0,0 +1,103 @@ +/***************************************************************************** + Copyright 2021 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Framework; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// Partitioning of a sequence of values into contiguous subsequences ("rows"). + /// + public class RowPartition : CompositeTensor + { + Tensor _row_splits; + public Tensor row_splits => _row_splits; + Tensor _row_lengths; + Tensor _value_rowids; + Tensor _nrows; + + public int static_nrows + { + get + { + return _row_splits.shape[0] - 1; + } + } + + public int static_uniform_row_length + { + get + { + return -1; + } + } + + public RowPartition(Tensor row_splits, + Tensor row_lengths = null, Tensor value_rowids = null, Tensor nrows = null, + Tensor uniform_row_length = null) + { + _row_splits = row_splits; + _row_lengths = row_lengths; + _value_rowids = value_rowids; + _nrows = nrows; + } + + /// + /// Creates a `RowPartition` with rows partitioned by `value_rowids`. + /// + /// + /// + /// + /// + /// + public static RowPartition from_value_rowids(Tensor value_rowids, + Tensor nrows = null, bool validate = true, TF_DataType preferred_dtype = TF_DataType.DtInvalid) + { + return tf_with(ops.name_scope(null, "RowPartitionFromValueRowIds"), scope => + { + var value_rowids_int32 = math_ops.cast(value_rowids, dtypes.int32); + var nrows_int32 = math_ops.cast(nrows, dtypes.int32); + var row_lengths = tf.math.bincount(value_rowids_int32, + minlength: nrows_int32, + maxlength: nrows_int32, + dtype: value_rowids.dtype); + var row_splits = array_ops.concat(new object[] + { + ops.convert_to_tensor(new long[] { 0 }), + tf.cumsum(row_lengths) + }, axis: 0); + + return new RowPartition(row_splits, + row_lengths: row_lengths, + value_rowids: value_rowids, + nrows: nrows); + }); + } + + public static RowPartition from_row_splits(Tensor row_splits, + bool validate = true, TF_DataType preferred_dtype = TF_DataType.DtInvalid) + { + return tf_with(ops.name_scope(null, "RowPartitionFromRowSplits"), scope => + { + return new RowPartition(row_splits); + }); + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Ragged/SparseTensor.cs b/src/TensorFlowNET.Core/Tensors/Ragged/SparseTensor.cs new file mode 100644 index 00000000..987d8d1d --- /dev/null +++ b/src/TensorFlowNET.Core/Tensors/Ragged/SparseTensor.cs @@ -0,0 +1,76 @@ +/***************************************************************************** + Copyright 2021 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Linq; +using Tensorflow.Framework; +using static Tensorflow.Binding; + +namespace Tensorflow +{ + /// + /// Represents a sparse tensor. + /// + public class SparseTensor : CompositeTensor + { + public Tensor indices; + + public Tensor values; + + public Tensor dense_shape; + + public SparseTensor(Tensor indices, Tensor values, Tensor dense_shape) + { + this.indices = indices; + this.values = values; + this.dense_shape = dense_shape; + _init(); + } + + public SparseTensor(long[,] indices_, Array values_, long[] dense_shape_) + { + tf_with(ops.name_scope(null, "SparseTensor", new { }), delegate + { + indices = ops.convert_to_tensor( + indices_, name: "indices", dtype: dtypes.int64); + values = ops.convert_to_tensor(values_, name: "values"); + dense_shape = ops.convert_to_tensor( + dense_shape_, name: "dense_shape", dtype: dtypes.int64); + }); + _init(); + } + + void _init() + { + var indices_shape = indices.TensorShape.with_rank(2); + var values_shape = values.TensorShape.with_rank(1); + var dense_shape_shape = dense_shape.TensorShape.with_rank(1); + + indices_shape["0"].merge_with(values_shape[0]); + indices_shape["1"].merge_with(dense_shape_shape[0]); + } + + public static implicit operator Tensor(SparseTensor indexedSlices) + { + return indexedSlices.values; + } + + public static implicit operator SparseTensor(Tensor tensor) + { + return tensor.Tag as SparseTensor; + } + } +} diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 037a370a..d73a7933 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -33,9 +33,7 @@ namespace Tensorflow /// [SuppressMessage("ReSharper", "ConvertToAutoProperty")] public partial class Tensor : DisposableObject, - ITensor, ITensorOrOperation, - _TensorLike, ITensorOrTensorArray, IPackable, ICanBeFlattened @@ -97,6 +95,7 @@ namespace Tensorflow public SafeTensorHandleHandle EagerTensorHandle { get; set; } public bool IsEagerTensor => this is EagerTensor; + public bool IsSparseTensor => this is SparseTensor; /// /// Returns the shape of a tensor. diff --git a/src/TensorFlowNET.Core/Training/gen_training_ops.cs b/src/TensorFlowNET.Core/Training/gen_training_ops.cs index 675f920f..abe85a14 100644 --- a/src/TensorFlowNET.Core/Training/gen_training_ops.cs +++ b/src/TensorFlowNET.Core/Training/gen_training_ops.cs @@ -21,46 +21,19 @@ namespace Tensorflow { public class gen_training_ops { - public static Operation resource_apply_adam(Tensor var, Tensor m, Tensor v, Tensor beta1_power, Tensor beta2_power, + public static Tensor resource_apply_adam(Tensor var, Tensor m, Tensor v, Tensor beta1_power, Tensor beta2_power, Tensor lr, Tensor beta1, Tensor beta2, Tensor epsilon, Tensor grad, bool use_locking = false, bool use_nesterov = false, string name = null) - { - if (tf.executing_eagerly()) - { - var result = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "ResourceApplyAdam", name, - null, - var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, - "use_locking", use_locking, - "use_nesterov", use_nesterov); - return null; - } - - throw new NotImplementedException(""); - } + => tf.Context.ExecuteOp("ResourceApplyAdam", name, + new ExecuteOpArgs(var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) + .SetAttributes(new { use_locking, use_nesterov })); public static Tensor apply_adam(Tensor var, Tensor m, Tensor v, Tensor beta1_power, Tensor beta2_power, Tensor lr, Tensor beta1, Tensor beta2, Tensor epsilon, Tensor grad, bool use_locking = false, bool use_nesterov = false, string name = null) - { - var _op = tf.OpDefLib._apply_op_helper("ApplyAdam", name, new - { - var, - m, - v, - beta1_power, - beta2_power, - lr, - beta1, - beta2, - epsilon, - grad, - use_locking, - use_nesterov - }); - - return _op.outputs[0]; - } + => tf.Context.ExecuteOp("ApplyAdam", name, + new ExecuteOpArgs(var, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) + .SetAttributes(new { use_locking, use_nesterov })); public static Tensor apply_gradient_descent(IVariableV1 var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null) { @@ -75,27 +48,8 @@ namespace Tensorflow return _op.output; } - public static Operation resource_apply_gradient_descent(Tensor var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null) - { - if (tf.executing_eagerly()) - { - var result = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "ResourceApplyGradientDescent", name, - null, - var, alpha, delta, - "use_locking", use_locking); - return null; - } - - var _op = tf.OpDefLib._apply_op_helper("ResourceApplyGradientDescent", name, new - { - var, - alpha, - delta, - use_locking - }); - - return _op; - } + public static Tensor resource_apply_gradient_descent(Tensor var, Tensor alpha, Tensor delta, bool use_locking = false, string name = null) + => tf.Context.ExecuteOp("ResourceApplyGradientDescent", name, + new ExecuteOpArgs(var, alpha, delta).SetAttributes(new { use_locking })); } } diff --git a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs index df964d5f..8d8c0699 100644 --- a/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs +++ b/src/TensorFlowNET.Core/Variables/gen_state_ops.py.cs @@ -59,31 +59,8 @@ namespace Tensorflow bool validate_shape = true, bool use_locking = true, string name = null) - { - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Assign", name, - null, - @ref, value, - "validate_shape", validate_shape, - "use_locking", use_locking); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Assign", name: name, args: new { @ref, value, validate_shape, use_locking }); - - var _result = _op.outputs; - var _inputs_flat = _op.inputs; - - var _attrs = new Dictionary(); - _attrs["T"] = _op.get_attr("T"); - _attrs["validate_shape"] = _op.get_attr("validate_shape"); - _attrs["use_locking"] = _op.get_attr("use_locking"); - - return _result[0]; - } + => tf.Context.ExecuteOp("Assign", name, new ExecuteOpArgs(@ref, value) + .SetAttributes(new { validate_shape, use_locking })); public static Tensor assign_add(IVariableV1 @ref, T value, bool use_locking = false, string name = null) { diff --git a/src/TensorFlowNET.Keras/Activations/Activations.Relu.cs b/src/TensorFlowNET.Keras/Activations/Activations.Relu.cs index d4f08088..dfebfb29 100644 --- a/src/TensorFlowNET.Keras/Activations/Activations.Relu.cs +++ b/src/TensorFlowNET.Keras/Activations/Activations.Relu.cs @@ -4,21 +4,7 @@ namespace Tensorflow.Keras { public partial class Activations { - public Activation Relu = (features, name) => - { - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Relu", name, - null, - features); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Relu", name: name, args: new { features }); - - return _op.output; - }; + public Activation Relu = (features, name) + => tf.Context.ExecuteOp("Relu", name, new ExecuteOpArgs(features)); } } diff --git a/src/TensorFlowNET.Keras/Activations/Activations.Sigmoid.cs b/src/TensorFlowNET.Keras/Activations/Activations.Sigmoid.cs index 84220f4f..ad900bde 100644 --- a/src/TensorFlowNET.Keras/Activations/Activations.Sigmoid.cs +++ b/src/TensorFlowNET.Keras/Activations/Activations.Sigmoid.cs @@ -5,21 +5,7 @@ namespace Tensorflow.Keras { public partial class Activations { - public Activation Sigmoid = (features, name) => - { - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Sigmoid", name, - null, - features); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Sigmoid", name: name, args: new { x = features }); - - return _op.output; - }; + public Activation Sigmoid = (features, name) + => tf.Context.ExecuteOp("Sigmoid", name, new ExecuteOpArgs(features)); } } diff --git a/src/TensorFlowNET.Keras/Activations/Activations.Tanh.cs b/src/TensorFlowNET.Keras/Activations/Activations.Tanh.cs index 30bbdbf4..33dc5ba6 100644 --- a/src/TensorFlowNET.Keras/Activations/Activations.Tanh.cs +++ b/src/TensorFlowNET.Keras/Activations/Activations.Tanh.cs @@ -5,21 +5,7 @@ namespace Tensorflow.Keras { public partial class Activations { - public Activation Tanh = (features, name) => - { - if (tf.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Tanh", name, - null, - features); - - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Tanh", name: name, args: new { x = features }); - - return _op.output; - }; + public Activation Tanh = (features, name) + => tf.Context.ExecuteOp("Tanh", name, new ExecuteOpArgs(features)); } } diff --git a/src/TensorFlowNET.Keras/Engine/CombinerPreprocessingLayer.cs b/src/TensorFlowNET.Keras/Engine/CombinerPreprocessingLayer.cs index 11adfe9f..2e564480 100644 --- a/src/TensorFlowNET.Keras/Engine/CombinerPreprocessingLayer.cs +++ b/src/TensorFlowNET.Keras/Engine/CombinerPreprocessingLayer.cs @@ -8,11 +8,23 @@ namespace Tensorflow.Keras.Engine public class CombinerPreprocessingLayer : Layer { PreprocessingLayerArgs args; + protected ICombiner combiner; + protected bool _previously_updated; public CombinerPreprocessingLayer(PreprocessingLayerArgs args) : base(args) { - + _previously_updated = false; + } + + public virtual void adapt(IDatasetV2 data, bool reset_state = true) + { + IAccumulator accumulator; + if (!reset_state) + accumulator = combiner.Restore(); + + var next_data = data.make_one_shot_iterator(); + var data_element = next_data.next(); } } } diff --git a/src/TensorFlowNET.Keras/Engine/Interfaces/IAccumulator.cs b/src/TensorFlowNET.Keras/Engine/Interfaces/IAccumulator.cs new file mode 100644 index 00000000..df819839 --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Interfaces/IAccumulator.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Engine +{ + public interface IAccumulator + { + } +} diff --git a/src/TensorFlowNET.Keras/Engine/Interfaces/ICombiner.cs b/src/TensorFlowNET.Keras/Engine/Interfaces/ICombiner.cs new file mode 100644 index 00000000..8fe1764d --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/Interfaces/ICombiner.cs @@ -0,0 +1,19 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Engine +{ + /// + /// Functional object that defines a shardable computation. + /// + public interface ICombiner + { + void Compute(Tensor values, IAccumulator accumulator = null); + void Merge(); + void Extract(); + IAccumulator Restore(); + void Serialize(); + void Deserialize(); + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookup.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookup.cs new file mode 100644 index 00000000..5e02f562 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookup.cs @@ -0,0 +1,30 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Layers +{ + public class IndexLookup : CombinerPreprocessingLayer + { + public IndexLookup(int max_tokens = -1, + int num_oov_indices = 1, + string mask_token = "", + string oov_token = "[UNK]", + string encoding = "utf-8", + bool invert = false) : base(new PreprocessingLayerArgs()) + { + var num_mask_tokens = mask_token == null ? 0 : 1; + var vocab_size = max_tokens - (num_oov_indices + num_mask_tokens); + combiner = new IndexLookupCombiner(vocab_size, mask_token); + } + + public override void adapt(IDatasetV2 data, bool reset_state = true) + { + if (!reset_state) + throw new ValueError("IndexLookup does not support streaming adapts."); + base.adapt(data, reset_state); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookupAccumulator.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookupAccumulator.cs new file mode 100644 index 00000000..e2de669d --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookupAccumulator.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Layers +{ + public class IndexLookupAccumulator : IAccumulator + { + public Dictionary CountDict { get; set; } + public IndexLookupAccumulator() + { + CountDict = new Dictionary(); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookupCombiner.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookupCombiner.cs new file mode 100644 index 00000000..ac4c5dc9 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/IndexLookupCombiner.cs @@ -0,0 +1,55 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.Engine; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Combiner for the IndexLookup preprocessing layer. + /// + public class IndexLookupCombiner : ICombiner + { + int _vocab_size; + string _mask_value; + + public IndexLookupCombiner(int vocab_size = -1, string mask_value = null) + { + _vocab_size = vocab_size; + _mask_value = mask_value; + } + + public void Compute(Tensor values, IAccumulator accumulator = null) + { + if(accumulator == null) + { + accumulator = new IndexLookupAccumulator(); + } + } + + public void Deserialize() + { + throw new NotImplementedException(); + } + + public void Extract() + { + throw new NotImplementedException(); + } + + public void Merge() + { + throw new NotImplementedException(); + } + + public IAccumulator Restore() + { + throw new NotImplementedException(); + } + + public void Serialize() + { + throw new NotImplementedException(); + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/StringLookup.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/StringLookup.cs new file mode 100644 index 00000000..616af1c6 --- /dev/null +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/StringLookup.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Layers +{ + /// + /// Maps strings from a vocabulary to integer indices. + /// + class StringLookup : IndexLookup + { + public StringLookup(int max_tokens = -1, + int num_oov_indices = 1, + string mask_token = "", + string[] vocabulary = null, + string oov_token = "[UNK]", + string encoding = "utf-8", + bool invert = false) + { + + } + } +} diff --git a/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs index c72860a6..038f419b 100644 --- a/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs +++ b/src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs @@ -3,12 +3,14 @@ using System.Collections.Generic; using System.Text; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using static Tensorflow.Binding; namespace Tensorflow.Keras.Layers { public class TextVectorization : CombinerPreprocessingLayer { TextVectorizationArgs args; + IndexLookup _index_lookup_layer; public TextVectorization(TextVectorizationArgs args) : base(args) @@ -16,6 +18,11 @@ namespace Tensorflow.Keras.Layers this.args = args; args.DType = TF_DataType.TF_STRING; // string standardize = "lower_and_strip_punctuation", + + var mask_token = args.OutputMode == "int" ? "" : null; + _index_lookup_layer = new StringLookup(max_tokens: args.MaxTokens, + mask_token: mask_token, + vocabulary: args.Vocabulary); } /// @@ -23,13 +30,14 @@ namespace Tensorflow.Keras.Layers /// /// /// - public void adapt(IDatasetV2 data, bool reset_state = true) + public override void adapt(IDatasetV2 data, bool reset_state = true) { var shape = data.output_shapes[0]; if (shape.rank == 1) data = data.map(tensor => array_ops.expand_dims(tensor, -1)); build(data.variant_tensor); var preprocessed_inputs = data.map(_preprocess); + _index_lookup_layer.adapt(preprocessed_inputs); } protected override void build(Tensors inputs) @@ -39,14 +47,17 @@ namespace Tensorflow.Keras.Layers Tensors _preprocess(Tensors inputs) { + Tensor input_tensor = null; if (args.Standardize != null) - inputs = args.Standardize(inputs); + input_tensor = args.Standardize(inputs); if (!string.IsNullOrEmpty(args.Split)) { if (inputs.shape.ndim > 1) - inputs = array_ops.squeeze(inputs, axis: new[] { -1 }); + input_tensor = array_ops.squeeze(inputs, axis: new[] { -1 }); + if (args.Split == "whitespace") + input_tensor = tf.strings.split(input_tensor); } - return inputs; + return input_tensor; } } } diff --git a/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs b/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs index 6b62b9b2..03c9f8d1 100644 --- a/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs +++ b/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs @@ -1,4 +1,5 @@ using NumSharp; +using System; using System.Collections.Generic; using System.IO; using System.Linq; @@ -60,6 +61,7 @@ namespace Tensorflow.Keras.Preprocessings } } + Console.WriteLine($"Found {return_file_paths.Length} files belonging to {class_names.Length} classes."); return (return_file_paths, return_labels, class_names); } } diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj index 9eeb4634..0c50a5a1 100644 --- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj +++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj @@ -6,7 +6,7 @@ 8.0 Tensorflow.Keras AnyCPU;x64 - 0.4.1 + 0.5.0 Haiping Chen Keras for .NET Apache 2.0, Haiping Chen 2020 @@ -23,7 +23,8 @@ * Implemented backward_function. * Support model.load_weights. * Add Subtract layer -* Support YOLOv3 model. +* Support YOLOv3 model. +* Text preprocessing Keras for .NET Keras is an API designed for human beings, not machines. Keras follows best practices for reducing cognitive load: it offers consistent & simple APIs, it minimizes the number of user actions required for common use cases, and it provides clear & actionable error messages. @@ -34,8 +35,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac Git true Open.snk - 0.4.1.0 - 0.4.1.0 + 0.5.0.0 + 0.5.0.0 LICENSE @@ -48,6 +49,10 @@ Keras is an API designed for human beings, not machines. Keras follows best prac false + + Tensorflow.Keras.xml + + @@ -62,10 +67,6 @@ Keras is an API designed for human beings, not machines. Keras follows best prac - - - - diff --git a/src/TensorFlowNET.Text/Tokenizers/WhitespaceTokenizer.cs b/src/TensorFlowNET.Text/Tokenizers/WhitespaceTokenizer.cs index a0bbe473..bade6f4a 100644 --- a/src/TensorFlowNET.Text/Tokenizers/WhitespaceTokenizer.cs +++ b/src/TensorFlowNET.Text/Tokenizers/WhitespaceTokenizer.cs @@ -1,6 +1,8 @@ -using System; +using NumSharp; +using System; using System.Collections.Generic; using System.Text; +using static Tensorflow.Binding; namespace Tensorflow.Text.Tokenizers { @@ -13,7 +15,31 @@ namespace Tensorflow.Text.Tokenizers /// public Tensor tokenize(Tensor input) { + tokenize_with_offsets(input); throw new NotImplementedException(""); } + + Tensor[] tokenize_with_offsets(Tensor input) + { + tf_with(ops.name_scope(null, "WhitespaceTokenize"), scope => + { + _whitespace_tokenize_with_offsets_encode_decode_wrapper(input); + }); + throw new NotImplementedException(""); + } + + Tensor _whitespace_tokenize_with_offsets_encode_decode_wrapper(Tensor input_tensor) + { + // Decode the strings and get byte offsets + var (codepoints, byte_start_offsets) = tf.strings.unicode_decode_with_offsets(input_tensor, "UTF-8"); + var byte_end_offsets = array_ops.concat(new Tensor[] + { + byte_start_offsets[Slice.All, new Slice(1)], + math_ops.cast( + array_ops.expand_dims(tf.strings.string_length(input_tensor), 1), + dtypes.int64) + }, 1); + return input_tensor; + } } } diff --git a/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj b/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj index 60955e68..6567a1ae 100644 --- a/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj +++ b/src/TensorFlowNet.Benchmarks/Tensorflow.Benchmark.csproj @@ -9,6 +9,7 @@ true DEBUG;TRACE + x64 diff --git a/test/TensorFlowNET.UnitTest/Basics/RandomTest.cs b/test/TensorFlowNET.UnitTest/Basics/RandomTest.cs index ed70fa35..b658586a 100644 --- a/test/TensorFlowNET.UnitTest/Basics/RandomTest.cs +++ b/test/TensorFlowNET.UnitTest/Basics/RandomTest.cs @@ -14,7 +14,7 @@ namespace TensorFlowNET.UnitTest.Basics /// Test the function of setting random seed /// This will help regenerate the same result /// - [TestMethod, Ignore] + [TestMethod] public void TFRandomSeedTest() { var initValue = np.arange(6).reshape(3, 2); @@ -60,7 +60,7 @@ namespace TensorFlowNET.UnitTest.Basics /// /// This part we use funcs in tf.random rather than only tf /// - [TestMethod, Ignore] + [TestMethod] public void TFRandomRaodomSeedTest() { tf.set_random_seed(1234); diff --git a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs index 0a858ef9..e8e87840 100644 --- a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs +++ b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs @@ -151,5 +151,25 @@ namespace TensorFlowNET.UnitTest.Dataset var cardinality = dataset.dataset_cardinality(); Assert.AreEqual(new long[] { 10 }, cardinality.numpy()); } + + [TestMethod] + public void Shuffle() + { + tf.set_random_seed(1234); + + var dataset = tf.data.Dataset.range(3); + var shuffled = dataset.shuffle(3); + + var zipped = tf.data.Dataset.zip(dataset, shuffled); + + bool allEqual = true; + foreach (var item in zipped) + { + if (item.Item1 != item.Item2) + allEqual = false; + } + + Assert.IsFalse(allEqual); + } } } diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs index f1d2e0fe..d98c5207 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs @@ -58,5 +58,13 @@ namespace TensorFlowNET.UnitTest.ManagedAPI Assert.AreEqual(strings[1], stringData[1]); Assert.AreEqual(strings[2], stringData[2]); } + + [TestMethod] + public void StringSplit() + { + var tensor = tf.constant(new[] { "hello world", "tensorflow .net csharp", "fsharp" }); + var ragged_tensor = tf.strings.split(tensor); + Assert.AreEqual((3, -1), ragged_tensor.shape); + } } } diff --git a/test/TensorFlowNET.UnitTest/Text/TokenizerTest.cs b/test/TensorFlowNET.UnitTest/Text/TokenizerTest.cs index 3b8237b9..65c69a3f 100644 --- a/test/TensorFlowNET.UnitTest/Text/TokenizerTest.cs +++ b/test/TensorFlowNET.UnitTest/Text/TokenizerTest.cs @@ -10,10 +10,12 @@ namespace TensorFlowNET.UnitTest.Text [TestClass] public class TokenizerTest { - [TestMethod] + [TestMethod, Ignore] public void Tokenize() { var docs = tf.constant(new[] { "Everything not saved will be lost." }); + var tokenizer = text.WhitespaceTokenizer(); + var tokens = tokenizer.tokenize(docs); } } }