diff --git a/src/TensorFlowNET.Core/APIs/tf.io.cs b/src/TensorFlowNET.Core/APIs/tf.io.cs index 394357de..6d1f7d17 100644 --- a/src/TensorFlowNET.Core/APIs/tf.io.cs +++ b/src/TensorFlowNET.Core/APIs/tf.io.cs @@ -24,7 +24,7 @@ namespace Tensorflow public GFile gfile = new GFile(); public Tensor read_file(string filename, string name = null) => gen_io_ops.read_file(filename, name); - public void import_graph_def(GraphDef graph_def, + public ITensorOrOperation[] import_graph_def(GraphDef graph_def, Dictionary input_map = null, string[] return_elements = null, string name = null, diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index 31ea0d84..eaeefd73 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -95,6 +95,9 @@ namespace Tensorflow throw new NotImplementedException("len() not implemented for type: " + a.GetType()); } + public static float min(float a, float b) + => Math.Min(a, b); + public static T[] list(IEnumerable list) => list.ToArray(); diff --git a/src/TensorFlowNET.Core/Framework/importer.py.cs b/src/TensorFlowNET.Core/Framework/importer.cs similarity index 75% rename from src/TensorFlowNET.Core/Framework/importer.py.cs rename to src/TensorFlowNET.Core/Framework/importer.cs index b6c011c4..b4bf1c73 100644 --- a/src/TensorFlowNET.Core/Framework/importer.py.cs +++ b/src/TensorFlowNET.Core/Framework/importer.cs @@ -54,6 +54,7 @@ namespace Tensorflow input_map = _ConvertInputMapValues(name, input_map); }); + TF_ImportGraphDefResults results = null; var bytes = graph_def.ToByteString().ToArray(); using (var buffer = c_api_util.tf_buffer(bytes)) using (var scoped_options = c_api_util.ScopedTFImportGraphDefOptions()) @@ -61,9 +62,8 @@ namespace Tensorflow { _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); // need to create a class ImportGraphDefWithResults with IDisposal - var results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status); + results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status); status.Check(true); - c_api.TF_DeleteImportGraphDefResults(results); } _ProcessNewOps(graph); @@ -71,7 +71,34 @@ namespace Tensorflow if (return_elements == null) return null; else - throw new NotImplementedException("import_graph_def return_elements"); + return _GatherReturnElements(return_elements, graph, results); + } + + private static ITensorOrOperation[] _GatherReturnElements(string[] requested_return_elements, + Graph graph, + TF_ImportGraphDefResults results) + { + var return_outputs = results.return_tensors; + var return_opers = results.return_opers; + + var combined_return_elements = new List(); + int outputs_idx = 0; + int opers_idx = 0; + foreach(var name in requested_return_elements) + { + if (name.Contains(":")) + { + combined_return_elements.append(graph.get_tensor_by_tf_output(return_outputs[outputs_idx])); + outputs_idx += 1; + } + else + { + throw new NotImplementedException("_GatherReturnElements"); + // combined_return_elements.append(graph._get_operation_by_tf_operation(return_opers[opers_idx])); + } + } + + return combined_return_elements.ToArray(); } private static void _ProcessNewOps(Graph graph) @@ -100,8 +127,29 @@ namespace Tensorflow foreach (var name in return_elements) { - throw new NotImplementedException("_PopulateTFImportGraphDefOptions"); + if(name.Contains(":")) + { + var (op_name, index) = _ParseTensorName(name); + c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index); + } + else + { + c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, name); + } } + + // c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(options, validate_colocation_constraints); + } + + private static (string, int) _ParseTensorName(string tensor_name) + { + var components = tensor_name.Split(':'); + if (components.Length == 2) + return (components[0], int.Parse(components[1])); + else if (components.Length == 1) + return (components[0], 0); + else + throw new ValueError($"Cannot convert {tensor_name} to a tensor name."); } public static Dictionary _ConvertInputMapValues(string name, Dictionary input_map) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 1f62295a..48420d18 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -494,6 +494,12 @@ namespace Tensorflow c_api.TF_DeleteGraph(handle); } + public Tensor get_tensor_by_tf_output(TF_Output tf_output) + { + var op = _get_operation_by_tf_operation(tf_output.oper); + return op.outputs[tf_output.index]; + } + /// /// Returns the with the given . /// This method may be called concurrently from multiple threads. diff --git a/src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs b/src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs index e9ad8474..71ea5306 100644 --- a/src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs +++ b/src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs @@ -3,13 +3,62 @@ using System.Runtime.InteropServices; namespace Tensorflow { - [StructLayout(LayoutKind.Sequential)] - public struct TF_ImportGraphDefResults + public class TF_ImportGraphDefResults : DisposableObject { - public IntPtr return_tensors; - public IntPtr return_nodes; + /*public IntPtr return_nodes; public IntPtr missing_unused_key_names; public IntPtr missing_unused_key_indexes; - public IntPtr missing_unused_key_names_data; + public IntPtr missing_unused_key_names_data;*/ + + public TF_ImportGraphDefResults(IntPtr handle) + { + _handle = handle; + } + + public TF_Output[] return_tensors + { + get + { + IntPtr return_output_handle = IntPtr.Zero; + int num_outputs = -1; + c_api.TF_ImportGraphDefResultsReturnOutputs(_handle, ref num_outputs, ref return_output_handle); + TF_Output[] return_outputs = new TF_Output[num_outputs]; + unsafe + { + var tf_output_ptr = (TF_Output*)return_output_handle; + for (int i = 0; i < num_outputs; i++) + return_outputs[i] = *(tf_output_ptr + i); + return return_outputs; + } + } + } + + public TF_Operation[] return_opers + { + get + { + return new TF_Operation[0]; + /*TF_Operation return_output_handle = new TF_Operation(); + int num_outputs = -1; + c_api.TF_ImportGraphDefResultsReturnOperations(_handle, ref num_outputs, ref return_output_handle); + TF_Operation[] return_outputs = new TF_Operation[num_outputs]; + unsafe + { + var tf_output_ptr = (TF_Operation*)return_output_handle; + for (int i = 0; i < num_outputs; i++) + return_outputs[i] = *(tf_output_ptr + i); + return return_outputs; + }*/ + } + } + + public static implicit operator TF_ImportGraphDefResults(IntPtr handle) + => new TF_ImportGraphDefResults(handle); + + public static implicit operator IntPtr(TF_ImportGraphDefResults results) + => results._handle; + + protected override void DisposeUnmanagedResources(IntPtr handle) + => c_api.TF_DeleteImportGraphDefResults(handle); } } diff --git a/src/TensorFlowNET.Core/Status/Status.cs b/src/TensorFlowNET.Core/Status/Status.cs index 21ff6f6e..928f39f2 100644 --- a/src/TensorFlowNET.Core/Status/Status.cs +++ b/src/TensorFlowNET.Core/Status/Status.cs @@ -65,9 +65,7 @@ namespace Tensorflow } public static implicit operator IntPtr(Status status) - { - return status._handle; - } + => status._handle; protected override void DisposeUnmanagedResources(IntPtr handle) => TF_DeleteStatus(handle);