diff --git a/README.md b/README.md index 15f72bf5..95caa446 100644 --- a/README.md +++ b/README.md @@ -11,8 +11,6 @@ *master branch is based on tensorflow 2.2 now, v0.15-tensorflow1.15 is from tensorflow1.15.* -TF.NET is a member project of [SciSharp STACK](https://github.com/SciSharp). - ![tensors_flowing](docs/assets/tensors_flowing.gif) @@ -56,59 +54,40 @@ using static Tensorflow.Binding; Linear Regression: ```c# -// We can set a fixed init value in order to debug +// Parameters +int training_steps = 1000; +float learning_rate = 0.01f; +int display_step = 100; + +// We can set a fixed init value in order to demo var W = tf.Variable(-0.06f, name: "weight"); var b = tf.Variable(-0.73f, name: "bias"); +var optimizer = tf.optimizers.SGD(learning_rate); -// Construct a linear model -var pred = tf.add(tf.multiply(X, W), b); - -// Mean squared error -var cost = tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * n_samples); - -// Gradient descent -// Note, minimize() knows to modify W and b because Variable objects are trainable=True by default -var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); - -// Initialize the variables (i.e. assign their default value) -var init = tf.global_variables_initializer(); - -// Start training -using(tf.Session()) +// Run training for the given number of steps. +foreach (var step in range(1, training_steps + 1)) { - // Run the initializer - sess.run(init); - - // Fit all training data - for (int epoch = 0; epoch < training_epochs; epoch++) + // Run the optimization to update W and b values. + // Wrap computation inside a GradientTape for automatic differentiation. + using var g = tf.GradientTape(); + // Linear regression (Wx + b). + var pred = W * X + b; + // Mean square error. + var loss = tf.reduce_sum(tf.pow(pred - Y, 2)) / (2 * n_samples); + // should stop recording + // Compute gradients. + var gradients = g.gradient(loss, (W, b)); + + // Update W and b following gradients. + optimizer.apply_gradients(zip(gradients, (W, b))); + + if (step % display_step == 0) { - foreach (var (x, y) in zip(train_X, train_Y)) - sess.run(optimizer, (X, x), (Y, y)); - - // Display logs per epoch step - if ((epoch + 1) % display_step == 0) - { - var c = sess.run(cost, (X, train_X), (Y, train_Y)); - Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)} b={sess.run(b)}"); - } + pred = W * X + b; + loss = tf.reduce_sum(tf.pow(pred - Y, 2)) / (2 * n_samples); + print($"step: {step}, loss: {loss.numpy()}, W: {W.numpy()}, b: {b.numpy()}"); } - - Console.WriteLine("Optimization Finished!"); - var training_cost = sess.run(cost, (X, train_X), (Y, train_Y)); - Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}"); - - // Testing example - var test_X = np.array(6.83f, 4.668f, 8.9f, 7.91f, 5.7f, 8.7f, 3.1f, 2.1f); - var test_Y = np.array(1.84f, 2.273f, 3.2f, 2.831f, 2.92f, 3.24f, 1.35f, 1.03f); - Console.WriteLine("Testing... (Mean square loss Comparison)"); - var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]), - (X, test_X), (Y, test_Y)); - Console.WriteLine($"Testing cost={testing_cost}"); - var diff = Math.Abs((float)training_cost - (float)testing_cost); - Console.WriteLine($"Absolute mean square loss difference: {diff}"); - - return diff < 0.01; -}); +} ``` Run this example in [Jupyter Notebook](https://github.com/SciSharp/SciSharpCube). diff --git a/docs/source/HelloWorld.md b/docs/source/HelloWorld.md index 8023d9f9..d8c6b32e 100644 --- a/docs/source/HelloWorld.md +++ b/docs/source/HelloWorld.md @@ -25,7 +25,15 @@ TensorFlow.NET uses the .NET Standard 2.0 standard, so your new project Target F ```cmd +### install tensorflow C# binding PM> Install-Package TensorFlow.NET + +### Install tensorflow binary +### For CPU version +PM> Install-Package SciSharp.TensorFlow.Redist + +### For GPU version (CUDA and cuDNN are required) +PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU ``` ### Start coding Hello World @@ -36,7 +44,7 @@ After installing the TensorFlow.NET package, you can use the `using Tensorflow` ```csharp using System; -using Tensorflow; +using static Tensorflow.Binding; namespace TensorFlowNET.Examples { diff --git a/docs/source/Placeholder.md b/docs/source/Placeholder.md index a578a127..2cf345bd 100644 --- a/docs/source/Placeholder.md +++ b/docs/source/Placeholder.md @@ -8,13 +8,13 @@ In this chapter we will talk about another common data type in TensorFlow: Place var x = tf.placeholder(tf.int32); var y = x * 3; -Python.with(tf.Session(), sess => +using (var sess = tf.Session()) { var result = sess.run(y, feed_dict: new FeedItem[] { new FeedItem(x, 2) }); // (int)result should be 6; -}); +} ``` diff --git a/src/TensorFlowNET.Console/TensorFlowNET.Console.csproj b/src/TensorFlowNET.Console/TensorFlowNET.Console.csproj index b78a25f3..a047afb9 100644 --- a/src/TensorFlowNET.Console/TensorFlowNET.Console.csproj +++ b/src/TensorFlowNET.Console/TensorFlowNET.Console.csproj @@ -8,7 +8,7 @@ - + diff --git a/src/TensorFlowNET.Core/APIs/c_api.cs b/src/TensorFlowNET.Core/APIs/c_api.cs index c1575fb4..d3dc15ed 100644 --- a/src/TensorFlowNET.Core/APIs/c_api.cs +++ b/src/TensorFlowNET.Core/APIs/c_api.cs @@ -43,7 +43,7 @@ namespace Tensorflow /// public partial class c_api { - public const string TensorFlowLibName = @"D:\SciSharp\tensorflow-google\bazel-bin\tensorflow\tensorflow.dll"; + public const string TensorFlowLibName = "tensorflow"; public static string StringPiece(IntPtr handle) { diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs index ec17cecc..59689bc5 100644 --- a/src/TensorFlowNET.Core/APIs/tf.array.cs +++ b/src/TensorFlowNET.Core/APIs/tf.array.cs @@ -186,7 +186,7 @@ namespace Tensorflow => array_ops.slice(input, begin, size, name: name); public Tensor squeeze(Tensor input, int[] axis = null, string name = null, int squeeze_dims = -1) - => gen_array_ops.squeeze(input, axis, name); + => array_ops.squeeze(input, axis, name); /// /// Stacks a list of rank-`R` tensors into one rank-`(R+1)` tensor. @@ -217,7 +217,7 @@ namespace Tensorflow Tensor off_value = null, TF_DataType dtype = TF_DataType.DtInvalid, int axis = -1, - string name = null) => array_ops.one_hot(indices, depth, dtype: dtype, axis: axis, name: name); + string name = null) => array_ops.one_hot(indices, ops.convert_to_tensor(depth), dtype: dtype, axis: axis, name: name); /// /// Pads a tensor diff --git a/src/TensorFlowNET.Core/APIs/tf.data.cs b/src/TensorFlowNET.Core/APIs/tf.data.cs new file mode 100644 index 00000000..7eee3a90 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.data.cs @@ -0,0 +1,30 @@ +/***************************************************************************** + Copyright 2020 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using NumSharp; + +namespace Tensorflow +{ + public partial class tensorflow + { + public DataOps data { get; } = new DataOps(); + + public class DataOps + { + public TensorSliceDataset Dataset { get; } = new TensorSliceDataset(); + } + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.keras.cs b/src/TensorFlowNET.Core/APIs/tf.keras.cs new file mode 100644 index 00000000..ec3af440 --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.keras.cs @@ -0,0 +1,25 @@ +/***************************************************************************** + Copyright 2020 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using NumSharp; + +namespace Tensorflow +{ + public partial class tensorflow + { + public KerasApi keras { get; } = new KerasApi(); + } +} diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs index 4ad70420..f960e14c 100644 --- a/src/TensorFlowNET.Core/APIs/tf.math.cs +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -21,6 +21,13 @@ namespace Tensorflow { public partial class tensorflow { + public MathApi math { get; } = new MathApi(); + public class MathApi + { + public Tensor log(Tensor x, string name = null) + => gen_math_ops.log(x, name); + } + public Tensor abs(Tensor x, string name = null) => math_ops.abs(x, name); @@ -254,7 +261,7 @@ namespace Tensorflow /// Any values less than clip_value_min are set to clip_value_min. Any values /// greater than clip_value_max are set to clip_value_max. /// - public Tensor clip_by_value (Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = "ClipByValue") + public Tensor clip_by_value(Tensor t, T1 clip_value_min, T2 clip_value_max, string name = "ClipByValue") => clip_ops.clip_by_value(t, clip_value_min, clip_value_max, name); public Tensor sub(Tx a, Ty b, string name = null) diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index 19afce1d..3f756502 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -14,6 +14,7 @@ limitations under the License. ******************************************************************************/ +using System; using Tensorflow.Operations; using Tensorflow.Operations.Activation; using static Tensorflow.Binding; @@ -182,7 +183,13 @@ namespace Tensorflow => nn_impl.sigmoid_cross_entropy_with_logits(labels: labels, logits: logits, name: name); public Tensor softmax(Tensor logits, int axis = -1, string name = null) - => gen_nn_ops.softmax(logits, name); + { + if (axis == -1) + return gen_nn_ops.softmax(logits, name); + else + throw new NotImplementedException(""); + } + /// /// Computes sparse softmax cross entropy between `logits` and `labels`. diff --git a/src/TensorFlowNET.Core/APIs/tf.random.cs b/src/TensorFlowNET.Core/APIs/tf.random.cs index 54fd57be..a8c18808 100644 --- a/src/TensorFlowNET.Core/APIs/tf.random.cs +++ b/src/TensorFlowNET.Core/APIs/tf.random.cs @@ -38,6 +38,24 @@ namespace Tensorflow TF_DataType dtype = TF_DataType.TF_FLOAT, int? seed = null, string name = null) => random_ops.random_normal(shape, mean, stddev, dtype, seed, name); + + /// + /// Outputs random values from a truncated normal distribution. + /// + /// + /// + /// + /// + /// + /// + /// + public Tensor truncated_normal(TensorShape shape, + float mean = 0.0f, + float stddev = 1.0f, + TF_DataType dtype = TF_DataType.TF_FLOAT, + int? seed = null, + string name = null) => random_ops.truncated_normal(shape, mean, stddev, dtype, seed, name); + public Tensor categorical( Tensor logits, int num_samples, diff --git a/src/TensorFlowNET.Core/Data/DatasetOps.cs b/src/TensorFlowNET.Core/Data/DatasetOps.cs new file mode 100644 index 00000000..4035ee4f --- /dev/null +++ b/src/TensorFlowNET.Core/Data/DatasetOps.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class DatasetOps + { + } +} diff --git a/src/TensorFlowNET.Core/Data/TensorSliceDataset.cs b/src/TensorFlowNET.Core/Data/TensorSliceDataset.cs new file mode 100644 index 00000000..8f6a6dac --- /dev/null +++ b/src/TensorFlowNET.Core/Data/TensorSliceDataset.cs @@ -0,0 +1,20 @@ +using NumSharp; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class TensorSliceDataset + { + public TensorSliceDataset(params NDArray[] elements) + { + + } + + public TensorSliceDataset from_tensor_slices(params NDArray[] elements) + { + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/EagerOperation.cs b/src/TensorFlowNET.Core/Eager/EagerOperation.cs index 39038608..6f092a57 100644 --- a/src/TensorFlowNET.Core/Eager/EagerOperation.cs +++ b/src/TensorFlowNET.Core/Eager/EagerOperation.cs @@ -1,20 +1,30 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; +using static Tensorflow.Binding; namespace Tensorflow.Eager { public class EagerOperation : Operation { - public int NumInputs; + static Dictionary op_dict; + public string Name { get; set; } + public new int NumInputs; public IntPtr[] InputHandles { get; set; } public Tensor[] Inputs { get; set; } - public int NumOutputs; + public new int NumOutputs; public IntPtr[] OutputHandles { get; set; } public Tensor[] Outputs { get; set; } - public int[] SkipInputIndices { get; set; } + public BindingArray SkipInputIndicesArray { get; set; } + public unsafe int[] SkipInputIndices => SkipInputIndicesArray.Data.Select(x => *(int*) x).ToArray(); + public string[] AttrsArray { get; set; } - public EagerOperation() : base(IntPtr.Zero) { } + public EagerOperation() : base(IntPtr.Zero) + { + if (op_dict == null) + op_dict = op_def_registry.get_registered_ops(); + } public override InputList inputs { @@ -22,13 +32,6 @@ namespace Tensorflow.Eager { if (_inputs_val == null) { - var retval = new Tensor[NumInputs]; - - for (int i = 0; i < NumInputs; i++) - { - - } - _inputs_val = new InputList(Inputs); } @@ -48,5 +51,35 @@ namespace Tensorflow.Eager return _outputs; } } + + public override object get_attr(string attr_name) + { + object value = null; + byte isList = 0; + using var status = new Status(); + var attrType = c_api.TFE_OpNameGetAttrType(tf.context, Name, attr_name, ref isList, status.Handle); + switch (attrType) + { + case TF_AttrType.TF_ATTR_BOOL: + value = get_attr_bool(attr_name); + break; + default: + break; + } + + return value; + } + + public bool get_attr_bool(string attr_name) + { + for (int i = 0; i < AttrsArray.Length; i = i + 2) + if (AttrsArray[i] == attr_name) + return AttrsArray[i + 1] == "1"; + + throw new ValueError($"Can't find attr: {attr_name}"); + } + + public override string ToString() + => $"tf.EagerOperation {Name}"; } } diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs index ea13c59b..a5fc064b 100644 --- a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs @@ -2,6 +2,7 @@ using System; using System.Collections.Generic; using System.Text; +using System.Threading; using static Tensorflow.Binding; namespace Tensorflow.Eager @@ -49,6 +50,10 @@ namespace Tensorflow.Eager print($"new TensorHandle {Id} {tfe_tensor_handle.ToString("x16")}"); print($"new EagerTensor {Id} {EagerTensorHandle.ToString("x16")}");*/ + if (tfe_tensor_handle == IntPtr.Zero && _id == 0) + { + } + GarbageCollector.Increase(_handle, GCItemType.TensorHandle); GarbageCollector.Increase(tfe_tensor_handle, GCItemType.LocalTensorHandle); GarbageCollector.Increase(EagerTensorHandle, GCItemType.EagerTensorHandle); @@ -56,6 +61,9 @@ namespace Tensorflow.Eager return this; } + public override IntPtr ToPointer() + => EagerTensorHandle; + protected override void DisposeUnmanagedResources(IntPtr handle) { GarbageCollector.Decrease(_handle); diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.cs index a59f98e1..cb99d458 100644 --- a/src/TensorFlowNET.Core/Eager/EagerTensor.cs +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.cs @@ -13,7 +13,7 @@ namespace Tensorflow.Eager public IntPtr EagerTensorHandle { get; set; } public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(tfe_tensor_handle, status.Handle)); - // public override int rank => c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, status); + public override int rank => c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, status.Handle); public static int GetRank(IntPtr handle) { diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs index fd02b5b0..d38f08fe 100644 --- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs +++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs @@ -25,7 +25,7 @@ namespace Tensorflow public delegate IntPtr gradient_function_callback(string op_name, IntPtr op_inputs, IntPtr op_outputs, - int num_attrs, + string attrs_string, IntPtr output_grads, IntPtr skip_input_indices); @@ -72,6 +72,9 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern TF_AttrType TFE_OpGetAttrType(IntPtr op, string attr_name, ref byte is_list, SafeStatusHandle status); + [DllImport(TensorFlowLibName)] + public static extern TF_AttrType TFE_OpNameGetAttrType(IntPtr ct, string op_or_function_name, string attr_name, ref byte is_list, SafeStatusHandle status); + /// /// Returns the length (number of tensors) of the input argument `input_name` /// found in the provided `op`. @@ -399,6 +402,7 @@ namespace Tensorflow string name, IntPtr[] inputs, int input_size, + string attrs_string, TFE_FastPathExecute_SetOpAttrs set_op_attrs, IntPtr[] outputs, int output_size); diff --git a/src/TensorFlowNET.Core/Eager/wrap_tfe_src.TFE_FastPathExecute.cs b/src/TensorFlowNET.Core/Eager/wrap_tfe_src.TFE_FastPathExecute.cs index fef57c06..ad11fd6d 100644 --- a/src/TensorFlowNET.Core/Eager/wrap_tfe_src.TFE_FastPathExecute.cs +++ b/src/TensorFlowNET.Core/Eager/wrap_tfe_src.TFE_FastPathExecute.cs @@ -31,6 +31,37 @@ namespace Tensorflow.Eager } } + public static string SetOpAttrs2(params object[] attrs) + { + string attr_string = string.Empty; + for(int i = 0; i < attrs.Length; i = i + 2) + { + object key = attrs[i]; + object value = attrs[i + 1]; + + switch (value) + { + case TF_DataType dtype: + value = (int)dtype; + break; + case bool bVal: + value = bVal ? 1 : 0; + break; + case int[] shape: + value = shape.Length == 0 ? "null" : string.Join(" ", shape); + break; + default: + break; + } + + attr_string += string.IsNullOrEmpty(attr_string) ? + $"{key},{value}" : + $",{key},{value}"; + } + + return attr_string; + } + /// /// This function will set the op attrs required. If an attr has the value of /// None, then it will read the AttrDef to get the default value and set that diff --git a/src/TensorFlowNET.Core/Gradients/GradientTape.cs b/src/TensorFlowNET.Core/Gradients/GradientTape.cs index ea18e557..e45ca404 100644 --- a/src/TensorFlowNET.Core/Gradients/GradientTape.cs +++ b/src/TensorFlowNET.Core/Gradients/GradientTape.cs @@ -1,7 +1,9 @@ using Google.Protobuf.WellKnownTypes; +using NumSharp.Utilities; using System; using System.Collections.Generic; using System.Linq; +using System.Reflection; using System.Runtime.InteropServices; using System.Text; using Tensorflow.Eager; @@ -72,22 +74,40 @@ namespace Tensorflow.Gradients public Tensor gradient(Tensor target, Tensor source) { - if(_recording) + if (_recording) { if (!_persistent) _pop_tape(); } - var results = new[] { new EagerTensor() }; + var results = EagerTensorPass.Create(); + var targets = EagerTensorPass.From(target); + var sources = EagerTensorPass.From(source); + using Status status = new Status(c_api.TFE_TapeGradient(_tape, - new [] { (target as EagerTensor).EagerTensorHandle }, 1, - new [] { (source as EagerTensor).EagerTensorHandle }, 1, - results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); + targets.Points, targets.Length, + sources.Points, sources.Length, + results.Points, results.Length)); status.Check(true); + return results[0].Resolve(); } - public unsafe (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources) + public Tensor gradient(Tensor target, ResourceVariable source) + { + var results = gradient(target as EagerTensor, new[] { source }); + + return results[0]; + } + + public (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources) + { + var results = gradient(target as EagerTensor, new[] { sources.Item1, sources.Item2 }); + + return (results[0], results[1]); + } + + public EagerTensor[] gradient(EagerTensor target, ResourceVariable[] sources) { if (_recording) { @@ -95,18 +115,14 @@ namespace Tensorflow.Gradients _pop_tape(); } - var results = new[] { new EagerTensor(), new EagerTensor() }; + var results = EagerTensorPass.Create(sources.Length); + var target_inputs = EagerTensorPass.From(target); + var source_inputs = EagerTensorPass.From(sources.Select(x => x.Handle).ToArray()); + using Status status = new Status(c_api.TFE_TapeGradient(_tape, - new IntPtr[] - { - target as EagerTensor - }, 1, - new IntPtr[] - { - (sources.Item1.Handle as EagerTensor).EagerTensorHandle, - (sources.Item2.Handle as EagerTensor).EagerTensorHandle - }, 2, - results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); + target_inputs.Points, target_inputs.Length, + source_inputs.Points, source_inputs.Length, + results.Points, results.Length)); status.Check(true); if (!_persistent) @@ -116,13 +132,15 @@ namespace Tensorflow.Gradients _tape = null; } - return (results[0].Resolve(), results[1].Resolve()); + return results.Items.Select(x => x.Resolve()).ToArray(); } public void Dispose() { if (_recording) _pop_tape(); + + tf.tensorMgr.Reset(); } } } diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs index 427de88b..4e5a5e85 100644 --- a/src/TensorFlowNET.Core/Gradients/math_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs @@ -310,23 +310,26 @@ namespace Tensorflow.Gradients var input_shape = op.inputs[0]._shape_tuple(); var output_shape = op.outputs[0]._shape_tuple(); + Tensor result, factor_tensor; if(input_shape != null && output_shape != null) { var input_size = np.prod(input_shape); var output_size = np.prod(output_shape); var factor = (int)input_size / Math.Max((int)output_size, 1); - var factor_tensor = constant_op.constant((int)input_size, dtype: sum_grad.dtype); - return new Tensor[] { math_ops.truediv(sum_grad, math_ops.cast(factor_tensor, sum_grad.dtype)), null }; + factor_tensor = constant_op.constant(factor, dtype: sum_grad.dtype); } else { var input_shape_tensor = array_ops.shape(op.inputs[0]); var output_shape_tensor = array_ops.shape(op.outputs[0]); var factor = _safe_shape_div(math_ops.reduce_prod(input_shape_tensor), math_ops.reduce_prod(output_shape_tensor)); - - return new Tensor[] { math_ops.truediv(sum_grad, math_ops.cast(factor, sum_grad.dtype)), null }; + throw new NotImplementedException(""); + factor_tensor = null; } + + result = math_ops.truediv(sum_grad, math_ops.cast(factor_tensor, sum_grad.dtype)); + return new Tensor[] { result, null }; } /// @@ -497,8 +500,8 @@ namespace Tensorflow.Gradients if (tf.context.executing_eagerly()) { // should add ones_rank_cache - var new_shape_tensor = constant_op.constant(np.array(new int[] { 1 }) * rank, dtype: TF_DataType.TF_INT32); - grad = array_ops.reshape(grad, new_shape_tensor); + var new_shape = constant_op.constant(range(0, rank).Select(x => 1).ToArray(), dtype: TF_DataType.TF_INT32); + grad = array_ops.reshape(grad, new_shape); } else { @@ -513,20 +516,23 @@ namespace Tensorflow.Gradients input_shape = array_ops.shape(op.inputs[0]); return new Tensor[] { gen_array_ops.tile(grad, input_shape), null }; } - else + else if (!input_0_shape.Contains(-1) && !tf.context.executing_eagerly()) { - + throw new NotImplementedException(""); } } } input_shape = array_ops.shape(op.inputs[0]); - ops.colocate_with(input_shape); - var output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]); - var tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims); - grad = gen_array_ops.reshape(grad, output_shape_kept_dims); + if (!op.get_attr("keep_dims")) + { + ops.colocate_with(input_shape); + var output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1]); + // var tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims); + grad = gen_array_ops.reshape(grad, output_shape_kept_dims); + } - return new Tensor[] { gen_array_ops.tile(grad, tile_scaling), null }; + return new Tensor[] { gen_array_ops.broadcast_to(grad, input_shape), null }; } [RegisterGradient("RealDiv")] diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs index 81c13827..b119745c 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Control.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Control.cs @@ -17,6 +17,7 @@ using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; +using Tensorflow.Eager; using Tensorflow.Operations; using static Tensorflow.Binding; @@ -81,6 +82,9 @@ namespace Tensorflow /// public _ControlDependenciesController control_dependencies(object[] control_inputs) { + if (tf.context.executing_eagerly()) + return new _ControlDependenciesController(this, null); + if (control_inputs == null) return new _ControlDependenciesController(this, null); diff --git a/src/TensorFlowNET.Core/Graphs/NullContextmanager.cs b/src/TensorFlowNET.Core/Graphs/NullContextmanager.cs new file mode 100644 index 00000000..a8b57ea3 --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/NullContextmanager.cs @@ -0,0 +1,34 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class NullContextmanager : ITensorFlowObject + { + public void __init__() + { + throw new NotImplementedException(); + } + + public void __enter__() + { + throw new NotImplementedException(); + } + + public void __del__() + { + throw new NotImplementedException(); + } + + public void __exit__() + { + throw new NotImplementedException(); + } + + public void Dispose() + { + throw new NotImplementedException(); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Datasets/DatasetPass.cs b/src/TensorFlowNET.Core/Keras/Datasets/DatasetPass.cs new file mode 100644 index 00000000..48fb3b82 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Datasets/DatasetPass.cs @@ -0,0 +1,27 @@ +using NumSharp; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Datasets +{ + public class DatasetPass + { + public (NDArray, NDArray) Train { get; set; } + public (NDArray, NDArray) Test { get; set; } + + public void Deconstruct(out NDArray x_train, out NDArray y_train, out NDArray x_test, out NDArray y_test) + { + x_train = Train.Item1; + y_train = Train.Item2; + x_test = Test.Item1; + y_test = Test.Item2; + } + + public void Deconstruct(out (NDArray, NDArray) train, out (NDArray, NDArray) test) + { + train = Train; + test = Test; + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Datasets/KerasDataset.cs b/src/TensorFlowNET.Core/Keras/Datasets/KerasDataset.cs new file mode 100644 index 00000000..5aabfdf0 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Datasets/KerasDataset.cs @@ -0,0 +1,27 @@ +/***************************************************************************** + Copyright 2020 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Datasets +{ + public class KerasDataset + { + public Mnist mnist { get; } = new Mnist(); + } +} diff --git a/src/TensorFlowNET.Core/Keras/Datasets/Mnist.cs b/src/TensorFlowNET.Core/Keras/Datasets/Mnist.cs new file mode 100644 index 00000000..37eae319 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Datasets/Mnist.cs @@ -0,0 +1,76 @@ +/***************************************************************************** + Copyright 2020 Haiping Chen. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using NumSharp; +using System; +using System.Collections.Generic; +using System.IO; +using System.Net; +using System.Text; + +namespace Tensorflow.Keras.Datasets +{ + public class Mnist + { + string origin_folder = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/"; + string file_name = "mnist.npz"; + + /// + /// Loads the [MNIST dataset](http://yann.lecun.com/exdb/mnist/). + /// + /// + public DatasetPass load_data() + { + var file = Download(); + var bytes = File.ReadAllBytes(file); + var datax = LoadX(bytes); + var datay = LoadY(bytes); + return new DatasetPass + { + Train = (datax.Item1, datay.Item1), + Test = (datax.Item2, datay.Item2) + }; + } + + (NDArray, NDArray) LoadX(byte[] bytes) + { + var y = np.Load_Npz(bytes); + return (y["x_train.npy"], y["x_test.npy"]); + } + + (NDArray, NDArray) LoadY(byte[] bytes) + { + var y = np.Load_Npz(bytes); + return (y["y_train.npy"], y["y_test.npy"]); + } + + string Download() + { + var fileSaveTo = Path.Combine(Path.GetTempPath(), file_name); + + if (File.Exists(fileSaveTo)) + { + Console.WriteLine($"The file {fileSaveTo} already exists"); + return fileSaveTo; + } + + using var wc = new WebClient(); + wc.DownloadFileTaskAsync(origin_folder + file_name, fileSaveTo).Wait(); + + return fileSaveTo; + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/KerasApi.cs b/src/TensorFlowNET.Core/Keras/KerasApi.cs new file mode 100644 index 00000000..f9fd94d6 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/KerasApi.cs @@ -0,0 +1,12 @@ +using System.Data; +using Tensorflow.Keras; +using Tensorflow.Keras.Datasets; + +namespace Tensorflow +{ + public class KerasApi + { + public KerasDataset datasets { get; } = new KerasDataset(); + public Initializers initializers { get; } = new Initializers(); + } +} diff --git a/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs b/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs index e2c4808d..5c75e9bf 100644 --- a/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs +++ b/src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs @@ -36,6 +36,13 @@ namespace Tensorflow.Keras.Optimizers apply_state = new Dictionary>(); } + public void apply_gradients((Tensor, ResourceVariable) grads_and_vars, + string name = null, + bool experimental_aggregate_gradients = true) + => apply_gradients(new (Tensor, ResourceVariable)[] { grads_and_vars }, + name: name, + experimental_aggregate_gradients: experimental_aggregate_gradients); + /// /// Apply gradients to variables. /// diff --git a/src/TensorFlowNET.Core/Keras/tf.keras.cs b/src/TensorFlowNET.Core/Keras/tf.keras.cs deleted file mode 100644 index dee173f8..00000000 --- a/src/TensorFlowNET.Core/Keras/tf.keras.cs +++ /dev/null @@ -1,12 +0,0 @@ -using Tensorflow.Keras; - -namespace Tensorflow -{ - public partial class tensorflow - { - public class keras - { - public static Initializers initializers => new Initializers(); - } - } -} diff --git a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs index 535070e9..108b64b7 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs @@ -373,6 +373,19 @@ namespace Tensorflow.Operations public static Tensor relu_grad(Tensor gradients, Tensor features, string name = null) { + if (tf.context.executing_eagerly()) + { + var results = EagerTensorPass.Create(); + var inputs = EagerTensorPass.From(gradients, features); + using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "ReluGrad", name, + inputs.Points, inputs.Length, + null, null, + results.Points, results.Length)); + status.Check(true); + return results[0].Resolve(); + } + var _op = _op_def_lib._apply_op_helper("ReluGrad", name: name, args: new { gradients, @@ -396,6 +409,19 @@ namespace Tensorflow.Operations public static Tensor softmax(Tensor logits, string name = null) { + if (tf.context.executing_eagerly()) + { + var results = EagerTensorPass.Create(); + var inputs = EagerTensorPass.From(logits); + using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Softmax", name, + inputs.Points, inputs.Length, + null, null, + results.Points, results.Length)); + status.Check(true); + return results[0].Resolve(); + } + var _op = _op_def_lib._apply_op_helper("Softmax", name: name, args: new { logits @@ -473,7 +499,8 @@ namespace Tensorflow.Operations "Relu", name, new IntPtr[] { features as EagerTensor, - }, 1, null, + }, 1, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); @@ -492,7 +519,8 @@ namespace Tensorflow.Operations "Tanh", name, new IntPtr[] { x as EagerTensor, - }, 1, null, + }, 1, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index b4ab4aa3..18ef6266 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -57,7 +57,7 @@ namespace Tensorflow public int _id_value { get; set; } public Operation op => this; public TF_DataType dtype => TF_DataType.DtInvalid; - public string name => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationName(_handle)); + public virtual string name => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationName(_handle)); public string OpType => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationOpType(_handle)); public string Device => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); @@ -228,7 +228,7 @@ namespace Tensorflow public T get_attr(string name) => (T)get_attr(name); - public object get_attr(string name) + public virtual object get_attr(string name) { AttrValue x = null; diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index 3a69eda5..23c6febb 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -349,7 +349,7 @@ namespace Tensorflow return fill(shape_tensor, ones, name: name); }); - public static Tensor one_hot(Tensor indices, int depth, + public static Tensor one_hot(Tensor indices, Tensor depth, Tensor on_value = null, Tensor off_value = null, TF_DataType dtype = TF_DataType.DtInvalid, diff --git a/src/TensorFlowNET.Core/Operations/clip_ops.cs b/src/TensorFlowNET.Core/Operations/clip_ops.cs index 701664f4..018949c0 100644 --- a/src/TensorFlowNET.Core/Operations/clip_ops.cs +++ b/src/TensorFlowNET.Core/Operations/clip_ops.cs @@ -25,7 +25,7 @@ namespace Tensorflow { public class clip_ops { - public static Tensor clip_by_value(Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = null) + public static Tensor clip_by_value(Tensor t, T1 clip_value_min, T2 clip_value_max, string name = null) { return tf_with(ops.name_scope(name, "clip_by_value", new { t, clip_value_min, clip_value_max }), delegate { diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 54024910..316519e2 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -21,6 +21,7 @@ using static Tensorflow.Binding; using Tensorflow.Eager; using System.Linq; using static Tensorflow.Binding; +using System.Security.Cryptography.X509Certificates; namespace Tensorflow { @@ -60,7 +61,8 @@ namespace Tensorflow { values as EagerTensor, axis as EagerTensor - }, 2, null, + }, 2, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); @@ -165,7 +167,8 @@ namespace Tensorflow var results = new[] { new EagerTensor() }; using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, "Pack", name, - values.Select(x => (x as EagerTensor).EagerTensorHandle).ToArray(), values.Length, + values.Select(x => (x as EagerTensor).EagerTensorHandle).ToArray(), values.Length, + wrap_tfe_src.SetOpAttrs2("axis", axis), op => wrap_tfe_src.SetOpAttrs(op, "axis", axis), results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); @@ -235,7 +238,8 @@ namespace Tensorflow "Identity", name, new IntPtr[] { input as EagerTensor - }, 1, null, + }, 1, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); @@ -278,15 +282,16 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { - var results = new[] { new EagerTensor() }; + var results = EagerTensorPass.Create(); + var inputs = EagerTensorPass.From(dims, value); + using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, - "Fill", name, new IntPtr[] - { - dims as EagerTensor, - value as EagerTensor - }, 2, null, - results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); + "Fill", name, + inputs.Points, inputs.Length, + null, null, + results.Points, results.Length)); status.Check(true); + return results[0].Resolve(); } @@ -311,7 +316,8 @@ namespace Tensorflow { s0 as EagerTensor, s1 as EagerTensor - }, 2, null, + }, 2, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return (results[0].Resolve(), results[1].Resolve()); @@ -338,7 +344,8 @@ namespace Tensorflow { tensor as EagerTensor, shape as EagerTensor - }, 2, null, + }, 2, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); @@ -381,13 +388,30 @@ namespace Tensorflow return _op.output; } - public static Tensor one_hot(Tensor indices, int depth, + public static Tensor one_hot(Tensor indices, Tensor depth, Tensor on_value = null, Tensor off_value = null, TF_DataType dtype = TF_DataType.DtInvalid, int axis = -1, string name = null) { + if (tf.context.executing_eagerly()) + { + var results = EagerTensorPass.Create(); + var inputs = EagerTensorPass.From(indices, depth, on_value, off_value); + var attrs = new object[] { "axis", axis }; + + using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "OneHot", name, + inputs.Points, inputs.Length, + wrap_tfe_src.SetOpAttrs2(attrs), + op => wrap_tfe_src.SetOpAttrs(op, attrs), + results.Points, results.Length)); + status.Check(true); + + return results[0].Resolve(); + } + var _op = _op_def_lib._apply_op_helper("OneHot", name, new { indices, depth, on_value, off_value, axis }); return _op.outputs[0]; } @@ -407,6 +431,21 @@ namespace Tensorflow public static Tensor select(Tensor condition, Tx t, Ty e, string name = null) { + if (tf.context.executing_eagerly()) + { + var results = EagerTensorPass.Create(); + var inputs = EagerTensorPass.From(condition, t, e); + + using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "SelectV2", name, + inputs.Points, inputs.Length, + null, null, + results.Points, results.Length)); + status.Check(true); + + return results[0].Resolve(); + } + var _op = _op_def_lib._apply_op_helper("Select", name, new { condition, t, e }); return _op.outputs[0]; } @@ -427,6 +466,7 @@ namespace Tensorflow { input as EagerTensor, }, 1, + wrap_tfe_src.SetOpAttrs2("out_type", out_type), op => wrap_tfe_src.SetOpAttrs(op, "out_type", out_type), results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); @@ -486,7 +526,8 @@ namespace Tensorflow { input as EagerTensor, multiples as EagerTensor - }, 2, null, + }, 2, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); @@ -526,6 +567,14 @@ namespace Tensorflow if (tf.context.executing_eagerly()) { var results = new[] { new EagerTensor() }; + var attrs = new object[] + { + "begin_mask", begin_mask, + "end_mask", end_mask, + "ellipsis_mask", ellipsis_mask, + "new_axis_mask", new_axis_mask, + "shrink_axis_mask", shrink_axis_mask + }; using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, "StridedSlice", name, new IntPtr[] { @@ -534,12 +583,8 @@ namespace Tensorflow end as EagerTensor, strides as EagerTensor, }, 4, - op => wrap_tfe_src.SetOpAttrs(op, - "begin_mask", begin_mask, - "end_mask", end_mask, - "ellipsis_mask", ellipsis_mask, - "new_axis_mask", new_axis_mask, - "shrink_axis_mask", shrink_axis_mask), + wrap_tfe_src.SetOpAttrs2(attrs), + op => wrap_tfe_src.SetOpAttrs(op, attrs), results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); @@ -645,6 +690,21 @@ namespace Tensorflow /// A `Tensor`. Has the same type as `input`. public static Tensor squeeze(Tensor input, int[] axis = null, string name = null) { + if (tf.context.executing_eagerly()) + { + var results = new[] { new EagerTensor() }; + using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Squeeze", name, new IntPtr[] + { + input as EagerTensor + }, 1, + wrap_tfe_src.SetOpAttrs2("squeeze_dims", axis), + op => wrap_tfe_src.SetOpAttrs(op, "squeeze_dims", axis), + results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); + status.Check(true); + return results[0].Resolve(); + } + if (axis == null) axis = new int[0]; var _op = _op_def_lib._apply_op_helper("Squeeze", name, args: new { input, squeeze_dims = axis }); @@ -674,8 +734,22 @@ namespace Tensorflow /// /// /// - public static Tensor broadcast_to(Tensor input, int[] shape, string name = null) + public static Tensor broadcast_to(Tensor input, T shape, string name = null) { + if (tf.context.executing_eagerly()) + { + var results = EagerTensorPass.Create(); + var inputs = EagerTensorPass.From(input, shape); + + using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "BroadcastTo", name, + inputs.Points, inputs.Length, + null, null, + results.Points, results.Length)); + status.Check(true); + return results[0].Resolve(); + } + var _op = _op_def_lib._apply_op_helper("BroadcastTo", name, args: new { input, shape, name }); return _op.outputs[0]; diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index c5dd0b98..435d46c0 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -48,7 +48,7 @@ namespace Tensorflow using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, "AddN", name, inputs.Select(x => (x as EagerTensor).EagerTensorHandle).ToArray(), inputs.Length, - null, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); @@ -65,7 +65,7 @@ namespace Tensorflow using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, "AddN", name, inputs, inputs.Length, - null, + null, null, results, results.Length)); status.Check(true); return results[0]; @@ -80,7 +80,23 @@ namespace Tensorflow /// /// public static Tensor arg_max(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null) - => _op_def_lib._apply_op_helper("ArgMax", name, args: new { input, dimension, output_type }).outputs[0]; + { + if (tf.context.executing_eagerly()) + { + var results = EagerTensorPass.Create(); + var inputs = EagerTensorPass.From(input, dimension); + using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "ArgMax", name, + inputs.Points, inputs.Length, + wrap_tfe_src.SetOpAttrs2("output_type", output_type), + op => wrap_tfe_src.SetOpAttrs(op, "output_type", output_type), + results.Points, results.Length)); + status.Check(true); + return results[0].Resolve(); + } + + return _op_def_lib._apply_op_helper("ArgMax", name, args: new { input, dimension, output_type }).output; + } /// /// Returns the index with the smallest value across dimensions of a tensor. @@ -152,6 +168,7 @@ namespace Tensorflow input as EagerTensor, axis as EagerTensor }, 2, + wrap_tfe_src.SetOpAttrs2("keep_dims", keep_dims), op => wrap_tfe_src.SetOpAttrs(op, "keep_dims", keep_dims), results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); @@ -198,6 +215,7 @@ namespace Tensorflow input as EagerTensor, axis as EagerTensor }, 2, + wrap_tfe_src.SetOpAttrs2("keep_dims", keep_dims), op => wrap_tfe_src.SetOpAttrs(op, "keep_dims", keep_dims), results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); @@ -247,7 +265,8 @@ namespace Tensorflow { x as EagerTensor, y as EagerTensor - }, 2, null, + }, 2, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); @@ -268,7 +287,8 @@ namespace Tensorflow { x as EagerTensor, y as EagerTensor - }, 2, null, + }, 2, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); @@ -290,7 +310,8 @@ namespace Tensorflow { x as EagerTensor, y as EagerTensor - }, 2, null, + }, 2, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); @@ -324,7 +345,8 @@ namespace Tensorflow "Sin", name, new IntPtr[] { x as EagerTensor, - }, 1, null, + }, 1, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); @@ -358,7 +380,8 @@ namespace Tensorflow "Sigmoid", name, new IntPtr[] { x as EagerTensor, - }, 1, null, + }, 1, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); @@ -451,7 +474,8 @@ namespace Tensorflow "Tan", name, new IntPtr[] { x as EagerTensor, - }, 1, null, + }, 1, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); @@ -464,6 +488,20 @@ namespace Tensorflow public static Tensor tanh(Tensor x, string name = null) { + if (tf.context.executing_eagerly()) + { + var results = new[] { new EagerTensor() }; + using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Tanh", name, new IntPtr[] + { + x as EagerTensor, + }, 1, + null, null, + results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); + status.Check(true); + return results[0].Resolve(); + } + var _op = _op_def_lib._apply_op_helper("Tanh", name, args: new { x }); return _op.outputs[0]; @@ -477,7 +515,25 @@ namespace Tensorflow /// /// public static Tensor tanh_grad(Tensor y, Tensor dy, string name = null) - => _op_def_lib._apply_op_helper("TanhGrad", name: name, args: new { y, dy }).output; + { + if (tf.context.executing_eagerly()) + { + var results = new[] { new EagerTensor() }; + using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "TanhGrad", name, new IntPtr[] + { + y as EagerTensor, + dy as EagerTensor + }, 2, + null, null, + results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); + status.Check(true); + return results[0].Resolve(); + } + + var _op = _op_def_lib._apply_op_helper("TanhGrad", name: name, args: new { y, dy }).output; + return _op.outputs[0]; + } public static Tensor floor(Tensor x, string name = null) { @@ -495,6 +551,19 @@ namespace Tensorflow public static Tensor greater(Tx x, Ty y, string name = null) { + if (tf.context.executing_eagerly()) + { + var results = EagerTensorPass.Create(); + var inputs = EagerTensorPass.From(x, y); + using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Greater", name, + inputs.Points, inputs.Length, + null, null, + results.Points, results.Length)); + status.Check(true); + return results[0].Resolve(); + } + var _op = _op_def_lib._apply_op_helper("Greater", name: name, args: new { x, y }); return _op.outputs[0]; @@ -520,6 +589,21 @@ namespace Tensorflow public static Tensor greater_equal(Tx x, Ty y, string name = null) { + if (tf.context.executing_eagerly()) + { + var results = EagerTensorPass.Create(); + var inputs = EagerTensorPass.From(x, y); + + using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "GreaterEqual", name, + inputs.Points, inputs.Length, + null, null, + results.Points, results.Length)); + status.Check(true); + + return results[0].Resolve(); + } + var _op = _op_def_lib._apply_op_helper("GreaterEqual", name: name, args: new { x, y }); return _op.outputs[0]; @@ -529,14 +613,13 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { - var results = new[] { new EagerTensor() }; + var results = EagerTensorPass.Create(); + var inputs = EagerTensorPass.From(x, y); using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, - "Less", name, new IntPtr[] - { - x as EagerTensor, - y as EagerTensor - }, 2, null, - results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); + "Less", name, + inputs.Points, inputs.Length, + null, null, + results.Points, results.Length)); status.Check(true); return results[0].Resolve(); } @@ -548,6 +631,19 @@ namespace Tensorflow public static Tensor less_equal(Tx x, Ty y, string name = null) { + if (tf.context.executing_eagerly()) + { + var results = EagerTensorPass.Create(); + var inputs = EagerTensorPass.From(x, y); + using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "LessEqual", name, + inputs.Points, inputs.Length, + null, null, + results.Points, results.Length)); + status.Check(true); + return results[0].Resolve(); + } + var _op = _op_def_lib._apply_op_helper("LessEqual", name: name, args: new { x, y }); return _op.outputs[0]; @@ -611,7 +707,8 @@ namespace Tensorflow "Square", name, new IntPtr[] { x as EagerTensor, - }, 1, null, + }, 1, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); @@ -663,6 +760,21 @@ namespace Tensorflow /// A `Tensor`. Has the same type as `x`. public static Tensor log(Tensor x, string name = null) { + if (tf.context.executing_eagerly()) + { + var results = EagerTensorPass.Create(); + var inputs = EagerTensorPass.From(x); + + using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Log", name, + inputs.Points, inputs.Length, + null, null, + results.Points, results.Length)); + status.Check(true); + + return results[0].Resolve(); + } + var _op = _op_def_lib._apply_op_helper("Log", name, args: new { x }); return _op.outputs[0]; @@ -673,12 +785,20 @@ namespace Tensorflow if (tf.context.executing_eagerly()) { var results = new[] { new EagerTensor() }; + var attrs = new object[] + { + "DstT", DstT, + "Truncate", Truncate + }; + using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, "Cast", name, new IntPtr[] { x as EagerTensor }, 1, - op => wrap_tfe_src.SetOpAttrs(op, "DstT", DstT, "Truncate", Truncate), + wrap_tfe_src.SetOpAttrs2(attrs), + op => wrap_tfe_src.SetOpAttrs(op, attrs), results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); + return results[0].Resolve(); } @@ -691,14 +811,16 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { - var results = new[] { new EagerTensor() }; + var results = EagerTensorPass.Create(); + var inputs = EagerTensorPass.From(x); + using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, - "Neg", name, new IntPtr[] - { - x as EagerTensor - }, 2, null, - results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); + "Neg", name, + inputs.Points, inputs.Length, + null, null, + results.Points, results.Length)); status.Check(true); + return results[0].Resolve(); } @@ -716,7 +838,8 @@ namespace Tensorflow "Sqrt", name, new IntPtr[] { x as EagerTensor, - }, 1, null, + }, 1, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); @@ -737,7 +860,8 @@ namespace Tensorflow { x as EagerTensor, y as EagerTensor - }, 2, null, + }, 2, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); @@ -758,7 +882,8 @@ namespace Tensorflow { x as EagerTensor, y as EagerTensor - }, 2, null, + }, 2, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); @@ -786,7 +911,8 @@ namespace Tensorflow { x as EagerTensor, y as EagerTensor - }, 2, null, + }, 2, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); @@ -815,7 +941,8 @@ namespace Tensorflow { x as EagerTensor, y as EagerTensor - }, 2, null, + }, 2, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); @@ -836,7 +963,8 @@ namespace Tensorflow { y as EagerTensor, x as EagerTensor - }, 2, null, + }, 2, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); @@ -856,7 +984,8 @@ namespace Tensorflow { x as EagerTensor, y as EagerTensor - }, 2, null, + }, 2, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); @@ -877,7 +1006,8 @@ namespace Tensorflow { x as EagerTensor, y as EagerTensor, - }, 2, null, + }, 2, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); @@ -905,7 +1035,8 @@ namespace Tensorflow { x as EagerTensor, y as EagerTensor - }, 2, null, + }, 2, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); @@ -918,6 +1049,21 @@ namespace Tensorflow public static Tensor reciprocal(Tensor x, string name = null) { + if (tf.context.executing_eagerly()) + { + var results = EagerTensorPass.Create(); + var inputs = EagerTensorPass.From(x); + + using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Reciprocal", name, + inputs.Points, inputs.Length, + null, null, + results.Points, results.Length)); + status.Check(true); + + return results[0].Resolve(); + } + var _op = _op_def_lib._apply_op_helper("Reciprocal", name, args: new { x }); return _op.outputs[0]; @@ -933,7 +1079,8 @@ namespace Tensorflow { x as EagerTensor, y as EagerTensor - }, 2, null, + }, 2, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); @@ -954,7 +1101,8 @@ namespace Tensorflow { x as EagerTensor, y as EagerTensor - }, 2, null, + }, 2, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); @@ -978,18 +1126,19 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { - var results = new[] { new EagerTensor() }; + var results = EagerTensorPass.Create(); + var inputs = EagerTensorPass.From(a, b); + var attrs = new object[] + { + "transpose_a", transpose_a, + "transpose_b", transpose_b + }; using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, "MatMul", name, - new IntPtr[] - { - a as EagerTensor, - b as EagerTensor - }, 2, - op => wrap_tfe_src.SetOpAttrs(op, - "transpose_a", transpose_a, - "transpose_b", transpose_b), - results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); + inputs.Points, inputs.Length, + wrap_tfe_src.SetOpAttrs2(attrs), + op => wrap_tfe_src.SetOpAttrs(op, attrs), + results.Points, results.Length)); status.Check(true); return results[0].Resolve(); } @@ -1043,6 +1192,21 @@ namespace Tensorflow /// public static Tensor maximum(T1 x, T2 y, string name = null) { + if (tf.context.executing_eagerly()) + { + var results = EagerTensorPass.Create(); + var inputs = EagerTensorPass.From(x, y); + + using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Maximum", name, + inputs.Points, inputs.Length, + null, null, + results.Points, results.Length)); + status.Check(true); + + return results[0].Resolve(); + } + var _op = _op_def_lib._apply_op_helper("Maximum", name, args: new { x, y }); return _op.outputs[0]; @@ -1050,6 +1214,21 @@ namespace Tensorflow public static Tensor minimum(T1 x, T2 y, string name = null) { + if (tf.context.executing_eagerly()) + { + var results = EagerTensorPass.Create(); + var inputs = EagerTensorPass.From(x, y); + + using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Minimum", name, + inputs.Points, inputs.Length, + null, null, + results.Points, results.Length)); + status.Check(true); + + return results[0].Resolve(); + } + var _op = _op_def_lib._apply_op_helper("Minimum", name, args: new { x, y }); return _op.outputs[0]; @@ -1093,7 +1272,8 @@ namespace Tensorflow { x as EagerTensor, y as EagerTensor - }, 2, null, + }, 2, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); @@ -1108,17 +1288,18 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { - var results = new[] { new EagerTensor() }; + var results = EagerTensorPass.Create(); + var inputs = EagerTensorPass.From(input, axis); + var attrs = new object[] { "keep_dims", keep_dims }; + using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, "Sum", name, - new IntPtr[] - { - input as EagerTensor, - axis as EagerTensor - }, 2, - op => wrap_tfe_src.SetOpAttrs(op, "keep_dims", keep_dims), - results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); + inputs.Points, inputs.Length, + wrap_tfe_src.SetOpAttrs2(attrs), + op => wrap_tfe_src.SetOpAttrs(op, attrs), + results.Points, results.Length)); status.Check(true); + return results[0].Resolve(); } @@ -1169,7 +1350,8 @@ namespace Tensorflow start as EagerTensor, limit as EagerTensor, delta as EagerTensor - }, 3, null, + }, 3, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.eager.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.eager.cs index fb8990a9..9ad5ad97 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.eager.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.eager.cs @@ -11,14 +11,15 @@ namespace Tensorflow { public static EagerTensor mul(IntPtr x, IntPtr y, string name = null) { - var results = new[] { new EagerTensor() }; + var results = EagerTensorPass.Create(); using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, "Mul", name, new IntPtr[] { x, y, - }, 2, null, - results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); + }, 2, + null, null, + results.Points, results.Length)); status.Check(true); return results[0].Resolve(); } diff --git a/src/TensorFlowNET.Core/Operations/gen_random_ops.cs b/src/TensorFlowNET.Core/Operations/gen_random_ops.cs index 036d2a4d..aac75c22 100644 --- a/src/TensorFlowNET.Core/Operations/gen_random_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_random_ops.cs @@ -42,17 +42,20 @@ namespace Tensorflow if (tf.context.executing_eagerly()) { - var results = new[] { new EagerTensor() }; + var results = EagerTensorPass.Create(); + var attrs = new object[] + { + "seed", seed, + "seed2", seed2, + "dtype", dtype + }; + var inputs = EagerTensorPass.From(shape); using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, - "RandomStandardNormal", name, new IntPtr[] - { - shape as EagerTensor, - }, 1, - op => wrap_tfe_src.SetOpAttrs(op, - "seed", seed, - "seed2", seed2, - "dtype", dtype), - results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); + "RandomStandardNormal", name, + inputs.Points, inputs.Length, + wrap_tfe_src.SetOpAttrs2(attrs), + op => wrap_tfe_src.SetOpAttrs(op, attrs), + results.Points, results.Length)); status.Check(true); return results[0].Resolve(); } @@ -146,6 +149,26 @@ namespace Tensorflow if (!seed2.HasValue) seed2 = 0; + if (tf.context.executing_eagerly()) + { + var results = EagerTensorPass.Create(); + var inputs = EagerTensorPass.From(shape); + var attrs = new object[] + { + "seed", seed, + "seed2", seed2, + "dtype", dtype + }; + using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "TruncatedNormal", name, + inputs.Points, inputs.Length, + wrap_tfe_src.SetOpAttrs2(attrs), + op => wrap_tfe_src.SetOpAttrs(op, attrs), + results.Points, results.Length)); + status.Check(true); + return results[0].Resolve(); + } + var _op = _op_def_lib._apply_op_helper("TruncatedNormal", name: name, args: new { shape, dtype, seed, seed2 }); diff --git a/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs index d4aed731..d3b571a7 100644 --- a/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs @@ -29,15 +29,13 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { - var results = new[] { new EagerTensor() }; + var results = EagerTensorPass.Create(); + var inputs = EagerTensorPass.From(resource, value); using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, "AssignSubVariableOp", name, - new IntPtr[] - { - resource as EagerTensor, - value as EagerTensor - }, 2, null, - results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); + inputs.Points, inputs.Length, + null, null, + results.Points, results.Length)); status.Check(true); return results[0].Resolve(); } @@ -56,13 +54,11 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { + var inputs = EagerTensorPass.From(resource, value); using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, "AssignAddVariableOp", name, - new IntPtr[] - { - resource as EagerTensor, - value as EagerTensor - }, 2, null, + inputs.Points, inputs.Length, + null, null, null, 0)); status.Check(true); return null; @@ -75,13 +71,11 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { + var inputs = EagerTensorPass.From(resource, value); using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, "AssignVariableOp", name, - new IntPtr[] - { - resource as EagerTensor, - value as EagerTensor - }, 2, null, + inputs.Points, inputs.Length, + null, null, null, 0)); status.Check(true); return null; @@ -100,7 +94,8 @@ namespace Tensorflow using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, "VarIsInitializedOp", name, new IntPtr[] { resource as EagerTensor }, - 1, null, + 1, + null, null, results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); return results[0].Resolve(); @@ -125,15 +120,19 @@ namespace Tensorflow { if(tf.context.executing_eagerly()) { - var results = new[] { new EagerTensor() }; + var results = EagerTensorPass.Create(); + var attrs = new object[] + { + "container", container, + "shared_name", shared_name, + "dtype", dtype, + "shape", shape.dims + }; using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, "VarHandleOp", name, null, 0, - op => wrap_tfe_src.SetOpAttrs(op, - "container", container, - "shared_name", shared_name, - "dtype", dtype, - "shape", shape.dims), - results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); + wrap_tfe_src.SetOpAttrs2(attrs), + op => wrap_tfe_src.SetOpAttrs(op, attrs), + results.Points, results.Length)); status.Check(true); return results[0].Resolve(); } @@ -163,6 +162,7 @@ namespace Tensorflow using Status status = new Status(c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, "ReadVariableOp", name, new IntPtr[] { resource as EagerTensor }, 1, + wrap_tfe_src.SetOpAttrs2("dtype", dtype), op => wrap_tfe_src.SetOpAttrs(op, "dtype", dtype), results.Select(x => x.EagerTensorHandle).ToArray(), results.Length)); status.Check(true); diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index a58c90ec..37e749ec 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -348,6 +348,14 @@ namespace Tensorflow /// A 1-D Tensor, the output shape as if keepdims were set to True. public static Tensor reduced_shape(Tensor input_shape, Tensor axes) { + if(tf.context.executing_eagerly()) + { + var input_shape_val = input_shape.numpy(); + var axes_val = (int)axes.numpy(); + input_shape_val[axes_val] = 1; + return tf.constant(input_shape_val); + } + input_shape = to_int32(input_shape); axes = to_int32(axes); @@ -522,7 +530,8 @@ namespace Tensorflow public static Tensor reduce_sum(Tensor input_tensor, int axis, bool keepdims = false, string name = null) { - var m = gen_math_ops._sum(input_tensor, axis, keep_dims: keepdims, name: name); + var dims = _ReductionDims(input_tensor, axis); + var m = gen_math_ops._sum(input_tensor, dims, keep_dims: keepdims, name: name); return _may_reduce_to_scalar(keepdims, axis, m); } diff --git a/src/TensorFlowNET.Core/System/GarbageCollector.cs b/src/TensorFlowNET.Core/System/GarbageCollector.cs index 32c78d29..1ecb4f4f 100644 --- a/src/TensorFlowNET.Core/System/GarbageCollector.cs +++ b/src/TensorFlowNET.Core/System/GarbageCollector.cs @@ -54,8 +54,11 @@ namespace Tensorflow public static void Decrease(IntPtr handle) { - if (handle != IntPtr.Zero && container.ContainsKey(handle)) - container[handle].RefCounter--; + lock (locker) + { + if (handle != IntPtr.Zero && container.ContainsKey(handle)) + container[handle].RefCounter--; + } } private static void Recycle() @@ -64,7 +67,7 @@ namespace Tensorflow lock (locker) { var items = container.Values - .Where(x => x.RefCounter <= 0 && (DateTime.Now - x.LastUpdateTime).TotalMilliseconds > 100) + .Where(x => x.RefCounter <= 0 && (DateTime.Now - x.LastUpdateTime).TotalMilliseconds > 300) .ToArray(); foreach (var item in items) @@ -74,15 +77,15 @@ namespace Tensorflow switch (item.ItemType) { case GCItemType.TensorHandle: - // print($"c_api.TF_DeleteTensor({item.Handle.ToString("x16")})"); + //print($"c_api.TF_DeleteTensor({item.Handle.ToString("x16")})"); c_api.TF_DeleteTensor(item.Handle); break; case GCItemType.LocalTensorHandle: - // print($"c_api.TFE_DeleteTensorHandle({item.Handle.ToString("x16")})"); + //print($"c_api.TFE_DeleteTensorHandle({item.Handle.ToString("x16")})"); c_api.TFE_DeleteTensorHandle(item.Handle); break; case GCItemType.EagerTensorHandle: - // print($"c_api.TFE_DeleteEagerTensor({item.Handle.ToString("x16")})"); + //print($"c_api.TFE_DeleteEagerTensor({item.Handle.ToString("x16")})"); c_api.TFE_DeleteEagerTensor(item.Handle); break; default: diff --git a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj index e462fded..410c62c2 100644 --- a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj +++ b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj @@ -5,7 +5,7 @@ TensorFlow.NET Tensorflow 2.2.0 - 0.20.0-alpha2 + 0.20.0-preview1 8.0 Haiping Chen, Meinrad Recheis, Eli Belash SciSharp STACK diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 76be53a4..f9c24510 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -38,7 +38,8 @@ namespace Tensorflow _TensorLike, ITensorOrTensorArray, IPackable, - ICanBeFlattened + ICanBeFlattened, + IPointerInputs { protected int _id; private readonly Operation _op; @@ -280,6 +281,10 @@ namespace Tensorflow } else throw new InvalidOperationException($"Tensor.AllocationHandle is not null ({AllocationHandle}) but AllocationType is not matched to a C# allocation type ({AllocationType})."); } + + public virtual IntPtr ToPointer() + => _handle; + public bool IsDisposed => _disposed; // public int tensor_int_val { get; set; } diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index d9be6b99..6a9adcf3 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -199,6 +199,7 @@ namespace Tensorflow => type switch { TF_DataType.TF_STRING => "string", + TF_DataType.TF_UINT8 => "uint8", TF_DataType.TF_INT32 => "int32", TF_DataType.TF_FLOAT => "float32", TF_DataType.TF_BOOL => "bool", diff --git a/src/TensorFlowNET.Core/Training/gen_training_ops.cs b/src/TensorFlowNET.Core/Training/gen_training_ops.cs index bc14bfdd..e7105012 100644 --- a/src/TensorFlowNET.Core/Training/gen_training_ops.cs +++ b/src/TensorFlowNET.Core/Training/gen_training_ops.cs @@ -72,6 +72,7 @@ namespace Tensorflow alpha, delta }, 3, + wrap_tfe_src.SetOpAttrs2("use_locking", use_locking), op => wrap_tfe_src.SetOpAttrs(op, "use_locking", use_locking), null, 0)); status.Check(true); diff --git a/src/TensorFlowNET.Core/Util/EagerTensorPass.cs b/src/TensorFlowNET.Core/Util/EagerTensorPass.cs new file mode 100644 index 00000000..5b0bbde8 --- /dev/null +++ b/src/TensorFlowNET.Core/Util/EagerTensorPass.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Eager; + +namespace Tensorflow +{ + public class EagerTensorPass : PointerInputs + { + public EagerTensorPass(params EagerTensor[] tensors) + { + data = tensors; + } + + public static EagerTensorPass Create(int count = 1) + => new EagerTensorPass(Enumerable.Range(0, count).Select(x => new EagerTensor()).ToArray()); + + public static EagerTensorPass From(params object[] objects) + => new EagerTensorPass(objects.Select(x => ops.convert_to_tensor(x) as EagerTensor).ToArray()); + } +} diff --git a/src/TensorFlowNET.Core/Util/IPointerInputs.cs b/src/TensorFlowNET.Core/Util/IPointerInputs.cs new file mode 100644 index 00000000..e0cdd0d3 --- /dev/null +++ b/src/TensorFlowNET.Core/Util/IPointerInputs.cs @@ -0,0 +1,11 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public interface IPointerInputs + { + public IntPtr ToPointer(); + } +} diff --git a/src/TensorFlowNET.Core/Util/PointerInputs.cs b/src/TensorFlowNET.Core/Util/PointerInputs.cs new file mode 100644 index 00000000..64d638a2 --- /dev/null +++ b/src/TensorFlowNET.Core/Util/PointerInputs.cs @@ -0,0 +1,30 @@ +using System; +using System.Collections.Generic; +using System.Text; +using System.Linq; + +namespace Tensorflow +{ + public abstract class PointerInputs + where T : IPointerInputs, new() + { + protected T[] data; + public int Length + => data.Length; + + public IntPtr[] Points + => data.Select(x => x.ToPointer()).ToArray(); + + public PointerInputs(params T[] data) + => this.data = data; + + public T this[int idx] + => data[idx]; + + public T[] Items + => data; + + public static implicit operator IntPtr[](PointerInputs inputs) + => inputs.data.Select(x => x.ToPointer()).ToArray(); + } +} diff --git a/src/TensorFlowNET.Core/Util/TensorManager.cs b/src/TensorFlowNET.Core/Util/TensorManager.cs new file mode 100644 index 00000000..6a3e518a --- /dev/null +++ b/src/TensorFlowNET.Core/Util/TensorManager.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Eager; + +namespace Tensorflow +{ + public class TensorManager + { + Dictionary tensors; + public TensorManager() + { + tensors = new Dictionary(); + } + + public EagerTensor GetTensor(IntPtr handle) + { + if (tensors.ContainsKey(handle)) + return tensors[handle]; + + //return new EagerTensor(handle); + tensors[handle] = new EagerTensor(handle); + return tensors[handle]; + } + + public void Reset() + { + tensors.Clear(); + } + } +} diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index ac0eda44..12089b78 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -54,7 +54,7 @@ namespace Tensorflow public BaseResourceVariable(IntPtr handle, IntPtr tensor) { _handle = handle; - this.handle = new EagerTensor(tensor); + this.handle = tf.tensorMgr.GetTensor(tensor); } public void __init__(bool trainable = true, diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs index b96576e5..08ef0462 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs @@ -22,21 +22,24 @@ namespace Tensorflow { public partial class ResourceVariable { + public static OpDefLibrary _op_def_lib = new OpDefLibrary(); + public static Tensor operator +(ResourceVariable x, int y) => op_helper("add", x, y); public static Tensor operator +(ResourceVariable x, float y) => op_helper("add", x, y); public static Tensor operator +(ResourceVariable x, double y) => op_helper("add", x, y); - + public static Tensor operator +(ResourceVariable x, ResourceVariable y) => op_helper("add", x, y); public static Tensor operator -(ResourceVariable x, int y) => op_helper("sub", x, y); public static Tensor operator -(ResourceVariable x, float y) => op_helper("sub", x, y); public static Tensor operator -(ResourceVariable x, double y) => op_helper("sub", x, y); public static Tensor operator -(ResourceVariable x, Tensor y) => op_helper("sub", x, y); + public static Tensor operator -(ResourceVariable x, ResourceVariable y) => op_helper("sub", x, y); public static Tensor operator *(ResourceVariable x, ResourceVariable y) => op_helper("mul", x, y); public static Tensor operator *(ResourceVariable x, NDArray y) => op_helper("mul", x, y); - public static Tensor operator <(ResourceVariable x, Tensor y) => gen_math_ops.less(x.value(), y); + public static Tensor operator <(ResourceVariable x, Tensor y) => op_helper("less", x, y); - public static Tensor operator >(ResourceVariable x, Tensor y) => gen_math_ops.greater(x.value(), y); + public static Tensor operator >(ResourceVariable x, Tensor y) => op_helper("greater", x, y); private static Tensor op_helper(string default_name, ResourceVariable x, T y) => tf_with(ops.name_scope(null, default_name, new { x, y }), scope => @@ -58,6 +61,12 @@ namespace Tensorflow case "mul": result = gen_math_ops.mul(xVal, yTensor, name: name); break; + case "less": + result = gen_math_ops.less(xVal, yTensor, name); + break; + case "greater": + result = gen_math_ops.greater(xVal, yTensor, name); + break; default: throw new NotImplementedException(""); } diff --git a/src/TensorFlowNET.Core/ops.name_scope.cs b/src/TensorFlowNET.Core/ops.name_scope.cs index 0d48cb3a..d3d78a45 100644 --- a/src/TensorFlowNET.Core/ops.name_scope.cs +++ b/src/TensorFlowNET.Core/ops.name_scope.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using System.Collections.Generic; +using System.Diagnostics; using Tensorflow.Eager; using static Tensorflow.Binding; @@ -96,15 +97,18 @@ namespace Tensorflow get_default_graph()._name_stack = old_scope_name; } + [DebuggerNonUserCode] public void __exit__() { } + [DebuggerNonUserCode] public void __init__() { } + [DebuggerNonUserCode] public void __del__() { diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index fa41cf98..b5780917 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -40,14 +40,15 @@ namespace Tensorflow public TF_DataType @string = TF_DataType.TF_STRING; public Context context = new Context(new ContextOptions(), new Status()); - + public TensorManager tensorMgr; public tensorflow() { _constructThreadingObjects(); InitGradientEnvironment(); + tensorMgr = new TensorManager(); } - private unsafe void InitGradientEnvironment() + private void InitGradientEnvironment() { GarbageCollector.Init(); @@ -64,25 +65,30 @@ namespace Tensorflow ops.RegisterFromAssembly(); // ops.RegisterFromAssemblyEager(); - c_api.TFE_RegisterGradientFunction((op_name, op_inputs, op_outputs, num_attrs, output_grads, skip_input_indices) => + c_api.TFE_RegisterGradientFunction((op_name, op_inputs, op_outputs, attrs_string, output_grads, skip_input_indices) => { /*var input_tensors = new BindingArray(op_inputs); var output_tensors = new BindingArray(op_outputs); var output_grad_tensors = new BindingArray(output_grads);*/ - var input_tensors = new BindingTensorArray(op_inputs).Data.Select(x => new EagerTensor(x)).ToArray(); - var output_tensors = new BindingTensorArray(op_outputs).Data.Select(x => new EagerTensor(x)).ToArray(); - var output_grad_tensors = new BindingTensorArray(output_grads).Data.Select(x => new EagerTensor(x)).ToArray(); - var skip_input_indices_param = new BindingArray(skip_input_indices).Data.Select(x => *(int*)x).ToArray(); + var input_tensors = new BindingTensorArray(op_inputs) + .Data.Select(x => tf.tensorMgr.GetTensor(x)).ToArray(); + var output_tensors = new BindingTensorArray(op_outputs) + .Data.Select(x => tf.tensorMgr.GetTensor(x)).ToArray(); + var output_grad_tensors = new BindingTensorArray(output_grads) + .Data.Select(x => tf.tensorMgr.GetTensor(x)).ToArray(); + var skip_input_indices_param = new BindingArray(skip_input_indices); var gradients = ops.gradientFunctions[op_name](new EagerOperation { + Name = op_name, NumInputs = input_tensors.Length, Inputs = input_tensors, // InputHandles = input_tensors.Data, NumOutputs = output_tensors.Length, Outputs = output_tensors, // OutputHandles = output_tensors.Data, - SkipInputIndices = skip_input_indices_param + SkipInputIndicesArray = skip_input_indices_param, + AttrsArray = attrs_string.Split(',') }, output_grad_tensors); var gradients_handles = gradients.Select(x => x == null ? IntPtr.Zero : (x as EagerTensor).EagerTensorHandle).ToArray(); diff --git a/test/TensorFlowNET.UnitTest/Basics/VariableTest.cs b/test/TensorFlowNET.UnitTest/Basics/VariableTest.cs index e41edb89..8f78e768 100644 --- a/test/TensorFlowNET.UnitTest/Basics/VariableTest.cs +++ b/test/TensorFlowNET.UnitTest/Basics/VariableTest.cs @@ -56,10 +56,10 @@ namespace TensorFlowNET.UnitTest.Basics public void Accumulation() { var x = tf.Variable(10, name: "x"); - /*for (int i = 0; i < 5; i++) - x = x + 1; + for (int i = 0; i < 5; i++) + x.assign(x + 1); - Assert.AreEqual(15, (int)x.numpy());*/ + Assert.AreEqual(15, (int)x.numpy()); } [TestMethod] diff --git a/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj b/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj index 081bb3a2..ec8c2ec1 100644 --- a/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj @@ -44,10 +44,10 @@ - - + + - +