diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln index ae290b70..36f71409 100644 --- a/TensorFlow.NET.sln +++ b/TensorFlow.NET.sln @@ -16,95 +16,51 @@ EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU - Debug|x64 = Debug|x64 Debug-Minimal|Any CPU = Debug-Minimal|Any CPU - Debug-Minimal|x64 = Debug-Minimal|x64 Publish|Any CPU = Publish|Any CPU - Publish|x64 = Publish|x64 Release|Any CPU = Release|Any CPU - Release|x64 = Release|x64 EndGlobalSection GlobalSection(ProjectConfigurationPlatforms) = postSolution {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.Build.0 = Debug|Any CPU - {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|x64.ActiveCfg = Debug|x64 - {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|x64.Build.0 = Debug|x64 {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU - {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU - {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|x64.Build.0 = Debug|Any CPU {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|Any CPU.ActiveCfg = Release|Any CPU {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|Any CPU.Build.0 = Release|Any CPU - {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|x64.ActiveCfg = Release|Any CPU - {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|x64.Build.0 = Release|Any CPU {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.ActiveCfg = Release|Any CPU {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.Build.0 = Release|Any CPU - {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|x64.ActiveCfg = Release|Any CPU - {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|x64.Build.0 = Release|Any CPU {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|Any CPU.Build.0 = Debug|Any CPU - {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|x64.ActiveCfg = Debug|Any CPU - {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|x64.Build.0 = Debug|Any CPU {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU - {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU - {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|x64.Build.0 = Debug|Any CPU {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|Any CPU.ActiveCfg = Release|Any CPU {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|Any CPU.Build.0 = Release|Any CPU - {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|x64.ActiveCfg = Release|Any CPU - {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|x64.Build.0 = Release|Any CPU {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|Any CPU.ActiveCfg = Release|Any CPU {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|Any CPU.Build.0 = Release|Any CPU - {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|x64.ActiveCfg = Release|Any CPU - {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|x64.Build.0 = Release|Any CPU {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|Any CPU.Build.0 = Debug|Any CPU - {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|x64.ActiveCfg = Debug|Any CPU - {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|x64.Build.0 = Debug|Any CPU {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU - {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU - {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|x64.Build.0 = Debug|Any CPU {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|Any CPU.ActiveCfg = Release|Any CPU {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|Any CPU.Build.0 = Release|Any CPU - {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|x64.ActiveCfg = Release|Any CPU - {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|x64.Build.0 = Release|Any CPU {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|Any CPU.ActiveCfg = Release|Any CPU {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|Any CPU.Build.0 = Release|Any CPU - {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x64.ActiveCfg = Release|Any CPU - {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x64.Build.0 = Release|Any CPU {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|Any CPU.Build.0 = Debug|Any CPU - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|x64.ActiveCfg = Debug|Any CPU - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|x64.Build.0 = Debug|Any CPU {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|x64.Build.0 = Debug|Any CPU {6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|Any CPU.ActiveCfg = Release|Any CPU {6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|Any CPU.Build.0 = Release|Any CPU - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|x64.ActiveCfg = Release|Any CPU - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|x64.Build.0 = Release|Any CPU {6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|Any CPU.ActiveCfg = Release|Any CPU {6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|Any CPU.Build.0 = Release|Any CPU - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|x64.ActiveCfg = Release|Any CPU - {6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|x64.Build.0 = Release|Any CPU {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|Any CPU.Build.0 = Debug|Any CPU - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|x64.ActiveCfg = Debug|Any CPU - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|x64.Build.0 = Debug|Any CPU {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|x64.Build.0 = Debug|Any CPU {EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|Any CPU.ActiveCfg = Release|Any CPU {EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|Any CPU.Build.0 = Release|Any CPU - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|x64.ActiveCfg = Release|Any CPU - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|x64.Build.0 = Release|Any CPU {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|Any CPU.ActiveCfg = Release|Any CPU {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|Any CPU.Build.0 = Release|Any CPU - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x64.ActiveCfg = Release|Any CPU - {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x64.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/src/TensorFlowNET.Core/APIs/c_api.cs b/src/TensorFlowNET.Core/APIs/c_api.cs index 56672173..bdf2785f 100644 --- a/src/TensorFlowNET.Core/APIs/c_api.cs +++ b/src/TensorFlowNET.Core/APIs/c_api.cs @@ -43,7 +43,7 @@ namespace Tensorflow /// public partial class c_api { - public const string TensorFlowLibName = "tensorflow"; + public const string TensorFlowLibName = @"D:\SciSharp\tensorflow-google\bazel-bin\tensorflow\tensorflow.dll"; public static string StringPiece(IntPtr handle) { diff --git a/src/TensorFlowNET.Core/Eager/EagerOperation.cs b/src/TensorFlowNET.Core/Eager/EagerOperation.cs new file mode 100644 index 00000000..ca10caaa --- /dev/null +++ b/src/TensorFlowNET.Core/Eager/EagerOperation.cs @@ -0,0 +1,34 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Eager +{ + public class EagerOperation : Operation + { + public int NumInputs; + public Tensor[] Inputs { get; set; } + + public EagerOperation() : base(IntPtr.Zero) { } + + public override InputList inputs + { + get + { + if (_inputs_val == null) + { + var retval = new Tensor[NumInputs]; + + for (int i = 0; i < NumInputs; i++) + { + + } + + _inputs_val = new InputList(Inputs); + } + + return _inputs_val; + } + } + } +} diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.Implicit.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.Implicit.cs index de08e9a3..a8a6952d 100644 --- a/src/TensorFlowNET.Core/Eager/EagerTensor.Implicit.cs +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.Implicit.cs @@ -10,5 +10,8 @@ namespace Tensorflow.Eager { public static explicit operator TFE_TensorHandle(EagerTensor tensor) => tensor.tfe_tensor_handle; + + public static implicit operator IntPtr(EagerTensor tensor) + => tensor.EagerTensorHandle; } } diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.cs index 258c9ca7..bfe3a9e1 100644 --- a/src/TensorFlowNET.Core/Eager/EagerTensor.cs +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.cs @@ -24,31 +24,10 @@ namespace Tensorflow.Eager tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); } - public EagerTensor(int value, string device_name) : base(value) - { - tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); - EagerTensorHandle = c_api.TFE_EagerTensorFromHandle(tf.context, tfe_tensor_handle); - } - - public EagerTensor(float value, string device_name) : base(value) - { - tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); - EagerTensorHandle = c_api.TFE_EagerTensorFromHandle(tf.context, tfe_tensor_handle); - } - - public EagerTensor(float[] value, string device_name) : base(value) - { - tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); - } - - public EagerTensor(double[] value, string device_name) : base(value) - { - tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); - } - public EagerTensor(NDArray value, string device_name) : base(value) { tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); + EagerTensorHandle = c_api.TFE_EagerTensorFromHandle(tf.context, tfe_tensor_handle); } public override string ToString() @@ -56,23 +35,24 @@ namespace Tensorflow.Eager switch (rank) { case -1: - return $"tf.Tensor: shape=, dtype={dtype.as_numpy_name()}, numpy={GetFormattedString()}"; + return $"tf.Tensor: shape=, dtype={dtype.as_numpy_name()}, numpy={GetFormattedString(dtype, numpy())}"; case 0: - return $"tf.Tensor: shape=(), dtype={dtype.as_numpy_name()}, numpy={GetFormattedString()}"; + return $"tf.Tensor: shape=(), dtype={dtype.as_numpy_name()}, numpy={GetFormattedString(dtype, numpy())}"; default: - return $"tf.Tensor: shape=({string.Join(",", shape)}), dtype={dtype.as_numpy_name()}, numpy={GetFormattedString()}"; + return $"tf.Tensor: shape=({string.Join(",", shape)}), dtype={dtype.as_numpy_name()}, numpy={GetFormattedString(dtype, numpy())}"; } } - private string GetFormattedString() + public static string GetFormattedString(TF_DataType dtype, NDArray nd) { - var nd = numpy(); switch (dtype) { case TF_DataType.TF_STRING: return $"b'{(string)nd}'"; case TF_DataType.TF_BOOL: return (nd.GetByte(0) > 0).ToString(); + case TF_DataType.TF_RESOURCE: + return ""; default: return nd.ToString(); } diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs index 13660a4a..48f0a5d5 100644 --- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs +++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs @@ -11,8 +11,18 @@ namespace Tensorflow public static extern void TFE_RegisterGradientFunction(_gradient_function_callback callbackPointer); [UnmanagedFunctionPointer(CallingConvention.StdCall)] - public delegate void _gradient_function_callback(string op_name, int num_inputs, IntPtr attrs, int num_attrs); + public delegate IntPtr _gradient_function_callback(string op_name, int num_inputs, IntPtr[] op_inputs, int num_attrs, IntPtr[] output_grads); + [DllImport(TensorFlowLibName)] + public static extern IntPtr VSpace_Handle(VSpace_callback_Ones ones, VSpace_callback_AggregateGrads aggregate_grads); + [UnmanagedFunctionPointer(CallingConvention.StdCall)] + public delegate IntPtr VSpace_callback_Ones(long[] shape, int dims, TF_DataType dtype); + [UnmanagedFunctionPointer(CallingConvention.StdCall)] + public delegate IntPtr VSpace_callback_AggregateGrads(IntPtr gradients, int num_grads); + + [DllImport(TensorFlowLibName)] + public static extern void TFE_RegisterVSpace(IntPtr vspace); + /// /// Return a new options object. /// @@ -330,7 +340,10 @@ namespace Tensorflow string name, IntPtr[] args, int input_size, + TFE_FastPathExecute_SetOpAttrs set_op_attrs, IntPtr status); + [UnmanagedFunctionPointer(CallingConvention.StdCall)] + public delegate void TFE_FastPathExecute_SetOpAttrs(IntPtr op); [DllImport(TensorFlowLibName)] public static extern IntPtr TFE_TapeSetNew(bool persistent, bool watch_accessed_variables); @@ -342,7 +355,8 @@ namespace Tensorflow public static extern void TFE_TapeWatch(IntPtr tape, IntPtr tensor); [DllImport(TensorFlowLibName)] - public static extern IntPtr TFE_TapeGradient(IntPtr tape, IntPtr[] target, int target_size, + public static extern IntPtr TFE_TapeGradient(IntPtr tape, + IntPtr[] target, int target_size, IntPtr[] sources, int source_size, IntPtr status); } diff --git a/src/TensorFlowNET.Core/Eager/wrap_tfe_src.TFE_FastPathExecute.cs b/src/TensorFlowNET.Core/Eager/wrap_tfe_src.TFE_FastPathExecute.cs index b9aaeab2..7b4226f9 100644 --- a/src/TensorFlowNET.Core/Eager/wrap_tfe_src.TFE_FastPathExecute.cs +++ b/src/TensorFlowNET.Core/Eager/wrap_tfe_src.TFE_FastPathExecute.cs @@ -11,6 +11,8 @@ namespace Tensorflow.Eager public partial class wrap_tfe_src { static int kFastPathExecuteInputStartIndex = 0; + + [Obsolete] public static EagerTensor TFE_FastPathExecute(Context ctx, string device_name, string opName, @@ -203,7 +205,7 @@ namespace Tensorflow.Eager /// /// /// - private static void SetOpAttrWithDefaults(Context ctx, IntPtr op, AttrDef attr, + public static void SetOpAttrWithDefaults(Context ctx, IntPtr op, AttrDef attr, string attr_name, object attr_value, Dictionary attr_list_sizes, Status status) diff --git a/src/TensorFlowNET.Core/Gradients/GradientActor.cs b/src/TensorFlowNET.Core/Gradients/GradientActor.cs index 82f37ac3..e6dbe92a 100644 --- a/src/TensorFlowNET.Core/Gradients/GradientActor.cs +++ b/src/TensorFlowNET.Core/Gradients/GradientActor.cs @@ -74,12 +74,12 @@ namespace Tensorflow.Gradients } using var status = new Status(); - var et = c_api.TFE_TapeGradient(_tape, - new IntPtr[] { (target as EagerTensor).EagerTensorHandle }, 1, - new IntPtr[] { (sources as EagerTensor).EagerTensorHandle }, 1, + var et = c_api.TFE_TapeGradient(_tape, + new [] { (target as EagerTensor).EagerTensorHandle }, 1, + new [] { (sources as EagerTensor).EagerTensorHandle }, 1, status); status.Check(true); - return et; + return new EagerTensor(et); } public void Dispose() diff --git a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs index b479ba0b..a43799aa 100644 --- a/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs +++ b/src/TensorFlowNET.Core/Gradients/ops.gradient_function_mapping.cs @@ -24,9 +24,9 @@ namespace Tensorflow { public partial class ops { - static Dictionary> gradientFunctions = null; + public static Dictionary> gradientFunctions = null; - private static void RegisterFromAssembly() + public static void RegisterFromAssembly() { if (gradientFunctions == null) { diff --git a/src/TensorFlowNET.Core/Operations/Operation.Input.cs b/src/TensorFlowNET.Core/Operations/Operation.Input.cs index 5c992aff..48f1800b 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Input.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Input.cs @@ -40,9 +40,9 @@ namespace Tensorflow public int NumInputs => c_api.TF_OperationNumInputs(_handle); private TF_DataType[] _input_types => _inputs_val._inputs.Select(x => x.dtype).ToArray(); - private InputList _inputs_val; + protected InputList _inputs_val; - public InputList inputs + public virtual InputList inputs { get { diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index eb746f98..70509ad5 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -152,8 +152,14 @@ namespace Tensorflow { if(tf.context.executing_eagerly()) { - var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, "Pack", name, null, values, "axis", axis); - return _result; + using var status = new Status(); + var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Pack", name, + values.Select(x => (x as EagerTensor).EagerTensorHandle).ToArray(), 1, + (op) => wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "axis", axis, null, status), + status); + status.Check(true); + return new EagerTensor(tensor); } var _op = _op_def_lib._apply_op_helper("Pack", name: name, args: new { values, axis }); diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 93c43ca3..c1a9a0db 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -41,6 +41,18 @@ namespace Tensorflow /// public static Tensor add_n(Tensor[] inputs, string name = null) { + if (tf.context.executing_eagerly()) + { + using var status = new Status(); + var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "AddN", name, + inputs.Select(x => (x as EagerTensor).EagerTensorHandle).ToArray(), inputs.Length, + null, + status); + status.Check(true); + return new EagerTensor(_result); + } + var _op = _op_def_lib._apply_op_helper("AddN", name, args: new { inputs }); return _op.outputs[0]; @@ -121,10 +133,18 @@ namespace Tensorflow { try { - var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, - "Mean", name, null, - input, axis, "keep_dims", keep_dims); - return _result; + using var status = new Status(); + var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Mean", name, + new IntPtr[] + { + (input as EagerTensor).EagerTensorHandle, + (axis as EagerTensor).EagerTensorHandle + }, 2, + (op) => wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "keep_dims", keep_dims, null, status), + status); + status.Check(true); + return new EagerTensor(tensor); } catch (Exception) { @@ -196,17 +216,15 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { - using (var status = new Status()) - { - var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, - "Add", name, new IntPtr[] - { - (x as EagerTensor).EagerTensorHandle, - (y as EagerTensor).EagerTensorHandle - }, 2, status); - status.Check(true); - return new EagerTensor(_result); - } + using var status = new Status(); + var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Add", name, new IntPtr[] + { + (x as EagerTensor).EagerTensorHandle, + (y as EagerTensor).EagerTensorHandle + }, 2, null, status); + status.Check(true); + return new EagerTensor(_result); } var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y }); @@ -574,10 +592,18 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { - var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, - "Cast", name, null, - x, "DstT", DstT, "Truncate", Truncate); - return _result; + using var status = new Status(); + var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Cast", name, + new IntPtr[] { (x as EagerTensor).EagerTensorHandle }, 1, + (op) => + { + wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "DstT", DstT, null, status); + wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "Truncate", Truncate, null, status); + }, + status); + status.Check(true); + return new EagerTensor(tensor); } var _op = _op_def_lib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate }); @@ -619,17 +645,15 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { - using (var status = new Status()) - { - var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, - "Sub", name, new IntPtr[] - { - (x as EagerTensor).EagerTensorHandle, - (y as EagerTensor).EagerTensorHandle - }, 2, status); - status.Check(true); - return new EagerTensor(_result); - } + using var status = new Status(); + var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Sub", name, new IntPtr[] + { + (x as EagerTensor).EagerTensorHandle, + (y as EagerTensor).EagerTensorHandle + }, 2, null, status); + status.Check(true); + return new EagerTensor(_result); } var _op = _op_def_lib._apply_op_helper("Sub", name, args: new { x, y }); @@ -717,11 +741,11 @@ namespace Tensorflow { using var status = new Status(); var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, - "Mul", name, new IntPtr[] - { - (x as EagerTensor).EagerTensorHandle, - (y as EagerTensor).EagerTensorHandle - }, 2, status); + "Mul", name, new IntPtr[] + { + (x as EagerTensor).EagerTensorHandle, + (y as EagerTensor).EagerTensorHandle + }, 2, null, status); status.Check(true); return new EagerTensor(_result); } @@ -757,17 +781,15 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { - using (var status = new Status()) - { - var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + using var status = new Status(); + var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, "RealDiv", name, new IntPtr[] { (x as EagerTensor).EagerTensorHandle, (y as EagerTensor).EagerTensorHandle - }, 2, status); - status.Check(true); - return new EagerTensor(_result); - } + }, 2, null, status); + status.Check(true); + return new EagerTensor(_result); } var _op = _op_def_lib._apply_op_helper("RealDiv", name, args: new { x, y }); @@ -962,8 +984,16 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { - var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, "Range", name, null, start, limit, delta); - return _result; + using var status = new Status(); + var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "Range", name, new IntPtr[] + { + (start as EagerTensor).EagerTensorHandle, + (limit as EagerTensor).EagerTensorHandle, + (delta as EagerTensor).EagerTensorHandle + }, 3, null, status); + status.Check(true); + return new EagerTensor(tensor); } var _op = _op_def_lib._apply_op_helper("Range", name, new { start, limit, delta }); diff --git a/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs b/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs index edc83091..97079aa7 100644 --- a/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using System; +using System.Linq; using Tensorflow.Eager; using static Tensorflow.Binding; @@ -27,10 +29,16 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { - var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, - "AssignVariableOp", name, null, - resource, value); - return _result; + using var status = new Status(); + var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "AssignVariableOp", name, + new[] + { + (resource as EagerTensor).EagerTensorHandle, + (value as EagerTensor).EagerTensorHandle + }, 2, null, status); + status.Check(true); + return tensor; } var _op = _op_def_lib._apply_op_helper("AssignVariableOp", name, new { resource, value }); @@ -42,10 +50,13 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { - var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, - "VarIsInitializedOp", name, null, - resource); - return _result; + using var status = new Status(); + var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "VarIsInitializedOp", name, + new[] { (resource as EagerTensor).EagerTensorHandle }, + 1, null, status); + status.Check(true); + return new EagerTensor(tensor); } var _op = _op_def_lib._apply_op_helper("VarIsInitializedOp", name, new { resource }); @@ -67,10 +78,17 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { - var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, - "VarHandleOp", name, null, - "container", container, "shared_name", shared_name, "dtype", dtype, "shape", shape.dims); - return _result; + using var status = new Status(); + var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "VarHandleOp", name, null, 0, op => + { + wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "container", container, null, status); + wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "shared_name", shared_name, null, status); + wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "dtype", dtype, null, status); + wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "shape", shape.dims, null, status); + }, status); + status.Check(true); + return new EagerTensor(tensor); } var _op = _op_def_lib._apply_op_helper("VarHandleOp", name, new { @@ -94,10 +112,13 @@ namespace Tensorflow { if (tf.context.executing_eagerly()) { - var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name, - "ReadVariableOp", name, null, - resource, "dtype", dtype); - return _result; + using var status = new Status(); + var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, + "ReadVariableOp", name, new IntPtr[] { (resource as EagerTensor).EagerTensorHandle }, 1, + (op) => wrap_tfe_src.SetOpAttrWithDefaults(tf.context, op, null, "dtype", dtype, null, status), + status); + status.Check(true); + return new EagerTensor(tensor); } var _op = _op_def_lib._apply_op_helper("ReadVariableOp", name, new diff --git a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj index 520ff9e4..cef3653b 100644 --- a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj +++ b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj @@ -31,7 +31,7 @@ https://tensorflownet.readthedocs.io true true Open.snk - AnyCPU;x64 + AnyCPU @@ -40,20 +40,10 @@ https://tensorflownet.readthedocs.io x64 - - true - TRACE;DEBUG;SERIALIZABLE_ - x64 - - true - - true - - diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index 9d800503..15f61072 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -127,6 +127,8 @@ namespace Tensorflow return new EagerTensor(val, ctx.device_name); case float val: return new EagerTensor(val, ctx.device_name); + case float[,] val: + return new EagerTensor(val, ctx.device_name); case double val: return new EagerTensor(val, ctx.device_name); case float[] val: diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index decaf075..d9be6b99 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -202,6 +202,7 @@ namespace Tensorflow TF_DataType.TF_INT32 => "int32", TF_DataType.TF_FLOAT => "float32", TF_DataType.TF_BOOL => "bool", + TF_DataType.TF_RESOURCE => "resource", _ => type.ToString() }; diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs index b639e1b8..fa5ee600 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs @@ -18,6 +18,7 @@ using Google.Protobuf; using NumSharp; using System; using System.Collections.Generic; +using Tensorflow.Eager; using static Tensorflow.Binding; namespace Tensorflow @@ -221,5 +222,10 @@ namespace Tensorflow return array_ops.identity(value); }); } + + public override string ToString() + { + return $"tf.Variable: '{name}' shape={string.Join(",", shape)}, dtype={dtype.as_numpy_name()}, numpy={EagerTensor.GetFormattedString(dtype, numpy())}"; + } } } diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs index 715d15be..3ce88e36 100644 --- a/src/TensorFlowNET.Core/tensorflow.cs +++ b/src/TensorFlowNET.Core/tensorflow.cs @@ -14,6 +14,9 @@ limitations under the License. ******************************************************************************/ +using System; +using System.Linq; +using System.Runtime.InteropServices; using System.Threading; using Tensorflow.Eager; @@ -39,6 +42,45 @@ namespace Tensorflow public tensorflow() { _constructThreadingObjects(); + InitGradientEnvironment(); + } + + private unsafe void InitGradientEnvironment() + { + var vspace = c_api.VSpace_Handle((shape, dims, dtype) => + { + var ones = constant_op.constant(1.0f, dtype: dtype) as EagerTensor; + return ones.EagerTensorHandle; + }, (gradients, num_grads) => + { + var input_grads = new EagerTensor[num_grads]; + for (int i = 0; i < num_grads; i++) + input_grads[i] = new EagerTensor(*((IntPtr*)gradients + i)); + + var add_n = gen_math_ops.add_n(input_grads); + return (add_n as EagerTensor).EagerTensorHandle; + }); + + ops.RegisterFromAssembly(); + c_api.TFE_RegisterGradientFunction((op_name, num_inputs, op_inputs, num_attrs, output_grads) => + { + var output_grad_tensors = output_grads.Select(x => new EagerTensor(x)).ToArray(); + + var input_tensors = new EagerTensor[num_inputs]; + for (int i = 0; i < num_inputs; i++) + input_tensors[i] = new EagerTensor(op_inputs[op_inputs.Length == 1 ? 0 : i]); + + var gradients = ops.gradientFunctions[op_name](new EagerOperation + { + NumInputs = num_inputs, + Inputs = input_tensors + }, output_grad_tensors); + + var ret_tensors = Marshal.AllocHGlobal(sizeof(IntPtr) * num_inputs); + Marshal.Copy(gradients.Select(x => (x as EagerTensor).EagerTensorHandle).ToArray(), 0, ret_tensors, 2); + // Marshal.FreeHGlobal(ret_tensors); + return ret_tensors; + }); } public ResourceVariable Variable(T data, diff --git a/test/TensorFlowNET.UnitTest/Eager/GradientEagerTest.cs b/test/TensorFlowNET.UnitTest/Eager/GradientEagerTest.cs new file mode 100644 index 00000000..edd1a438 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Eager/GradientEagerTest.cs @@ -0,0 +1,27 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow; +using static Tensorflow.Binding; + +namespace TensorFlowNET.UnitTest.Gradient +{ + [TestClass] + public class GradientEagerTest : PythonTest + { + [TestMethod] + public void ConstantSq() + { + // Calcute the gradient of w * w + // by Automatic Differentiation in Eager mode + // in tensorflow.net 2.x that is in development intensively + var w = tf.constant(1.5f); + using var tape = tf.GradientTape(); + tape.watch(w); + var loss = w * w; + var grad = tape.gradient(loss, w); + print(grad); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj b/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj index a7430e7e..351f40d7 100644 --- a/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj @@ -30,11 +30,10 @@ - + - diff --git a/test/Tensorflow.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj b/test/Tensorflow.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj index b52a923b..030c3920 100644 --- a/test/Tensorflow.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj +++ b/test/Tensorflow.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj @@ -7,7 +7,7 @@ - +