@@ -24,7 +24,7 @@ namespace Tensorflow | |||
public GFile gfile = new GFile(); | |||
public Tensor read_file(string filename, string name = null) => gen_io_ops.read_file(filename, name); | |||
public void import_graph_def(GraphDef graph_def, | |||
public ITensorOrOperation[] import_graph_def(GraphDef graph_def, | |||
Dictionary<string, Tensor> input_map = null, | |||
string[] return_elements = null, | |||
string name = null, | |||
@@ -95,6 +95,9 @@ namespace Tensorflow | |||
throw new NotImplementedException("len() not implemented for type: " + a.GetType()); | |||
} | |||
public static float min(float a, float b) | |||
=> Math.Min(a, b); | |||
public static T[] list<T>(IEnumerable<T> list) | |||
=> list.ToArray(); | |||
@@ -54,6 +54,7 @@ namespace Tensorflow | |||
input_map = _ConvertInputMapValues(name, input_map); | |||
}); | |||
TF_ImportGraphDefResults results = null; | |||
var bytes = graph_def.ToByteString().ToArray(); | |||
using (var buffer = c_api_util.tf_buffer(bytes)) | |||
using (var scoped_options = c_api_util.ScopedTFImportGraphDefOptions()) | |||
@@ -61,9 +62,8 @@ namespace Tensorflow | |||
{ | |||
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); | |||
// need to create a class ImportGraphDefWithResults with IDisposal | |||
var results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status); | |||
results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status); | |||
status.Check(true); | |||
c_api.TF_DeleteImportGraphDefResults(results); | |||
} | |||
_ProcessNewOps(graph); | |||
@@ -71,7 +71,34 @@ namespace Tensorflow | |||
if (return_elements == null) | |||
return null; | |||
else | |||
throw new NotImplementedException("import_graph_def return_elements"); | |||
return _GatherReturnElements(return_elements, graph, results); | |||
} | |||
private static ITensorOrOperation[] _GatherReturnElements(string[] requested_return_elements, | |||
Graph graph, | |||
TF_ImportGraphDefResults results) | |||
{ | |||
var return_outputs = results.return_tensors; | |||
var return_opers = results.return_opers; | |||
var combined_return_elements = new List<ITensorOrOperation>(); | |||
int outputs_idx = 0; | |||
int opers_idx = 0; | |||
foreach(var name in requested_return_elements) | |||
{ | |||
if (name.Contains(":")) | |||
{ | |||
combined_return_elements.append(graph.get_tensor_by_tf_output(return_outputs[outputs_idx])); | |||
outputs_idx += 1; | |||
} | |||
else | |||
{ | |||
throw new NotImplementedException("_GatherReturnElements"); | |||
// combined_return_elements.append(graph._get_operation_by_tf_operation(return_opers[opers_idx])); | |||
} | |||
} | |||
return combined_return_elements.ToArray(); | |||
} | |||
private static void _ProcessNewOps(Graph graph) | |||
@@ -100,8 +127,29 @@ namespace Tensorflow | |||
foreach (var name in return_elements) | |||
{ | |||
throw new NotImplementedException("_PopulateTFImportGraphDefOptions"); | |||
if(name.Contains(":")) | |||
{ | |||
var (op_name, index) = _ParseTensorName(name); | |||
c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index); | |||
} | |||
else | |||
{ | |||
c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, name); | |||
} | |||
} | |||
// c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(options, validate_colocation_constraints); | |||
} | |||
private static (string, int) _ParseTensorName(string tensor_name) | |||
{ | |||
var components = tensor_name.Split(':'); | |||
if (components.Length == 2) | |||
return (components[0], int.Parse(components[1])); | |||
else if (components.Length == 1) | |||
return (components[0], 0); | |||
else | |||
throw new ValueError($"Cannot convert {tensor_name} to a tensor name."); | |||
} | |||
public static Dictionary<string, Tensor> _ConvertInputMapValues(string name, Dictionary<string, Tensor> input_map) |
@@ -494,6 +494,12 @@ namespace Tensorflow | |||
c_api.TF_DeleteGraph(handle); | |||
} | |||
public Tensor get_tensor_by_tf_output(TF_Output tf_output) | |||
{ | |||
var op = _get_operation_by_tf_operation(tf_output.oper); | |||
return op.outputs[tf_output.index]; | |||
} | |||
/// <summary> | |||
/// Returns the <see cref="Tensor"/> with the given <paramref name="name"/>. | |||
/// This method may be called concurrently from multiple threads. | |||
@@ -3,13 +3,62 @@ using System.Runtime.InteropServices; | |||
namespace Tensorflow | |||
{ | |||
[StructLayout(LayoutKind.Sequential)] | |||
public struct TF_ImportGraphDefResults | |||
public class TF_ImportGraphDefResults : DisposableObject | |||
{ | |||
public IntPtr return_tensors; | |||
public IntPtr return_nodes; | |||
/*public IntPtr return_nodes; | |||
public IntPtr missing_unused_key_names; | |||
public IntPtr missing_unused_key_indexes; | |||
public IntPtr missing_unused_key_names_data; | |||
public IntPtr missing_unused_key_names_data;*/ | |||
public TF_ImportGraphDefResults(IntPtr handle) | |||
{ | |||
_handle = handle; | |||
} | |||
public TF_Output[] return_tensors | |||
{ | |||
get | |||
{ | |||
IntPtr return_output_handle = IntPtr.Zero; | |||
int num_outputs = -1; | |||
c_api.TF_ImportGraphDefResultsReturnOutputs(_handle, ref num_outputs, ref return_output_handle); | |||
TF_Output[] return_outputs = new TF_Output[num_outputs]; | |||
unsafe | |||
{ | |||
var tf_output_ptr = (TF_Output*)return_output_handle; | |||
for (int i = 0; i < num_outputs; i++) | |||
return_outputs[i] = *(tf_output_ptr + i); | |||
return return_outputs; | |||
} | |||
} | |||
} | |||
public TF_Operation[] return_opers | |||
{ | |||
get | |||
{ | |||
return new TF_Operation[0]; | |||
/*TF_Operation return_output_handle = new TF_Operation(); | |||
int num_outputs = -1; | |||
c_api.TF_ImportGraphDefResultsReturnOperations(_handle, ref num_outputs, ref return_output_handle); | |||
TF_Operation[] return_outputs = new TF_Operation[num_outputs]; | |||
unsafe | |||
{ | |||
var tf_output_ptr = (TF_Operation*)return_output_handle; | |||
for (int i = 0; i < num_outputs; i++) | |||
return_outputs[i] = *(tf_output_ptr + i); | |||
return return_outputs; | |||
}*/ | |||
} | |||
} | |||
public static implicit operator TF_ImportGraphDefResults(IntPtr handle) | |||
=> new TF_ImportGraphDefResults(handle); | |||
public static implicit operator IntPtr(TF_ImportGraphDefResults results) | |||
=> results._handle; | |||
protected override void DisposeUnmanagedResources(IntPtr handle) | |||
=> c_api.TF_DeleteImportGraphDefResults(handle); | |||
} | |||
} |
@@ -65,9 +65,7 @@ namespace Tensorflow | |||
} | |||
public static implicit operator IntPtr(Status status) | |||
{ | |||
return status._handle; | |||
} | |||
=> status._handle; | |||
protected override void DisposeUnmanagedResources(IntPtr handle) | |||
=> TF_DeleteStatus(handle); | |||