diff --git a/src/TensorFlowNET.Core/APIs/c_api.cs b/src/TensorFlowNET.Core/APIs/c_api.cs index a983d033..898afe34 100644 --- a/src/TensorFlowNET.Core/APIs/c_api.cs +++ b/src/TensorFlowNET.Core/APIs/c_api.cs @@ -27,7 +27,7 @@ namespace Tensorflow /// string => IntPtr c_api.StringPiece(IntPtr) /// unsigned char => byte /// - public static partial class c_api + public partial class c_api { public const string TensorFlowLibName = "tensorflow"; diff --git a/src/TensorFlowNET.Core/Attributes/c_api.ops.cs b/src/TensorFlowNET.Core/Attributes/c_api.ops.cs index 4e99300d..8f067d3e 100644 --- a/src/TensorFlowNET.Core/Attributes/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Attributes/c_api.ops.cs @@ -5,7 +5,7 @@ using System.Text; namespace Tensorflow { - public static partial class c_api + public partial class c_api { /// /// Fills in `value` with the value of the attribute `attr_name`. `value` must diff --git a/src/TensorFlowNET.Core/Buffers/c_api.buffer.cs b/src/TensorFlowNET.Core/Buffers/c_api.buffer.cs index 9adfd411..d9792f12 100644 --- a/src/TensorFlowNET.Core/Buffers/c_api.buffer.cs +++ b/src/TensorFlowNET.Core/Buffers/c_api.buffer.cs @@ -5,7 +5,7 @@ using System.Text; namespace Tensorflow { - public static partial class c_api + public partial class c_api { [DllImport(TensorFlowLibName)] public static extern void TF_DeleteBuffer(IntPtr buffer); diff --git a/src/TensorFlowNET.Core/Functions/c_api.function.cs b/src/TensorFlowNET.Core/Functions/c_api.function.cs index 32c020a6..cd4adfea 100644 --- a/src/TensorFlowNET.Core/Functions/c_api.function.cs +++ b/src/TensorFlowNET.Core/Functions/c_api.function.cs @@ -5,7 +5,7 @@ using System.Text; namespace Tensorflow { - public static partial class c_api + public partial class c_api { /// /// Write out a serialized representation of `func` (as a FunctionDef protocol diff --git a/src/TensorFlowNET.Core/Gradients/c_api.gradient.cs b/src/TensorFlowNET.Core/Gradients/c_api.gradient.cs index 16a32ae1..ba992b82 100644 --- a/src/TensorFlowNET.Core/Gradients/c_api.gradient.cs +++ b/src/TensorFlowNET.Core/Gradients/c_api.gradient.cs @@ -5,7 +5,7 @@ using System.Text; namespace Tensorflow { - public static partial class c_api + public partial class c_api { /// /// Adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s, diff --git a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs index 53288942..edcbd33f 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs @@ -53,10 +53,76 @@ namespace Tensorflow using (var namescope = new ops.name_scope(name, "gradients", values: all)) { grad_scope = namescope; + // Get a uid for this call to gradients that can be used to help + // cluster ops for compilation. + var gradient_uid = ops.get_default_graph().unique_name("uid"); + var to_ops = ys1.Select(x => x.op).ToList(); + var from_ops = xs1.Select(x => x.op).ToList(); + var stop_gradient_ops = stop_gradients1.Select(x => x.op).ToList(); + _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List(), xs1); } } + /// + /// + /// + /// + /// + /// + /// + private void _DefaultGradYs(List grad_ys, List ys, bool colocate_gradients_with_ops, string gradient_uid = "__unsupported__") + { + + } + + /// + /// Initialize the pending count for ops between two lists of Operations. + /// 'pending_count[op]' indicates the number of backprop inputs + /// to this operation. + /// + /// + /// + /// + /// + /// + private static void _PendingCount(List to_ops, List from_ops, bool colocate_gradients_with_ops, List func_graphs, List xs) + { + List reached_ops = new List(); + _MarkReachedOps(from_ops, reached_ops, func_graphs); + } + + /// + /// Mark all ops reached from "from_ops" + /// + /// + /// + /// + private static void _MarkReachedOps(List from_ops, List reached_ops, List func_graphs) + { + foreach(var op in from_ops) + { + reached_ops.Add(op); + foreach(var output in op.outputs) + { + reached_ops.AddRange(_Consumers(output, func_graphs)); + } + } + + reached_ops.Reverse(); + } + + /// + /// Returns the consumers of t, crossing closure boundaries where necessary. + /// + /// + /// + private static List _Consumers(Tensor t, List func_graphs) + { + var consumers = t.consumers(); + return consumers; + } + private static List _AsList(object ys) { List ret = null; diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs index 6e7a5bb3..52c23785 100644 --- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs +++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs @@ -5,7 +5,7 @@ using System.Text; namespace Tensorflow { - public static partial class c_api + public partial class c_api { /// /// Destroy an options object. Graph will be deleted once no more diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index a4b23294..aad9c555 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -32,12 +32,12 @@ namespace Tensorflow public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) { int size = Marshal.SizeOf(); - var handle = (TF_Input*)Marshal.AllocHGlobal(size); + var handle = Marshal.AllocHGlobal(size); int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers); var consumers = new TF_Input[num]; for(int i = 0; i < num; i++) { - consumers[i] = new TF_Input((*handle).oper + i * size, (*handle).index); + consumers[i] = Marshal.PtrToStructure(handle + i * size); } return consumers; @@ -161,6 +161,11 @@ namespace Tensorflow } } + public override string ToString() + { + return $"'{Name}' type={OpType}"; + } + public static implicit operator Operation(IntPtr handle) => new Operation(handle); public static implicit operator IntPtr(Operation op) => op._handle; diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs index ac293dae..46317bac 100644 --- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs @@ -5,7 +5,7 @@ using System.Text; namespace Tensorflow { - public static partial class c_api + public partial class c_api { /// /// Request that `desc` be co-located on the device where `op` @@ -154,12 +154,15 @@ namespace Tensorflow /// an operation. Returns the number of output consumers (should match /// TF_OperationOutputNumConsumers(oper_out)). /// - /// - /// - /// + /// TF_Output + /// TF_Input* + /// int /// [DllImport(TensorFlowLibName)] - public static extern unsafe int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers, int max_consumers); + public static extern unsafe int TF_OperationOutputConsumers(TF_Output oper_out, IntPtr consumers, int max_consumers); + + [DllImport(TensorFlowLibName)] + public static extern int TF_OperationOutputConsumers(TF_Output oper_out); [DllImport(TensorFlowLibName)] public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out); diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 5e727d9f..b08d5e1f 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -26,5 +26,20 @@ namespace Tensorflow return new Tensor(_op, 0, dtype); } + + /// + /// Return a tensor with the same shape and contents as the input tensor or value. + /// + /// + /// + public static Tensor identity(Tensor input, string name = "") + { + var keywords = new Dictionary(); + keywords.Add("input", input); + + var _op = _op_def_lib._apply_op_helper("Identity", name, keywords); + + return _op.outputs[0]; + } } } diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index f1963f38..a7cfce43 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -16,7 +16,7 @@ namespace Tensorflow keywords.Add("x", x); keywords.Add("y", y); - var _op = _op_def_lib._apply_op_helper("Add", name: "add", keywords: keywords); + var _op = _op_def_lib._apply_op_helper("Add", keywords: keywords); return new Tensor(_op, 0, _op.OutputType(0)); } @@ -38,7 +38,7 @@ namespace Tensorflow keywords.Add("x", x); keywords.Add("y", y); - var _op = _op_def_lib._apply_op_helper("Mul", name: "mul", keywords: keywords); + var _op = _op_def_lib._apply_op_helper("Mul", keywords: keywords); return new Tensor(_op, 0, _op.OutputType(0)); } @@ -62,7 +62,7 @@ namespace Tensorflow keywords.Add("transpose_a", transpose_a); keywords.Add("transpose_b", transpose_b); - var _op = _op_def_lib._apply_op_helper("MatMul", name: "MatMul", keywords: keywords); + var _op = _op_def_lib._apply_op_helper("MatMul", keywords: keywords); return new Tensor(_op, 0, _op.OutputType(0)); } @@ -73,7 +73,7 @@ namespace Tensorflow keywords.Add("x", x); keywords.Add("y", y); - var _op = _op_def_lib._apply_op_helper("Pow", name: "Pow", keywords: keywords); + var _op = _op_def_lib._apply_op_helper("Pow", keywords: keywords); return new Tensor(_op, 0, _op.OutputType(0)); } diff --git a/src/TensorFlowNET.Core/Sessions/c_api.session.cs b/src/TensorFlowNET.Core/Sessions/c_api.session.cs index 646ff4b2..3fc20365 100644 --- a/src/TensorFlowNET.Core/Sessions/c_api.session.cs +++ b/src/TensorFlowNET.Core/Sessions/c_api.session.cs @@ -5,7 +5,7 @@ using System.Text; namespace Tensorflow { - public static partial class c_api + public partial class c_api { /// /// Destroy a session object. diff --git a/src/TensorFlowNET.Core/Sessions/c_api.tf_session_helper.cs b/src/TensorFlowNET.Core/Sessions/c_api.tf_session_helper.cs new file mode 100644 index 00000000..97d55ec4 --- /dev/null +++ b/src/TensorFlowNET.Core/Sessions/c_api.tf_session_helper.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text; + +namespace Tensorflow +{ + public partial class c_api + { + public static string[] TF_OperationOutputConsumers_wrapper(TF_Output oper_out) + { + int num_consumers = TF_OperationOutputConsumers(oper_out); + int size = Marshal.SizeOf(); + var handle = Marshal.AllocHGlobal(size * num_consumers); + int num = TF_OperationOutputConsumers(oper_out, handle, num_consumers); + var consumers = new string[num_consumers]; + for (int i = 0; i < num; i++) + { + TF_Input input = Marshal.PtrToStructure(handle + i * size); + consumers[i] = Marshal.PtrToStringAnsi(TF_OperationName(input.oper)); + } + + return consumers; + } + } +} diff --git a/src/TensorFlowNET.Core/Status/c_api.status.cs b/src/TensorFlowNET.Core/Status/c_api.status.cs index 5ba62136..855a638e 100644 --- a/src/TensorFlowNET.Core/Status/c_api.status.cs +++ b/src/TensorFlowNET.Core/Status/c_api.status.cs @@ -5,7 +5,7 @@ using System.Text; namespace Tensorflow { - public static partial class c_api + public partial class c_api { /// /// Delete a previously created status object. diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 11f185f8..b8ecf011 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -34,6 +34,7 @@ namespace Tensorflow public ulong dataTypeSize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / dataTypeSize; public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); + public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); public long[] shape { get @@ -160,6 +161,13 @@ namespace Tensorflow this._dtype = dtype; } + public List consumers() + { + var output = _as_tf_output(); + var consumer_names = c_api.TF_OperationOutputConsumers_wrapper(output); + return consumer_names.Select(x => Graph.OperationByName(x)).ToList(); + } + public TF_Output _as_tf_output() { return new TF_Output(op, value_index); @@ -225,7 +233,7 @@ namespace Tensorflow } } - return $"{name} {dtype.ToString()} {rank} {string.Join(",", shape)}"; + return $"{name} shape=({string.Join(",", shape)}) dtype={dtype.ToString()}"; } public void Dispose() diff --git a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs index 62ed55a9..ec4cb488 100644 --- a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs @@ -5,7 +5,7 @@ using System.Text; namespace Tensorflow { - public static partial class c_api + public partial class c_api { /// /// Allocate and return a new Tensor. diff --git a/src/TensorFlowNET.Core/Variables/RefVariable.cs b/src/TensorFlowNET.Core/Variables/RefVariable.cs index e8fa72a0..01a354f0 100644 --- a/src/TensorFlowNET.Core/Variables/RefVariable.cs +++ b/src/TensorFlowNET.Core/Variables/RefVariable.cs @@ -11,6 +11,7 @@ namespace Tensorflow public string _graph_key; public bool _trainable; public Tensor _variable; + public Tensor _snapshot; public RefVariable(object initial_value, bool trainable = true, @@ -87,9 +88,12 @@ namespace Tensorflow } else { - + _snapshot = gen_array_ops.identity(_variable, name = "read"); } + // clear g._name_stack + ops.get_default_graph().old_stack = ""; + ops.add_to_collections(collections, this); } }