@@ -35,6 +35,7 @@ namespace Tensorflow.Contexts | |||
public string ScopeName { get; set; } = ""; | |||
bool initialized = false; | |||
ContextSwitchStack context_switches; | |||
public FunctionCallOptions FunctionCallOptions { get; } | |||
public SafeContextHandle Handle { get; } | |||
@@ -44,6 +45,7 @@ namespace Tensorflow.Contexts | |||
status.Check(true); | |||
context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE, false); | |||
initialized = true; | |||
FunctionCallOptions = new FunctionCallOptions(); | |||
} | |||
/// <summary> | |||
@@ -0,0 +1,20 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Google.Protobuf; | |||
using Google.Protobuf.Collections; | |||
namespace Tensorflow.Contexts | |||
{ | |||
public class FunctionCallOptions | |||
{ | |||
public string config_proto_serialized() | |||
{ | |||
var config = new ConfigProto | |||
{ | |||
AllowSoftPlacement = true, | |||
}; | |||
return config.ToByteString().ToStringUtf8(); | |||
} | |||
} | |||
} |
@@ -371,7 +371,7 @@ namespace Tensorflow.Eager | |||
switch (type) | |||
{ | |||
case TF_AttrType.TF_ATTR_STRING: | |||
c_api.TFE_OpSetAttrString(op, key, value.ToString(), (uint)value.ToString().Length); | |||
c_api.TFE_OpSetAttrString(op, key, value.ToString(), (ulong)value.ToString().Length); | |||
break; | |||
case TF_AttrType.TF_ATTR_TYPE: | |||
c_api.TFE_OpSetAttrType(op, key, (TF_DataType)value); | |||
@@ -241,7 +241,7 @@ namespace Tensorflow | |||
/// <param name="value">const void*</param> | |||
/// <param name="length">size_t</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpSetAttrString(SafeOpHandle op, string attr_name, string value, uint length); | |||
public static extern void TFE_OpSetAttrString(SafeOpHandle op, string attr_name, string value, ulong length); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpSetAttrTypeList(SafeOpHandle op, string attr_name, TF_DataType[] values, int num_values); | |||
@@ -14,6 +14,7 @@ namespace Tensorflow.Functions | |||
{ | |||
IntPtr _handle; | |||
FuncGraph func_graph; | |||
public Tensor[] CapturedInputs => func_graph.external_captures(); | |||
public string Name | |||
{ | |||
@@ -38,6 +39,8 @@ namespace Tensorflow.Functions | |||
public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs) | |||
{ | |||
func_graph = graph; | |||
ToGraph(graph.Inputs, graph.Outputs); | |||
} | |||
public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) | |||
@@ -124,6 +127,21 @@ namespace Tensorflow.Functions | |||
return flat_outputs; | |||
} | |||
public Tensor[] CallFlat(Tensor[] args, Tensor[] captured_inputs) | |||
{ | |||
var new_args = new List<Tensor>(); | |||
new_args.AddRange(args); | |||
new_args.AddRange(captured_inputs); | |||
args = new_args.ToArray(); | |||
var attrs = new object[] | |||
{ | |||
"executor_type", "", | |||
"config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() | |||
}; | |||
return tf.Runner.Execute(tf.Context, func_graph.FuncName, 1, args, attrs); | |||
} | |||
ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | |||
{ | |||
var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | |||
@@ -48,7 +48,7 @@ namespace Tensorflow.Functions | |||
getBackwardFunction: () => backward_function); | |||
} | |||
(BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors flat_outputs) | |||
(BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs) | |||
{ | |||
BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | |||
{ | |||
@@ -61,10 +61,11 @@ namespace Tensorflow.Functions | |||
processed_args.add(arg); | |||
input_index += 1; | |||
} | |||
return output_grads;// backward.Invoke(processed_args.ToArray()); | |||
return backward.CallFlat(processed_args.ToArray(), outputs); | |||
}; | |||
return (_backward_function_wrapper, flat_outputs); | |||
return (_backward_function_wrapper, outputs); | |||
} | |||
protected (EagerDefinedFunction, FuncGraph, ConcreteFunction, List<int>, int) | |||
@@ -82,7 +83,7 @@ namespace Tensorflow.Functions | |||
} | |||
var gradients_wrt_outputs = new List<Tensor>(); | |||
var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}{_func_graph.FuncName}_{ops.uid()}"); | |||
var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{ops.uid()}"); | |||
foreach (var output in trainable_outputs) | |||
gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape)); | |||
var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), | |||
@@ -90,16 +91,15 @@ namespace Tensorflow.Functions | |||
grad_ys: gradients_wrt_outputs.ToArray(), | |||
src_graph: _func_graph); | |||
tf.Context.restore_mode(); | |||
var forward_function_name = $"{_FORWARD_PREFIX}{_func_graph.FuncName}_{ops.uid()}"; | |||
var forward_function_name = $"{_FORWARD_PREFIX}_{ops.uid()}"; | |||
var backward_function_attr = new Dictionary<string, string>(); | |||
backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; | |||
gradients_wrt_outputs.append(backwards_graph.internal_captures()); | |||
backwards_graph.Inputs = gradients_wrt_outputs; | |||
backwards_graph.Outputs = gradients_wrt_inputs; | |||
var backward_function = new ConcreteFunction(backwards_graph, backward_function_attr); | |||
var forward_function_attr = new Dictionary<string, string>(); | |||
forward_function_attr[BACKWARD_FUNCTION_ATTRIBUTE_NAME] = backward_function.Name; | |||
var forward_function = new EagerDefinedFunction(forward_function_name, _func_graph, | |||
@@ -49,14 +49,14 @@ namespace Tensorflow | |||
RegisterGradientFunction(m.GetCustomAttribute<RegisterGradient>().Name, | |||
(oper, out_grads) => | |||
{ | |||
tf.Logger.Debug($"Caculate Gradient: {m.Name}"); | |||
tf.Logger.Debug($"Caculate Gradient: {oper.name} {m.Name}"); | |||
var results = g.InvokeMember(m.Name, | |||
BindingFlags.InvokeMethod, | |||
null, | |||
null, | |||
args: new object[] { oper, out_grads }) as Tensor[]; | |||
foreach (var result in results.Where(x => x != null)) | |||
tf.Logger.Debug($"{result.TensorShape}"); | |||
tf.Logger.Debug($"Gradient: {result.name} {result.TensorShape}"); | |||
return results; | |||
} | |||
); | |||
@@ -26,7 +26,9 @@ namespace Tensorflow.Graphs | |||
public Tensors Outputs { get; set; } | |||
public Dictionary<string, string> Attrs { get; set; } | |||
Dictionary<long, (Tensor, Tensor)> _captures = new Dictionary<long, (Tensor, Tensor)>(); | |||
// new Dictionary<long, (Tensor, Tensor)> _captures = new Dictionary<long, (Tensor, Tensor)>(); | |||
// public new Tensor[] external_captures => _captures.Values.Select(x => x.Item1).ToArray(); | |||
/// <summary> | |||
/// Construct a new FuncGraph. | |||
/// </summary> | |||
@@ -129,7 +131,7 @@ namespace Tensorflow.Graphs | |||
Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null) | |||
{ | |||
Tensor placeholder = null; | |||
if (!_captures.ContainsKey(tensor.Id)) | |||
if (!_captures.Contains(tensor.Id)) | |||
{ | |||
placeholder = _create_substitute_placeholder(tensor, | |||
name: name, | |||
@@ -139,7 +141,7 @@ namespace Tensorflow.Graphs | |||
} | |||
else | |||
{ | |||
placeholder = _captures[tensor.Id].Item1; | |||
placeholder = (((Tensor, Tensor))_captures[tensor.Id]).Item2; | |||
} | |||
BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | |||
@@ -557,16 +557,16 @@ namespace Tensorflow | |||
public Tensor[] external_captures() | |||
{ | |||
Tensor[] captures = new Tensor[this._captures.Count]; | |||
ICollection inner = this._captures.Keys; // c[0] | |||
Tensor[] captures = new Tensor[_captures.Count]; | |||
ICollection inner = _captures.Keys; // c[0] | |||
inner.CopyTo(captures, 0); | |||
return captures; | |||
} | |||
public Tensor[] internal_captures() | |||
{ | |||
Tensor[] captures = new Tensor[this._captures.Count]; | |||
ICollection inner = this._captures.Values; // c[1] | |||
Tensor[] captures = new Tensor[_captures.Count]; | |||
ICollection inner = _captures.Values; // c[1] | |||
inner.CopyTo(captures, 0); | |||
return captures; | |||
} | |||
@@ -340,7 +340,7 @@ namespace Tensorflow.Keras.Engine | |||
tf.Logger.Debug($"{depth}: {node.Layer}: {node.Layer.Name}"); | |||
var outputs = node.Layer.Apply(layer_inputs, is_training: training); | |||
foreach (var output in outputs.Where(x => x != null)) | |||
tf.Logger.Debug($"{output.TensorShape}"); | |||
tf.Logger.Debug($"{depth}: {node.Layer}: {node.Layer.Name} {output.TensorShape}"); | |||
// Update tensor_dict for next input | |||
foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs)) | |||
tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y)); | |||