Browse Source

_FetchMapper

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
daef22991b
6 changed files with 91 additions and 2 deletions
  1. +44
    -0
      src/TensorFlowNET.Core/Graph.cs
  2. +1
    -0
      src/TensorFlowNET.Core/Operation.cs
  3. +24
    -0
      src/TensorFlowNET.Core/Session/_ElementFetchMapper.cs
  4. +4
    -2
      src/TensorFlowNET.Core/Session/_FetchHandler.cs
  5. +16
    -0
      src/TensorFlowNET.Core/Session/_FetchMapper.cs
  6. +2
    -0
      src/TensorFlowNET.Core/Tensor.cs

+ 44
- 0
src/TensorFlowNET.Core/Graph.cs View File

@@ -31,6 +31,50 @@ namespace Tensorflow
_names_in_use = new Dictionary<string, int>(); _names_in_use = new Dictionary<string, int>();
} }


public T as_graph_element<T>(T obj, bool allow_tensor = true, bool allow_operation = true)
{
return _as_graph_element_locked(obj, allow_tensor, allow_operation);
}

private Func<object> _as_graph_element(object obj)
{
return null;
}

private T _as_graph_element_locked<T>(T obj, bool allow_tensor = true, bool allow_operation = true)
{
string types_str = "";

if (allow_tensor && allow_operation)
{
types_str = "Tensor or Operation";
}
else if (allow_tensor)
{
types_str = "Tensor";
}
else if (allow_operation)
{
types_str = "Operation";
}

var temp_obj = _as_graph_element(obj);

if(obj is Tensor && allow_tensor)
{
if ((obj as Tensor).graph.Equals(this))
{
return obj;
}
else
{
throw new Exception($"Tensor {obj} is not an element of this graph.");
}
}

throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}.");
}

public unsafe Operation create_op(string op_type, List<Tensor> inputs, TF_DataType[] dtypes, public unsafe Operation create_op(string op_type, List<Tensor> inputs, TF_DataType[] dtypes,
TF_DataType[] input_types = null, string name = "", TF_DataType[] input_types = null, string name = "",
Dictionary<string, AttrValue> attrs = null, OpDef op_def = null) Dictionary<string, AttrValue> attrs = null, OpDef op_def = null)


+ 1
- 0
src/TensorFlowNET.Core/Operation.cs View File

@@ -8,6 +8,7 @@ namespace Tensorflow
public class Operation public class Operation
{ {
private Graph _graph; private Graph _graph;
public Graph graph => _graph;
public IntPtr _c_op; public IntPtr _c_op;
public int _id => _id_value; public int _id => _id_value;
private int _id_value; private int _id_value;


+ 24
- 0
src/TensorFlowNET.Core/Session/_ElementFetchMapper.cs View File

@@ -0,0 +1,24 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
/// <summary>
/// Fetch mapper for singleton tensors and ops.
/// </summary>
public class _ElementFetchMapper : _FetchMapper
{
private List<Object> _unique_fetches = new List<object>();
private Action _contraction_fn;

public _ElementFetchMapper(List<Tensor> fetches, Action contraction_fn)
{
foreach(var tensor in fetches)
{
var fetch = ops.get_default_graph().as_graph_element(tensor, allow_tensor: true, allow_operation: true);
_unique_fetches.Add(fetch);
}
}
}
}

+ 4
- 2
src/TensorFlowNET.Core/Session/_FetchHandler.cs View File

@@ -9,9 +9,11 @@ namespace Tensorflow
/// </summary> /// </summary>
public class _FetchHandler public class _FetchHandler
{ {
public _FetchHandler(Graph graph, Tensor fetches)
{
private _ElementFetchMapper _fetch_mapper;


public _FetchHandler(Graph graph, Tensor fetches, object feeds = null, object feed_handles = null)
{
_fetch_mapper = new _FetchMapper().for_fetch(fetches);
} }
} }
} }

+ 16
- 0
src/TensorFlowNET.Core/Session/_FetchMapper.cs View File

@@ -0,0 +1,16 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class _FetchMapper
{
public _ElementFetchMapper for_fetch(Tensor fetch)
{
var fetches = new List<Tensor> { fetch };

return new _ElementFetchMapper(fetches, null);
}
}
}

+ 2
- 0
src/TensorFlowNET.Core/Tensor.cs View File

@@ -13,6 +13,8 @@ namespace Tensorflow
private DataType _dtype; private DataType _dtype;
public DataType dtype => _dtype; public DataType dtype => _dtype;


public Graph graph => _op.graph;

public string name; public string name;


public Tensor(Operation op, int value_index, DataType dtype) public Tensor(Operation op, int value_index, DataType dtype)


Loading…
Cancel
Save