diff --git a/src/TensorFlowNET.Core/APIs/c_api.customize.cs b/src/TensorFlowNET.Core/APIs/c_api.customize.cs index 173bdbe2..d2aab9ac 100644 --- a/src/TensorFlowNET.Core/APIs/c_api.customize.cs +++ b/src/TensorFlowNET.Core/APIs/c_api.customize.cs @@ -9,5 +9,9 @@ namespace Tensorflow { [DllImport(TensorFlowLibName)] public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status); + [DllImport(TensorFlowLibName)] + public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); + [DllImport(TensorFlowLibName)] + public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status); } } diff --git a/src/TensorFlowNET.Core/APIs/tf.gradients.cs b/src/TensorFlowNET.Core/APIs/tf.gradients.cs index 492b1034..d722cb14 100644 --- a/src/TensorFlowNET.Core/APIs/tf.gradients.cs +++ b/src/TensorFlowNET.Core/APIs/tf.gradients.cs @@ -21,7 +21,7 @@ namespace Tensorflow { public partial class tensorflow { - internal GradientTape _tapeSet; + GradientTape _tapeSet; /// /// Record operations for automatic differentiation. diff --git a/src/TensorFlowNET.Core/Framework/importer.cs b/src/TensorFlowNET.Core/Framework/importer.cs index b569c8e1..e7e7cf39 100644 --- a/src/TensorFlowNET.Core/Framework/importer.cs +++ b/src/TensorFlowNET.Core/Framework/importer.cs @@ -79,42 +79,6 @@ namespace Tensorflow return _GatherReturnElements(return_elements, graph, results); } - //private static ITensorOrOperation[] _import_graph_def_internal(GraphDef graph_def, Dictionary input_map = null, string[] return_elements = null, - // bool validate_colocation_constraints = true, string name = null, OpList producer_op_list = null) - //{ - // graph_def = _ProcessGraphDefParam(graph_def); - // input_map = _ProcessInputMapParam(input_map); - // return_elements = _ProcessReturnElementsParam(return_elements); - - // if(producer_op_list is not null) - // { - // _RemoveDefaultAttrs(producer_op_list, graph_def); - // } - - // var graph = ops.get_default_graph(); - // string prefix = null; - // tf_with(ops.name_scope(name, "import", input_map.Values), scope => - // { - // if (scope is not null) - // { - // Debug.Assert(scope.scope_name.EndsWith("/")); - // prefix = scope.scope_name[scope.scope_name.Length - 1].ToString(); - // } - // else - // { - // prefix = ""; - // } - - // input_map = _ConvertInputMapValues(name, input_map); - // }); - - // var scope_options = c_api_util.ScopedTFImportGraphDefOptions(); - // var options = scope_options.Options; - // _PopulateTFImportGraphDefOptions(scope_options, prefix, input_map, return_elements, validate_colocation_constraints); - - - //} - private static ITensorOrOperation[] _GatherReturnElements(string[] requested_return_elements, Graph graph, TF_ImportGraphDefResults results) diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index 5c2d3a8d..88dce7d9 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -305,7 +305,7 @@ namespace Tensorflow.Functions private Tensors _build_call_outputs(Tensors result) { - // TODO(Rinne): dwal with `func_graph.structured_outputs` + // TODO(Rinne): deal with `func_graph.structured_outputs` return result; } diff --git a/src/TensorFlowNET.Core/Functions/Function.cs b/src/TensorFlowNET.Core/Functions/Function.cs index ea1b3eec..e301048a 100644 --- a/src/TensorFlowNET.Core/Functions/Function.cs +++ b/src/TensorFlowNET.Core/Functions/Function.cs @@ -4,7 +4,7 @@ using Tensorflow.Train; namespace Tensorflow { - public class Function: Trackable + public class Function: Trackable, IGenericFunction { #pragma warning disable CS0169 // The field 'Function._handle' is never used private IntPtr _handle; @@ -34,6 +34,11 @@ namespace Tensorflow return result; } + public ConcreteFunction get_concrete_function(params Tensor[] args) + { + return _get_concrete_function_garbage_collected(args); + } + protected virtual Tensors _call(Tensors inputs) { if(_variable_creation_fn is not null) @@ -57,6 +62,18 @@ namespace Tensorflow return false; } + protected ConcreteFunction _get_concrete_function_garbage_collected(Tensor[] args) + { + if(_variable_creation_fn is null) + { + _initialize(args); + // TODO(Rinne): _initialize_uninitialized_variables + } + + var concrete = _variable_creation_fn._get_concrete_function_internal_garbage_collected(args); + return concrete; + } + private void _initialize(Tensor[] args) { _variable_creation_fn = _compiler(_csharp_function); diff --git a/src/TensorFlowNET.Core/Functions/IGenericFunction.cs b/src/TensorFlowNET.Core/Functions/IGenericFunction.cs index be6a3b2a..f046731b 100644 --- a/src/TensorFlowNET.Core/Functions/IGenericFunction.cs +++ b/src/TensorFlowNET.Core/Functions/IGenericFunction.cs @@ -6,7 +6,7 @@ namespace Tensorflow.Functions { public interface IGenericFunction { - object[] Apply(params object[] args); - ConcreteFunction get_concrete_function(params object[] args); + Tensors Apply(Tensors args); + ConcreteFunction get_concrete_function(params Tensor[] args); } } diff --git a/src/TensorFlowNET.Core/Functions/TracingCompiler.cs b/src/TensorFlowNET.Core/Functions/TracingCompiler.cs index fb109595..aa30c9f7 100644 --- a/src/TensorFlowNET.Core/Functions/TracingCompiler.cs +++ b/src/TensorFlowNET.Core/Functions/TracingCompiler.cs @@ -49,7 +49,7 @@ namespace Tensorflow.Functions private (ConcreteFunction, Tensor[]) _maybe_define_function(Tensor[] args) { - var lookup_func_key = male_cache_key(args); + var lookup_func_key = make_cache_key(args); if(_function_cache.TryGetValue(lookup_func_key, out var concrete_function)) { return (concrete_function, args); @@ -71,7 +71,7 @@ namespace Tensorflow.Functions return concrete_function; } - private static string male_cache_key(Tensor[] inputs) + private static string make_cache_key(Tensor[] inputs) { //string res = ""; //foreach (var input in inputs) diff --git a/src/TensorFlowNET.Core/Gradients/gradients_util.cs b/src/TensorFlowNET.Core/Gradients/gradients_util.cs index 71d3d9ca..1fb32778 100644 --- a/src/TensorFlowNET.Core/Gradients/gradients_util.cs +++ b/src/TensorFlowNET.Core/Gradients/gradients_util.cs @@ -727,7 +727,7 @@ namespace Tensorflow private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor[] out_grads, Action func, Func grad_fn) { - //scope = scope.TrimEnd('/').Replace('/', '_'); + // scope = scope.TrimEnd('/').Replace('/', '_'); return grad_fn(op, out_grads); } diff --git a/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs b/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs index b7f793ee..cc283db4 100644 --- a/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs +++ b/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs @@ -38,21 +38,6 @@ namespace Tensorflow.Graphs // make function as an Operation by autograph // need to restore mode when exits - - //var func_graph = new FuncGraph(func_name); - //func_graph.as_default(); - //var input_placeholders = args.Arguments.Select(x => tf.placeholder(((Tensor)x).dtype)).ToArray(); - //// stop the function from recursive call. - //already_in_boundary = true; - //var outputs = args.Method.Invoke(args.Instance, input_placeholders) as Tensors; - //already_in_boundary = false; - - //var opers = func_graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); - //func_graph.ToGraph(opers, - // input_placeholders, - // outputs, - // null); - //func_graph.Exit(); function = new ConcreteFunction(func_name); function.Enter(); diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs index e5f55631..900db8ca 100644 --- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs @@ -208,9 +208,5 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, SafeStatusHandle status); - [DllImport(TensorFlowLibName)] - public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output); - [DllImport(TensorFlowLibName)] - public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status); } } diff --git a/src/TensorFlowNET.Core/Operations/gen_ops.cs b/src/TensorFlowNET.Core/Operations/gen_ops.cs index 8f8b2f11..ba59b367 100644 --- a/src/TensorFlowNET.Core/Operations/gen_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_ops.cs @@ -10060,7 +10060,7 @@ namespace Tensorflow.Operations } catch (Exception) { - Console.WriteLine(); + } try { @@ -10068,7 +10068,7 @@ namespace Tensorflow.Operations } catch (Exception) { - Console.WriteLine(); + } }