Browse Source

TF_ImportGraphDefResults.return_tensors

tags/v0.20
Oceania2018 5 years ago
parent
commit
e488c67662
6 changed files with 117 additions and 13 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/APIs/tf.io.cs
  2. +3
    -0
      src/TensorFlowNET.Core/Binding.Util.cs
  3. +52
    -4
      src/TensorFlowNET.Core/Framework/importer.cs
  4. +6
    -0
      src/TensorFlowNET.Core/Graphs/Graph.cs
  5. +54
    -5
      src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs
  6. +1
    -3
      src/TensorFlowNET.Core/Status/Status.cs

+ 1
- 1
src/TensorFlowNET.Core/APIs/tf.io.cs View File

@@ -24,7 +24,7 @@ namespace Tensorflow
public GFile gfile = new GFile(); public GFile gfile = new GFile();
public Tensor read_file(string filename, string name = null) => gen_io_ops.read_file(filename, name); public Tensor read_file(string filename, string name = null) => gen_io_ops.read_file(filename, name);


public void import_graph_def(GraphDef graph_def,
public ITensorOrOperation[] import_graph_def(GraphDef graph_def,
Dictionary<string, Tensor> input_map = null, Dictionary<string, Tensor> input_map = null,
string[] return_elements = null, string[] return_elements = null,
string name = null, string name = null,


+ 3
- 0
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -95,6 +95,9 @@ namespace Tensorflow
throw new NotImplementedException("len() not implemented for type: " + a.GetType()); throw new NotImplementedException("len() not implemented for type: " + a.GetType());
} }


public static float min(float a, float b)
=> Math.Min(a, b);

public static T[] list<T>(IEnumerable<T> list) public static T[] list<T>(IEnumerable<T> list)
=> list.ToArray(); => list.ToArray();




src/TensorFlowNET.Core/Framework/importer.py.cs → src/TensorFlowNET.Core/Framework/importer.cs View File

@@ -54,6 +54,7 @@ namespace Tensorflow
input_map = _ConvertInputMapValues(name, input_map); input_map = _ConvertInputMapValues(name, input_map);
}); });


TF_ImportGraphDefResults results = null;
var bytes = graph_def.ToByteString().ToArray(); var bytes = graph_def.ToByteString().ToArray();
using (var buffer = c_api_util.tf_buffer(bytes)) using (var buffer = c_api_util.tf_buffer(bytes))
using (var scoped_options = c_api_util.ScopedTFImportGraphDefOptions()) using (var scoped_options = c_api_util.ScopedTFImportGraphDefOptions())
@@ -61,9 +62,8 @@ namespace Tensorflow
{ {
_PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements); _PopulateTFImportGraphDefOptions(scoped_options, prefix, input_map, return_elements);
// need to create a class ImportGraphDefWithResults with IDisposal // need to create a class ImportGraphDefWithResults with IDisposal
var results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status);
results = c_api.TF_GraphImportGraphDefWithResults(graph, buffer, scoped_options, status);
status.Check(true); status.Check(true);
c_api.TF_DeleteImportGraphDefResults(results);
} }


_ProcessNewOps(graph); _ProcessNewOps(graph);
@@ -71,7 +71,34 @@ namespace Tensorflow
if (return_elements == null) if (return_elements == null)
return null; return null;
else else
throw new NotImplementedException("import_graph_def return_elements");
return _GatherReturnElements(return_elements, graph, results);
}

private static ITensorOrOperation[] _GatherReturnElements(string[] requested_return_elements,
Graph graph,
TF_ImportGraphDefResults results)
{
var return_outputs = results.return_tensors;
var return_opers = results.return_opers;

var combined_return_elements = new List<ITensorOrOperation>();
int outputs_idx = 0;
int opers_idx = 0;
foreach(var name in requested_return_elements)
{
if (name.Contains(":"))
{
combined_return_elements.append(graph.get_tensor_by_tf_output(return_outputs[outputs_idx]));
outputs_idx += 1;
}
else
{
throw new NotImplementedException("_GatherReturnElements");
// combined_return_elements.append(graph._get_operation_by_tf_operation(return_opers[opers_idx]));
}
}

return combined_return_elements.ToArray();
} }


private static void _ProcessNewOps(Graph graph) private static void _ProcessNewOps(Graph graph)
@@ -100,8 +127,29 @@ namespace Tensorflow


