Browse Source

Override create_op in FuncGraph.

tags/v0.30
Oceania2018 4 years ago
parent
commit
4d86da6650
3 changed files with 138 additions and 4 deletions
  1. +14
    -0
      src/TensorFlowNET.Core/Exceptions/InaccessibleTensorError.cs
  2. +123
    -3
      src/TensorFlowNET.Core/Graphs/FuncGraph.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Graphs/Graph.cs

+ 14
- 0
src/TensorFlowNET.Core/Exceptions/InaccessibleTensorError.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Exceptions
{
public class InaccessibleTensorError : TensorflowException
{
public InaccessibleTensorError(string message) : base(message)
{

}
}
}

+ 123
- 3
src/TensorFlowNET.Core/Graphs/FuncGraph.cs View File

@@ -1,6 +1,9 @@
using System;
using Google.Protobuf;
using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Eager;
using Tensorflow.Exceptions;
using static Tensorflow.Binding;

namespace Tensorflow.Graphs
@@ -21,7 +24,9 @@ namespace Tensorflow.Graphs

public Tensors Inputs { get; set; }
public Tensors Outputs { get; set; }
public Dictionary<string, string> Attrs { get; set; }

Dictionary<long, (Tensor, Tensor)> _captures = new Dictionary<long, (Tensor, Tensor)>();
/// <summary>
/// Construct a new FuncGraph.
/// </summary>
@@ -34,10 +39,14 @@ namespace Tensorflow.Graphs
as_default();
}

public FuncGraph(IntPtr handle, string name)
public FuncGraph(IntPtr handle, string name, Dictionary<string, string> attrs) : base()
{
outer_graph = ops.get_default_graph();
func_name = name;
Attrs = attrs;
// Will to test if FuncGraph has memory leak
// c_api.TF_DeleteGraph(_handle);
_handle = handle;

tf.Context.graph_mode();
as_default();
@@ -63,6 +72,8 @@ namespace Tensorflow.Graphs
status.Handle);
status.Check(true);

SetAttrs();

c_api.TF_GraphCopyFunction(outer_graph, func_handle, IntPtr.Zero, status.Handle);
status.Check(true);

@@ -73,13 +84,122 @@ namespace Tensorflow.Graphs

Inputs = inputs;
// mark_as_return
Outputs = outputs.Select(x => array_ops.identity(x)).ToArray();
Outputs = outputs;// .Select(x => array_ops.identity(x)).ToArray();

tf.Context.restore_mode();

return func_handle;
}

public override Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, TF_DataType[] input_types = null, string name = null, Dictionary<string, AttrValue> attrs = null, OpDef op_def = null, bool compute_device = true)
{
foreach(var (i, inp) in enumerate(inputs))
inputs[i] = capture(inp);

return base.create_op(op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device);
}

Tensor capture(Tensor tensor, string name = null, TF_DataType shape = TF_DataType.DtInvalid)
{
if(tensor is EagerTensor)
{
throw new NotImplementedException("");
}

if(tensor.graph != this)
{
if (name == null)
name = tensor.op.name;
var inner_graph = tensor.graph;
while(inner_graph != null && inner_graph is FuncGraph inner_func_graph)
{
if (inner_graph == this)
throw new InaccessibleTensorError($"The tensor '{tensor.name}' cannot be accessed here: it is defined" +
" in another function or code block. Use return values," +
" explicit Python locals or TensorFlow collections to access" +
$" it. Defined in: {tensor.graph.graph_key}; accessed from: {graph_key}.");
inner_graph = inner_func_graph.outer_graph;
}
return _capture_helper(tensor, name);
}

return tensor;
}

Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null)
{
Tensor placeholder = null;
if (!_captures.ContainsKey(tensor.Id))
{
placeholder = _create_substitute_placeholder(tensor,
name: name,
dtype: tensor.dtype,
shape: shape);
add_capture(tensor, placeholder);
}
else
{
placeholder = _captures[tensor.Id].Item1;
}

BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) =>
{
return output_grads;
};

tf.Runner.RecordGradient("captured_value",
new[] { placeholder }, null,
new[] { tensor },
getBackwardFunction: () => _backward_function_wrapper
/*getForwardFunction: forward_function*/);

return placeholder;
}

void add_capture(Tensor tensor, Tensor placeholder)
{
_captures[tensor.Id] = (tensor, placeholder);
if (Inputs == null)
Inputs = new Tensors(placeholder);
else
{
var inputs = Inputs.ToList();
inputs.Add(placeholder);
Inputs = new Tensors(inputs.ToArray());
}
}

Tensor _create_substitute_placeholder(Tensor value,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid,
TensorShape shape = null)
{
if (shape is null)
shape = value.shape;
if (dtype == TF_DataType.DtInvalid)
dtype = value.dtype;

var placeholder = tf_with(ops.control_dependencies(null), ctl => array_ops.placeholder(dtype, shape: shape, name: name));
// custom_gradient.copy_handle_data(value, placeholder)
return placeholder;
}

void SetAttrs()
{
if (Attrs == null)
return;

foreach (var (_name, attr_value) in enumerate(Attrs))
{
var serialized = new AttrValue
{
S = ByteString.CopyFromUtf8(attr_value)
}.ToByteArray();
c_api.TF_FunctionSetAttrValueProto(func_handle, _name, serialized, serialized.Length, tf.Status.Handle);
tf.Status.Check(true);
}
}

protected override void DisposeManagedResources()
{
base.DisposeManagedResources();


+ 1
- 1
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -262,7 +262,7 @@ namespace Tensorflow
throw new RuntimeError("Graph is finalized and cannot be modified.");
}

public Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes,
public virtual Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes,
TF_DataType[] input_types = null, string name = null,
Dictionary<string, AttrValue> attrs = null, OpDef op_def = null,
bool compute_device = true)


Loading…
Cancel
Save