Browse Source

#172

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
e237ba6bff
7 changed files with 49 additions and 26 deletions
  1. +7
    -7
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +3
    -3
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  3. +21
    -5
      src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs
  4. +7
    -7
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  5. +4
    -4
      src/TensorFlowNET.Core/Sessions/_FetchMapper.cs
  6. +5
    -0
      src/TensorFlowNET.Core/tf.cs
  7. +2
    -0
      test/TensorFlowNET.UnitTest/GradientTest.cs

+ 7
- 7
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -15,8 +15,8 @@ namespace Tensorflow
public partial class Graph : IDisposable
{
private IntPtr _handle;
private Dictionary<int, Operation> _nodes_by_id;
private Dictionary<string, Operation> _nodes_by_name;
private Dictionary<int, ITensorOrOperation> _nodes_by_id;
public Dictionary<string, ITensorOrOperation> _nodes_by_name;
private Dictionary<string, int> _names_in_use;
public int _version;
private int _next_id_counter;
@@ -35,13 +35,13 @@ namespace Tensorflow
{
_handle = c_api.TF_NewGraph();
Status = new Status();
_nodes_by_id = new Dictionary<int, Operation>();
_nodes_by_name = new Dictionary<string, Operation>();
_nodes_by_id = new Dictionary<int, ITensorOrOperation>();
_nodes_by_name = new Dictionary<string, ITensorOrOperation>();
_names_in_use = new Dictionary<string, int>();
_graph_key = $"grap-key-{ops.uid()}/";
}

public object as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true)
public ITensorOrOperation as_graph_element(ITensorOrOperation obj, bool allow_tensor = true, bool allow_operation = true)
{
return _as_graph_element_locked(obj, allow_tensor, allow_operation);
}
@@ -54,7 +54,7 @@ namespace Tensorflow
return null;
}

private object _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true)
private ITensorOrOperation _as_graph_element_locked(ITensorOrOperation obj, bool allow_tensor = true, bool allow_operation = true)
{
string types_str = "";

@@ -294,7 +294,7 @@ namespace Tensorflow
return c_api.TF_GraphOperationByName(_handle, operName);
}

public Operation[] get_operations()
public ITensorOrOperation[] get_operations()
{
return _nodes_by_name.Values.Select(x => x).ToArray();
}


+ 3
- 3
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -36,12 +36,12 @@ namespace Tensorflow
}


public virtual NDArray run<T>(T fetches, FeedItem[] feed_dict = null)
public virtual NDArray run(object fetches, FeedItem[] feed_dict = null)
{
return _run(fetches, feed_dict);
}

private NDArray _run<T>(T fetches, FeedItem[] feed_dict = null)
private NDArray _run(object fetches, FeedItem[] feed_dict = null)
{
var feed_dict_tensor = new Dictionary<object, object>();

@@ -49,7 +49,7 @@ namespace Tensorflow
feed_dict.ToList().ForEach(x => feed_dict_tensor.Add(x.Key, x.Value));

// Create a fetch handler to take care of the structure of fetches.
var fetch_handler = new _FetchHandler<T>(_graph, fetches, feed_dict_tensor);
var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor);

// Run request and get response.
// We need to keep the returned movers alive for the following _do_run().


+ 21
- 5
src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs View File

@@ -8,20 +8,36 @@ namespace Tensorflow
/// <summary>
/// Fetch mapper for singleton tensors and ops.
/// </summary>
public class _ElementFetchMapper<T> : _FetchMapper<T>
public class _ElementFetchMapper : _FetchMapper
{
private List<object> _unique_fetches = new List<object>();
private Func<List<object>, object> _contraction_fn;

public _ElementFetchMapper(List<T> fetches, Func<List<object>, object> contraction_fn)
public _ElementFetchMapper(object[] fetches, Func<List<object>, object> contraction_fn)
{
var g = ops.get_default_graph();
ITensorOrOperation el = null;

foreach(var fetch in fetches)
{
var g = ops.get_default_graph();
var el = g.as_graph_element(fetch, allow_tensor: true, allow_operation: true);
_unique_fetches.Add(el);
switch(fetch)
{
case Tensor tensor:
el = g.as_graph_element(tensor, allow_tensor: true, allow_operation: true);
break;
case Operation op:
el = g.as_graph_element(op, allow_tensor: true, allow_operation: true);
break;
case String str:
// Looks like a Tensor name and can be a Tensor.
el = g._nodes_by_name[str];
break;
default:
throw new NotImplementedException("_ElementFetchMapper");
}
}

_unique_fetches.Add(el);
_contraction_fn = contraction_fn;
}



+ 7
- 7
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -8,24 +8,24 @@ namespace Tensorflow
/// <summary>
/// Handler for structured fetches.
/// </summary>
public class _FetchHandler<T>
public class _FetchHandler
{
private _ElementFetchMapper<T> _fetch_mapper;
private _ElementFetchMapper _fetch_mapper;
private List<Tensor> _fetches = new List<Tensor>();
private List<bool> _ops = new List<bool>();
private List<Tensor> _final_fetches = new List<Tensor>();
private List<T> _targets = new List<T>();
private List<object> _targets = new List<object>();

public _FetchHandler(Graph graph, T fetches, Dictionary<object, object> feeds = null, Action feed_handles = null)
public _FetchHandler(Graph graph, object fetches, Dictionary<object, object> feeds = null, Action feed_handles = null)
{
_fetch_mapper = new _FetchMapper<T>().for_fetch(fetches);
_fetch_mapper = new _FetchMapper().for_fetch(fetches);
foreach(var fetch in _fetch_mapper.unique_fetches())
{
switch (fetch)
{
case Operation val:
_assert_fetchable(graph, val);
_targets.Add((T)(object)val);
_targets.Add(val);
_ops.Add(true);
break;
case Tensor val:
@@ -82,7 +82,7 @@ namespace Tensorflow
return _final_fetches;
}

public List<T> targets()
public List<object> targets()
{
return _targets;
}


+ 4
- 4
src/TensorFlowNET.Core/Sessions/_FetchMapper.cs View File

@@ -4,13 +4,13 @@ using System.Text;

namespace Tensorflow
{
public class _FetchMapper<T>
public class _FetchMapper
{
public _ElementFetchMapper<T> for_fetch(T fetch)
public _ElementFetchMapper for_fetch(object fetch)
{
var fetches = new List<T> { fetch };
var fetches = new object[] { fetch };

return new _ElementFetchMapper<T>(fetches, (List<object> fetched_vals) =>
return new _ElementFetchMapper(fetches, (List<object> fetched_vals) =>
{
return fetched_vals[0];
});


+ 5
- 0
src/TensorFlowNET.Core/tf.cs View File

@@ -47,6 +47,11 @@ namespace Tensorflow
return g;
}

public static void ResetGraph()
{
g = new Graph();
}

public static Session Session()
{
defaultSession = new Session();


+ 2
- 0
test/TensorFlowNET.UnitTest/GradientTest.cs View File

@@ -12,6 +12,8 @@ namespace TensorFlowNET.UnitTest
[TestMethod]
public void Gradients()
{
tf.ResetGraph();

var a = tf.constant(0.0);
var b = 2.0 * a;
Assert.AreEqual(b.name, "mul:0");


Loading…
Cancel
Save