@@ -27,7 +27,7 @@ namespace Tensorflow | |||||
/// string => IntPtr c_api.StringPiece(IntPtr) | /// string => IntPtr c_api.StringPiece(IntPtr) | ||||
/// unsigned char => byte | /// unsigned char => byte | ||||
/// </summary> | /// </summary> | ||||
public static partial class c_api | |||||
public partial class c_api | |||||
{ | { | ||||
public const string TensorFlowLibName = "tensorflow"; | public const string TensorFlowLibName = "tensorflow"; | ||||
@@ -5,7 +5,7 @@ using System.Text; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public static partial class c_api | |||||
public partial class c_api | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// Fills in `value` with the value of the attribute `attr_name`. `value` must | /// Fills in `value` with the value of the attribute `attr_name`. `value` must | ||||
@@ -5,7 +5,7 @@ using System.Text; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public static partial class c_api | |||||
public partial class c_api | |||||
{ | { | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TF_DeleteBuffer(IntPtr buffer); | public static extern void TF_DeleteBuffer(IntPtr buffer); | ||||
@@ -5,7 +5,7 @@ using System.Text; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public static partial class c_api | |||||
public partial class c_api | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// Write out a serialized representation of `func` (as a FunctionDef protocol | /// Write out a serialized representation of `func` (as a FunctionDef protocol | ||||
@@ -5,7 +5,7 @@ using System.Text; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public static partial class c_api | |||||
public partial class c_api | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// Adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s, | /// Adds operations to compute the partial derivatives of sum of `y`s w.r.t `x`s, | ||||
@@ -53,10 +53,76 @@ namespace Tensorflow | |||||
using (var namescope = new ops.name_scope<Tensor>(name, "gradients", values: all)) | using (var namescope = new ops.name_scope<Tensor>(name, "gradients", values: all)) | ||||
{ | { | ||||
grad_scope = namescope; | grad_scope = namescope; | ||||
// Get a uid for this call to gradients that can be used to help | |||||
// cluster ops for compilation. | |||||
var gradient_uid = ops.get_default_graph().unique_name("uid"); | |||||
var to_ops = ys1.Select(x => x.op).ToList(); | |||||
var from_ops = xs1.Select(x => x.op).ToList(); | |||||
var stop_gradient_ops = stop_gradients1.Select(x => x.op).ToList(); | |||||
_PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs1); | |||||
} | } | ||||
} | } | ||||
/// <summary> | |||||
/// | |||||
/// </summary> | |||||
/// <param name="grad_ys"></param> | |||||
/// <param name="ys"></param> | |||||
/// <param name="colocate_gradients_with_ops"></param> | |||||
/// <param name="gradient_uid"></param> | |||||
private void _DefaultGradYs(List<Tensor> grad_ys, List<Tensor> ys, bool colocate_gradients_with_ops, string gradient_uid = "__unsupported__") | |||||
{ | |||||
} | |||||
/// <summary> | |||||
/// Initialize the pending count for ops between two lists of Operations. | |||||
/// 'pending_count[op]' indicates the number of backprop inputs | |||||
/// to this operation. | |||||
/// </summary> | |||||
/// <param name="to_ops"></param> | |||||
/// <param name="from_ops"></param> | |||||
/// <param name="colocate_gradients_with_ops"></param> | |||||
/// <param name="func_graphs"></param> | |||||
/// <param name="xs"></param> | |||||
private static void _PendingCount(List<Operation> to_ops, List<Operation> from_ops, bool colocate_gradients_with_ops, List<object> func_graphs, List<Tensor> xs) | |||||
{ | |||||
List<Operation> reached_ops = new List<Operation>(); | |||||
_MarkReachedOps(from_ops, reached_ops, func_graphs); | |||||
} | |||||
/// <summary> | |||||
/// Mark all ops reached from "from_ops" | |||||
/// </summary> | |||||
/// <param name="from_ops"></param> | |||||
/// <param name="reached_ops"></param> | |||||
/// <param name="func_graphs"></param> | |||||
private static void _MarkReachedOps(List<Operation> from_ops, List<Operation> reached_ops, List<object> func_graphs) | |||||
{ | |||||
foreach(var op in from_ops) | |||||
{ | |||||
reached_ops.Add(op); | |||||
foreach(var output in op.outputs) | |||||
{ | |||||
reached_ops.AddRange(_Consumers(output, func_graphs)); | |||||
} | |||||
} | |||||
reached_ops.Reverse(); | |||||
} | |||||
/// <summary> | |||||
/// Returns the consumers of t, crossing closure boundaries where necessary. | |||||
/// </summary> | |||||
/// <param name="t"></param> | |||||
/// <param name="func_graphs"></param> | |||||
private static List<Operation> _Consumers(Tensor t, List<object> func_graphs) | |||||
{ | |||||
var consumers = t.consumers(); | |||||
return consumers; | |||||
} | |||||
private static List<Tensor> _AsList(object ys) | private static List<Tensor> _AsList(object ys) | ||||
{ | { | ||||
List<Tensor> ret = null; | List<Tensor> ret = null; | ||||
@@ -5,7 +5,7 @@ using System.Text; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public static partial class c_api | |||||
public partial class c_api | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// Destroy an options object. Graph will be deleted once no more | /// Destroy an options object. Graph will be deleted once no more | ||||
@@ -32,12 +32,12 @@ namespace Tensorflow | |||||
public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) | public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) | ||||
{ | { | ||||
int size = Marshal.SizeOf<TF_Input>(); | int size = Marshal.SizeOf<TF_Input>(); | ||||
var handle = (TF_Input*)Marshal.AllocHGlobal(size); | |||||
var handle = Marshal.AllocHGlobal(size); | |||||
int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers); | int num = c_api.TF_OperationOutputConsumers(new TF_Output(_handle, index), handle, max_consumers); | ||||
var consumers = new TF_Input[num]; | var consumers = new TF_Input[num]; | ||||
for(int i = 0; i < num; i++) | for(int i = 0; i < num; i++) | ||||
{ | { | ||||
consumers[i] = new TF_Input((*handle).oper + i * size, (*handle).index); | |||||
consumers[i] = Marshal.PtrToStructure<TF_Input>(handle + i * size); | |||||
} | } | ||||
return consumers; | return consumers; | ||||
@@ -161,6 +161,11 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
public override string ToString() | |||||
{ | |||||
return $"'{Name}' type={OpType}"; | |||||
} | |||||
public static implicit operator Operation(IntPtr handle) => new Operation(handle); | public static implicit operator Operation(IntPtr handle) => new Operation(handle); | ||||
public static implicit operator IntPtr(Operation op) => op._handle; | public static implicit operator IntPtr(Operation op) => op._handle; | ||||
@@ -5,7 +5,7 @@ using System.Text; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public static partial class c_api | |||||
public partial class c_api | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// Request that `desc` be co-located on the device where `op` | /// Request that `desc` be co-located on the device where `op` | ||||
@@ -154,12 +154,15 @@ namespace Tensorflow | |||||
/// an operation. Returns the number of output consumers (should match | /// an operation. Returns the number of output consumers (should match | ||||
/// TF_OperationOutputNumConsumers(oper_out)). | /// TF_OperationOutputNumConsumers(oper_out)). | ||||
/// </summary> | /// </summary> | ||||
/// <param name="oper_out"></param> | |||||
/// <param name="consumers"></param> | |||||
/// <param name="max_consumers"></param> | |||||
/// <param name="oper_out">TF_Output</param> | |||||
/// <param name="consumers">TF_Input*</param> | |||||
/// <param name="max_consumers">int</param> | |||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern unsafe int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers, int max_consumers); | |||||
public static extern unsafe int TF_OperationOutputConsumers(TF_Output oper_out, IntPtr consumers, int max_consumers); | |||||
[DllImport(TensorFlowLibName)] | |||||
public static extern int TF_OperationOutputConsumers(TF_Output oper_out); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out); | public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out); | ||||
@@ -26,5 +26,20 @@ namespace Tensorflow | |||||
return new Tensor(_op, 0, dtype); | return new Tensor(_op, 0, dtype); | ||||
} | } | ||||
/// <summary> | |||||
/// Return a tensor with the same shape and contents as the input tensor or value. | |||||
/// </summary> | |||||
/// <param name="input"></param> | |||||
/// <param name="name"></param> | |||||
public static Tensor identity(Tensor input, string name = "") | |||||
{ | |||||
var keywords = new Dictionary<string, object>(); | |||||
keywords.Add("input", input); | |||||
var _op = _op_def_lib._apply_op_helper("Identity", name, keywords); | |||||
return _op.outputs[0]; | |||||
} | |||||
} | } | ||||
} | } |
@@ -16,7 +16,7 @@ namespace Tensorflow | |||||
keywords.Add("x", x); | keywords.Add("x", x); | ||||
keywords.Add("y", y); | keywords.Add("y", y); | ||||
var _op = _op_def_lib._apply_op_helper("Add", name: "add", keywords: keywords); | |||||
var _op = _op_def_lib._apply_op_helper("Add", keywords: keywords); | |||||
return new Tensor(_op, 0, _op.OutputType(0)); | return new Tensor(_op, 0, _op.OutputType(0)); | ||||
} | } | ||||
@@ -38,7 +38,7 @@ namespace Tensorflow | |||||
keywords.Add("x", x); | keywords.Add("x", x); | ||||
keywords.Add("y", y); | keywords.Add("y", y); | ||||
var _op = _op_def_lib._apply_op_helper("Mul", name: "mul", keywords: keywords); | |||||
var _op = _op_def_lib._apply_op_helper("Mul", keywords: keywords); | |||||
return new Tensor(_op, 0, _op.OutputType(0)); | return new Tensor(_op, 0, _op.OutputType(0)); | ||||
} | } | ||||
@@ -62,7 +62,7 @@ namespace Tensorflow | |||||
keywords.Add("transpose_a", transpose_a); | keywords.Add("transpose_a", transpose_a); | ||||
keywords.Add("transpose_b", transpose_b); | keywords.Add("transpose_b", transpose_b); | ||||
var _op = _op_def_lib._apply_op_helper("MatMul", name: "MatMul", keywords: keywords); | |||||
var _op = _op_def_lib._apply_op_helper("MatMul", keywords: keywords); | |||||
return new Tensor(_op, 0, _op.OutputType(0)); | return new Tensor(_op, 0, _op.OutputType(0)); | ||||
} | } | ||||
@@ -73,7 +73,7 @@ namespace Tensorflow | |||||
keywords.Add("x", x); | keywords.Add("x", x); | ||||
keywords.Add("y", y); | keywords.Add("y", y); | ||||
var _op = _op_def_lib._apply_op_helper("Pow", name: "Pow", keywords: keywords); | |||||
var _op = _op_def_lib._apply_op_helper("Pow", keywords: keywords); | |||||
return new Tensor(_op, 0, _op.OutputType(0)); | return new Tensor(_op, 0, _op.OutputType(0)); | ||||
} | } | ||||
@@ -5,7 +5,7 @@ using System.Text; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public static partial class c_api | |||||
public partial class c_api | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// Destroy a session object. | /// Destroy a session object. | ||||
@@ -0,0 +1,26 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Runtime.InteropServices; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public partial class c_api | |||||
{ | |||||
public static string[] TF_OperationOutputConsumers_wrapper(TF_Output oper_out) | |||||
{ | |||||
int num_consumers = TF_OperationOutputConsumers(oper_out); | |||||
int size = Marshal.SizeOf<TF_Input>(); | |||||
var handle = Marshal.AllocHGlobal(size * num_consumers); | |||||
int num = TF_OperationOutputConsumers(oper_out, handle, num_consumers); | |||||
var consumers = new string[num_consumers]; | |||||
for (int i = 0; i < num; i++) | |||||
{ | |||||
TF_Input input = Marshal.PtrToStructure<TF_Input>(handle + i * size); | |||||
consumers[i] = Marshal.PtrToStringAnsi(TF_OperationName(input.oper)); | |||||
} | |||||
return consumers; | |||||
} | |||||
} | |||||
} |
@@ -5,7 +5,7 @@ using System.Text; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public static partial class c_api | |||||
public partial class c_api | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// Delete a previously created status object. | /// Delete a previously created status object. | ||||
@@ -34,6 +34,7 @@ namespace Tensorflow | |||||
public ulong dataTypeSize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); | public ulong dataTypeSize => _handle == IntPtr.Zero ? 0 : c_api.TF_DataTypeSize(dtype); | ||||
public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / dataTypeSize; | public ulong size => _handle == IntPtr.Zero ? 0 : bytesize / dataTypeSize; | ||||
public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); | public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle); | ||||
public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out); | |||||
public long[] shape | public long[] shape | ||||
{ | { | ||||
get | get | ||||
@@ -160,6 +161,13 @@ namespace Tensorflow | |||||
this._dtype = dtype; | this._dtype = dtype; | ||||
} | } | ||||
public List<Operation> consumers() | |||||
{ | |||||
var output = _as_tf_output(); | |||||
var consumer_names = c_api.TF_OperationOutputConsumers_wrapper(output); | |||||
return consumer_names.Select(x => Graph.OperationByName(x)).ToList(); | |||||
} | |||||
public TF_Output _as_tf_output() | public TF_Output _as_tf_output() | ||||
{ | { | ||||
return new TF_Output(op, value_index); | return new TF_Output(op, value_index); | ||||
@@ -225,7 +233,7 @@ namespace Tensorflow | |||||
} | } | ||||
} | } | ||||
return $"{name} {dtype.ToString()} {rank} {string.Join(",", shape)}"; | |||||
return $"{name} shape=({string.Join(",", shape)}) dtype={dtype.ToString()}"; | |||||
} | } | ||||
public void Dispose() | public void Dispose() | ||||
@@ -5,7 +5,7 @@ using System.Text; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
public static partial class c_api | |||||
public partial class c_api | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// Allocate and return a new Tensor. | /// Allocate and return a new Tensor. | ||||
@@ -11,6 +11,7 @@ namespace Tensorflow | |||||
public string _graph_key; | public string _graph_key; | ||||
public bool _trainable; | public bool _trainable; | ||||
public Tensor _variable; | public Tensor _variable; | ||||
public Tensor _snapshot; | |||||
public RefVariable(object initial_value, | public RefVariable(object initial_value, | ||||
bool trainable = true, | bool trainable = true, | ||||
@@ -87,9 +88,12 @@ namespace Tensorflow | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
_snapshot = gen_array_ops.identity(_variable, name = "read"); | |||||
} | } | ||||
// clear g._name_stack | |||||
ops.get_default_graph().old_stack = ""; | |||||
ops.add_to_collections(collections, this); | ops.add_to_collections(collections, this); | ||||
} | } | ||||
} | } | ||||