Browse Source

TF_ImportGraphDefResults.return_tensors

tags/v0.20
Oceania2018 5 years ago
parent
commit
e488c67662
6 changed files with 117 additions and 13 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/APIs/tf.io.cs
  2. +3
    -0
      src/TensorFlowNET.Core/Binding.Util.cs
  3. +52
    -4
      src/TensorFlowNET.Core/Framework/importer.cs
  4. +6
    -0
      src/TensorFlowNET.Core/Graphs/Graph.cs
  5. +54
    -5
      src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs
  6. +1
    -3
      src/TensorFlowNET.Core/Status/Status.cs

+ 1
- 1
src/TensorFlowNET.Core/APIs/tf.io.cs View File

@@ -24,7 +24,7 @@ namespace Tensorflow
public GFile gfile = new GFile();
public Tensor read_file(string filename, string name = null) => gen_io_ops.read_file(filename, name);

public void import_graph_def(GraphDef graph_def,
public ITensorOrOperation[] import_graph_def(GraphDef graph_def,
Dictionary<string, Tensor> input_map = null,
string[] return_elements = null,
string name = null,


+ 3
- 0
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -95,6 +95,9 @@ namespace Tensorflow
throw new NotImplementedException("len() not implemented for type: " + a.GetType());
}

public static float min(float a, float b)
=> Math.Min(a, b);

public static T[] list<T>(IEnumerable<T> list)
=> list.ToArray();



src/TensorFlowNET.Core/Framework/importer.py.cs → src/TensorFlowNET.Core/Framework/importer.cs View File

@@ -54,6 +54,7 @@ namespace Tensorflow
input_map = _ConvertInputMapValues(name, input_map);
});

TF_ImportGraphDefResults results = null;
var bytes = graph_def.ToByteString().ToArray();
using (var buffer = c_api_util.tf_buffer(bytes))
using (var scoped_options = c_api_util.ScopedTFImportGraphDefOptions())
@@ -61,9 +62,8 @@ namespace Tensorflow
{
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements);
// need to create a class ImportGraphDefWithResults with IDisposal
var results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status);
results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status);
status.Check(true);
c_api.TF_DeleteImportGraphDefResults(results);
}

_ProcessNewOps(graph);
@@ -71,7 +71,34 @@ namespace Tensorflow
if (return_elements == null)
return null;
else
throw new NotImplementedException("import_graph_def return_elements");
return _GatherReturnElements(return_elements, graph, results);
}

private static ITensorOrOperation[] _GatherReturnElements(string[] requested_return_elements,
Graph graph,
TF_ImportGraphDefResults results)
{
var return_outputs = results.return_tensors;
var return_opers = results.return_opers;

var combined_return_elements = new List<ITensorOrOperation>();
int outputs_idx = 0;
int opers_idx = 0;
foreach(var name in requested_return_elements)
{
if (name.Contains(":"))
{
combined_return_elements.append(graph.get_tensor_by_tf_output(return_outputs[outputs_idx]));
outputs_idx += 1;
}
else
{
throw new NotImplementedException("_GatherReturnElements");
// combined_return_elements.append(graph._get_operation_by_tf_operation(return_opers[opers_idx]));
}
}

return combined_return_elements.ToArray();
}

private static void _ProcessNewOps(Graph graph)
@@ -100,8 +127,29 @@ namespace Tensorflow

foreach (var name in return_elements)
{
throw new NotImplementedException("_PopulateTFImportGraphDefOptions");
if(name.Contains(":"))
{
var (op_name, index) = _ParseTensorName(name);
c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index);
}
else
{
c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, name);
}
}

// c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(options, validate_colocation_constraints);
}

private static (string, int) _ParseTensorName(string tensor_name)
{
var components = tensor_name.Split(':');
if (components.Length == 2)
return (components[0], int.Parse(components[1]));
else if (components.Length == 1)
return (components[0], 0);
else
throw new ValueError($"Cannot convert {tensor_name} to a tensor name.");
}

public static Dictionary<string, Tensor> _ConvertInputMapValues(string name, Dictionary<string, Tensor> input_map)

+ 6
- 0
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -494,6 +494,12 @@ namespace Tensorflow
c_api.TF_DeleteGraph(handle);
}

public Tensor get_tensor_by_tf_output(TF_Output tf_output)
{
var op = _get_operation_by_tf_operation(tf_output.oper);
return op.outputs[tf_output.index];
}

/// <summary>
/// Returns the <see cref="Tensor"/> with the given <paramref name="name"/>.
/// This method may be called concurrently from multiple threads.


+ 54
- 5
src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs View File

@@ -3,13 +3,62 @@ using System.Runtime.InteropServices;

namespace Tensorflow
{
[StructLayout(LayoutKind.Sequential)]
public struct TF_ImportGraphDefResults
public class TF_ImportGraphDefResults : DisposableObject
{
public IntPtr return_tensors;
public IntPtr return_nodes;
/*public IntPtr return_nodes;
public IntPtr missing_unused_key_names;
public IntPtr missing_unused_key_indexes;
public IntPtr missing_unused_key_names_data;
public IntPtr missing_unused_key_names_data;*/

public TF_ImportGraphDefResults(IntPtr handle)
{
_handle = handle;
}

public TF_Output[] return_tensors
{
get
{
IntPtr return_output_handle = IntPtr.Zero;
int num_outputs = -1;
c_api.TF_ImportGraphDefResultsReturnOutputs(_handle, ref num_outputs, ref return_output_handle);
TF_Output[] return_outputs = new TF_Output[num_outputs];
unsafe
{
var tf_output_ptr = (TF_Output*)return_output_handle;
for (int i = 0; i < num_outputs; i++)
return_outputs[i] = *(tf_output_ptr + i);
return return_outputs;
}
}
}

public TF_Operation[] return_opers
{
get
{
return new TF_Operation[0];
/*TF_Operation return_output_handle = new TF_Operation();
int num_outputs = -1;
c_api.TF_ImportGraphDefResultsReturnOperations(_handle, ref num_outputs, ref return_output_handle);
TF_Operation[] return_outputs = new TF_Operation[num_outputs];
unsafe
{
var tf_output_ptr = (TF_Operation*)return_output_handle;
for (int i = 0; i < num_outputs; i++)
return_outputs[i] = *(tf_output_ptr + i);
return return_outputs;
}*/
}
}

public static implicit operator TF_ImportGraphDefResults(IntPtr handle)
=> new TF_ImportGraphDefResults(handle);

public static implicit operator IntPtr(TF_ImportGraphDefResults results)
=> results._handle;

protected override void DisposeUnmanagedResources(IntPtr handle)
=> c_api.TF_DeleteImportGraphDefResults(handle);
}
}

+ 1
- 3
src/TensorFlowNET.Core/Status/Status.cs View File

@@ -65,9 +65,7 @@ namespace Tensorflow
}

public static implicit operator IntPtr(Status status)
{
return status._handle;
}
=> status._handle;

protected override void DisposeUnmanagedResources(IntPtr handle)
=> TF_DeleteStatus(handle);


Loading…
Cancel
Save