@@ -24,7 +24,7 @@ namespace Tensorflow | |||||
public GFile gfile = new GFile(); | public GFile gfile = new GFile(); | ||||
public Tensor read_file(string filename, string name = null) => gen_io_ops.read_file(filename, name); | public Tensor read_file(string filename, string name = null) => gen_io_ops.read_file(filename, name); | ||||
public void import_graph_def(GraphDef graph_def, | |||||
public ITensorOrOperation[] import_graph_def(GraphDef graph_def, | |||||
Dictionary<string, Tensor> input_map = null, | Dictionary<string, Tensor> input_map = null, | ||||
string[] return_elements = null, | string[] return_elements = null, | ||||
string name = null, | string name = null, | ||||
@@ -95,6 +95,9 @@ namespace Tensorflow | |||||
throw new NotImplementedException("len() not implemented for type: " + a.GetType()); | throw new NotImplementedException("len() not implemented for type: " + a.GetType()); | ||||
} | } | ||||
public static float min(float a, float b) | |||||
=> Math.Min(a, b); | |||||
public static T[] list<T>(IEnumerable<T> list) | public static T[] list<T>(IEnumerable<T> list) | ||||
=> list.ToArray(); | => list.ToArray(); | ||||
@@ -54,6 +54,7 @@ namespace Tensorflow | |||||
input_map = _ConvertInputMapValues(name, input_map); | input_map = _ConvertInputMapValues(name, input_map); | ||||
}); | }); | ||||
TF_ImportGraphDefResults results = null; | |||||
var bytes = graph_def.ToByteString().ToArray(); | var bytes = graph_def.ToByteString().ToArray(); | ||||
using (var buffer = c_api_util.tf_buffer(bytes)) | using (var buffer = c_api_util.tf_buffer(bytes)) | ||||
using (var scoped_options = c_api_util.ScopedTFImportGraphDefOptions()) | using (var scoped_options = c_api_util.ScopedTFImportGraphDefOptions()) | ||||
@@ -61,9 +62,8 @@ namespace Tensorflow | |||||
{ | { | ||||
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); | _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); | ||||
// need to create a class ImportGraphDefWithResults with IDisposal | // need to create a class ImportGraphDefWithResults with IDisposal | ||||
var results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status); | |||||
results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status); | |||||
status.Check(true); | status.Check(true); | ||||
c_api.TF_DeleteImportGraphDefResults(results); | |||||
} | } | ||||
_ProcessNewOps(graph); | _ProcessNewOps(graph); | ||||
@@ -71,7 +71,34 @@ namespace Tensorflow | |||||
if (return_elements == null) | if (return_elements == null) | ||||
return null; | return null; | ||||
else | else | ||||
throw new NotImplementedException("import_graph_def return_elements"); | |||||
return _GatherReturnElements(return_elements, graph, results); | |||||
} | |||||
private static ITensorOrOperation[] _GatherReturnElements(string[] requested_return_elements, | |||||
Graph graph, | |||||
TF_ImportGraphDefResults results) | |||||
{ | |||||
var return_outputs = results.return_tensors; | |||||
var return_opers = results.return_opers; | |||||
var combined_return_elements = new List<ITensorOrOperation>(); | |||||
int outputs_idx = 0; | |||||
int opers_idx = 0; | |||||
foreach(var name in requested_return_elements) | |||||
{ | |||||
if (name.Contains(":")) | |||||
{ | |||||
combined_return_elements.append(graph.get_tensor_by_tf_output(return_outputs[outputs_idx])); | |||||
outputs_idx += 1; | |||||
} | |||||
else | |||||
{ | |||||
throw new NotImplementedException("_GatherReturnElements"); | |||||
// combined_return_elements.append(graph._get_operation_by_tf_operation(return_opers[opers_idx])); | |||||
} | |||||
} | |||||
return combined_return_elements.ToArray(); | |||||
} | } | ||||
private static void _ProcessNewOps(Graph graph) | private static void _ProcessNewOps(Graph graph) | ||||
@@ -100,8 +127,29 @@ namespace Tensorflow | |||||
foreach (var name in return_elements) | foreach (var name in return_elements) | ||||
{ | { | ||||
throw new NotImplementedException("_PopulateTFImportGraphDefOptions"); | |||||
if(name.Contains(":")) | |||||
{ | |||||
var (op_name, index) = _ParseTensorName(name); | |||||
c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index); | |||||
} | |||||
else | |||||
{ | |||||
c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, name); | |||||
} | |||||
} | } | ||||
// c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(options, validate_colocation_constraints); | |||||
} | |||||
private static (string, int) _ParseTensorName(string tensor_name) | |||||
{ | |||||
var components = tensor_name.Split(':'); | |||||
if (components.Length == 2) | |||||
return (components[0], int.Parse(components[1])); | |||||
else if (components.Length == 1) | |||||
return (components[0], 0); | |||||
else | |||||
throw new ValueError($"Cannot convert {tensor_name} to a tensor name."); | |||||
} | } | ||||
public static Dictionary<string, Tensor> _ConvertInputMapValues(string name, Dictionary<string, Tensor> input_map) | public static Dictionary<string, Tensor> _ConvertInputMapValues(string name, Dictionary<string, Tensor> input_map) |
@@ -494,6 +494,12 @@ namespace Tensorflow | |||||
c_api.TF_DeleteGraph(handle); | c_api.TF_DeleteGraph(handle); | ||||
} | } | ||||
public Tensor get_tensor_by_tf_output(TF_Output tf_output) | |||||
{ | |||||
var op = _get_operation_by_tf_operation(tf_output.oper); | |||||
return op.outputs[tf_output.index]; | |||||
} | |||||
/// <summary> | /// <summary> | ||||
/// Returns the <see cref="Tensor"/> with the given <paramref name="name"/>. | /// Returns the <see cref="Tensor"/> with the given <paramref name="name"/>. | ||||
/// This method may be called concurrently from multiple threads. | /// This method may be called concurrently from multiple threads. | ||||
@@ -3,13 +3,62 @@ using System.Runtime.InteropServices; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
[StructLayout(LayoutKind.Sequential)] | |||||
public struct TF_ImportGraphDefResults | |||||
public class TF_ImportGraphDefResults : DisposableObject | |||||
{ | { | ||||
public IntPtr return_tensors; | |||||
public IntPtr return_nodes; | |||||
/*public IntPtr return_nodes; | |||||
public IntPtr missing_unused_key_names; | public IntPtr missing_unused_key_names; | ||||
public IntPtr missing_unused_key_indexes; | public IntPtr missing_unused_key_indexes; | ||||
public IntPtr missing_unused_key_names_data; | |||||
public IntPtr missing_unused_key_names_data;*/ | |||||
public TF_ImportGraphDefResults(IntPtr handle) | |||||
{ | |||||
_handle = handle; | |||||
} | |||||
public TF_Output[] return_tensors | |||||
{ | |||||
get | |||||
{ | |||||
IntPtr return_output_handle = IntPtr.Zero; | |||||
int num_outputs = -1; | |||||
c_api.TF_ImportGraphDefResultsReturnOutputs(_handle, ref num_outputs, ref return_output_handle); | |||||
TF_Output[] return_outputs = new TF_Output[num_outputs]; | |||||
unsafe | |||||
{ | |||||
var tf_output_ptr = (TF_Output*)return_output_handle; | |||||
for (int i = 0; i < num_outputs; i++) | |||||
return_outputs[i] = *(tf_output_ptr + i); | |||||
return return_outputs; | |||||
} | |||||
} | |||||
} | |||||
public TF_Operation[] return_opers | |||||
{ | |||||
get | |||||
{ | |||||
return new TF_Operation[0]; | |||||
/*TF_Operation return_output_handle = new TF_Operation(); | |||||
int num_outputs = -1; | |||||
c_api.TF_ImportGraphDefResultsReturnOperations(_handle, ref num_outputs, ref return_output_handle); | |||||
TF_Operation[] return_outputs = new TF_Operation[num_outputs]; | |||||
unsafe | |||||
{ | |||||
var tf_output_ptr = (TF_Operation*)return_output_handle; | |||||
for (int i = 0; i < num_outputs; i++) | |||||
return_outputs[i] = *(tf_output_ptr + i); | |||||
return return_outputs; | |||||
}*/ | |||||
} | |||||
} | |||||
public static implicit operator TF_ImportGraphDefResults(IntPtr handle) | |||||
=> new TF_ImportGraphDefResults(handle); | |||||
public static implicit operator IntPtr(TF_ImportGraphDefResults results) | |||||
=> results._handle; | |||||
protected override void DisposeUnmanagedResources(IntPtr handle) | |||||
=> c_api.TF_DeleteImportGraphDefResults(handle); | |||||
} | } | ||||
} | } |
@@ -65,9 +65,7 @@ namespace Tensorflow | |||||
} | } | ||||
public static implicit operator IntPtr(Status status) | public static implicit operator IntPtr(Status status) | ||||
{ | |||||
return status._handle; | |||||
} | |||||
=> status._handle; | |||||
protected override void DisposeUnmanagedResources(IntPtr handle) | protected override void DisposeUnmanagedResources(IntPtr handle) | ||||
=> TF_DeleteStatus(handle); | => TF_DeleteStatus(handle); | ||||