foreach (var name in return_elements) foreach (var name in return_elements)
{ {
throw new NotImplementedException("_PopulateTFImportGraphDefOptions");
if(name.Contains(":"))
{
var (op_name, index) = _ParseTensorName(name);
c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index);
}
else
{
c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, name);
}
} }

// c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(options, validate_colocation_constraints);
}

private static (string, int) _ParseTensorName(string tensor_name)
{
var components = tensor_name.Split(':');
if (components.Length == 2)
return (components[0], int.Parse(components[1]));
else if (components.Length == 1)
return (components[0], 0);
else
throw new ValueError($"Cannot convert {tensor_name} to a tensor name.");
} }


public static Dictionary<string, Tensor> _ConvertInputMapValues(string name, Dictionary<string, Tensor> input_map) public static Dictionary<string, Tensor> _ConvertInputMapValues(string name, Dictionary<string, Tensor> input_map)

+ 6
- 0
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -494,6 +494,12 @@ namespace Tensorflow
c_api.TF_DeleteGraph(handle); c_api.TF_DeleteGraph(handle);
} }


public Tensor get_tensor_by_tf_output(TF_Output tf_output)
{
var op = _get_operation_by_tf_operation(tf_output.oper);
return op.outputs[tf_output.index];
}

/// <summary> /// <summary>
/// Returns the <see cref="Tensor"/> with the given <paramref name="name"/>. /// Returns the <see cref="Tensor"/> with the given <paramref name="name"/>.
/// This method may be called concurrently from multiple threads. /// This method may be called concurrently from multiple threads.


+ 54
- 5
src/TensorFlowNET.Core/Graphs/TF_ImportGraphDefResults.cs View File

@@ -3,13 +3,62 @@ using System.Runtime.InteropServices;


namespace Tensorflow namespace Tensorflow
{ {
[StructLayout(LayoutKind.Sequential)]
public struct TF_ImportGraphDefResults
public class TF_ImportGraphDefResults : DisposableObject
{ {
public IntPtr return_tensors;
public IntPtr return_nodes;
/*public IntPtr return_nodes;
public IntPtr missing_unused_key_names; public IntPtr missing_unused_key_names;
public IntPtr missing_unused_key_indexes; public IntPtr missing_unused_key_indexes;
public IntPtr missing_unused_key_names_data;
public IntPtr missing_unused_key_names_data;*/

public TF_ImportGraphDefResults(IntPtr handle)
{
_handle = handle;
}

public TF_Output[] return_tensors
{
get
{
IntPtr return_output_handle = IntPtr.Zero;
int num_outputs = -1;
c_api.TF_ImportGraphDefResultsReturnOutputs(_handle, ref num_outputs, ref return_output_handle);
TF_Output[] return_outputs = new TF_Output[num_outputs];
unsafe
{
var tf_output_ptr = (TF_Output*)return_output_handle;
for (int i = 0; i < num_outputs; i++)
return_outputs[i] = *(tf_output_ptr + i);
return return_outputs;
}
}
}

public TF_Operation[] return_opers
{
get
{
return new TF_Operation[0];
/*TF_Operation return_output_handle = new TF_Operation();
int num_outputs = -1;
c_api.TF_ImportGraphDefResultsReturnOperations(_handle, ref num_outputs, ref return_output_handle);
TF_Operation[] return_outputs = new TF_Operation[num_outputs];
unsafe
{
var tf_output_ptr = (TF_Operation*)return_output_handle;
for (int i = 0; i < num_outputs; i++)
return_outputs[i] = *(tf_output_ptr + i);
return return_outputs;
}*/
}
}

public static implicit operator TF_ImportGraphDefResults(IntPtr handle)
=> new TF_ImportGraphDefResults(handle);

public static implicit operator IntPtr(TF_ImportGraphDefResults results)
=> results._handle;

protected override void DisposeUnmanagedResources(IntPtr handle)
=> c_api.TF_DeleteImportGraphDefResults(handle);
} }
} }

+ 1
- 3
src/TensorFlowNET.Core/Status/Status.cs View File

@@ -65,9 +65,7 @@ namespace Tensorflow
} }


public static implicit operator IntPtr(Status status) public static implicit operator IntPtr(Status status)
{
return status._handle;
}
=> status._handle;


protected override void DisposeUnmanagedResources(IntPtr handle) protected override void DisposeUnmanagedResources(IntPtr handle)
=> TF_DeleteStatus(handle); => TF_DeleteStatus(handle);


Loading…
Cancel
Save