Browse Source

#172

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
e237ba6bff
7 changed files with 49 additions and 26 deletions
  1. +7
    -7
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +3
    -3
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  3. +21
    -5
      src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs
  4. +7
    -7
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  5. +4
    -4
      src/TensorFlowNET.Core/Sessions/_FetchMapper.cs
  6. +5
    -0
      src/TensorFlowNET.Core/tf.cs
  7. +2
    -0
      test/TensorFlowNET.UnitTest/GradientTest.cs

+ 7
- 7
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -15,8 +15,8 @@ namespace Tensorflow
public partial class Graph : IDisposable public partial class Graph : IDisposable
{ {
private IntPtr _handle; private IntPtr _handle;
private Dictionary<int, Operation> _nodes_by_id;
private Dictionary<string, Operation> _nodes_by_name;
private Dictionary<int, ITensorOrOperation> _nodes_by_id;
public Dictionary<string, ITensorOrOperation> _nodes_by_name;
private Dictionary<string, int> _names_in_use; private Dictionary<string, int> _names_in_use;
public int _version; public int _version;
private int _next_id_counter; private int _next_id_counter;
@@ -35,13 +35,13 @@ namespace Tensorflow
{ {
_handle = c_api.TF_NewGraph(); _handle = c_api.TF_NewGraph();
Status = new Status(); Status = new Status();
_nodes_by_id = new Dictionary<int, Operation>();
_nodes_by_name = new Dictionary<string, Operation>();
_nodes_by_id = new Dictionary<int, ITensorOrOperation>();
_nodes_by_name = new Dictionary<string, ITensorOrOperation>();
_names_in_use = new Dictionary<string, int>(); _names_in_use = new Dictionary<string, int>();
_graph_key = $"grap-key-{ops.uid()}/"; _graph_key = $"grap-key-{ops.uid()}/";
} }


public object as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true)
public ITensorOrOperation as_graph_element(ITensorOrOperation obj, bool allow_tensor = true, bool allow_operation = true)
{ {
return _as_graph_element_locked(obj, allow_tensor, allow_operation); return _as_graph_element_locked(obj, allow_tensor, allow_operation);
} }
@@ -54,7 +54,7 @@ namespace Tensorflow
return null; return null;
} }


private object _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true)
private ITensorOrOperation _as_graph_element_locked(ITensorOrOperation obj, bool allow_tensor = true, bool allow_operation = true)
{ {
string types_str = ""; string types_str = "";


@@ -294,7 +294,7 @@ namespace Tensorflow
return c_api.TF_GraphOperationByName(_handle, operName); return c_api.TF_GraphOperationByName(_handle, operName);
} }


public Operation[] get_operations()
public ITensorOrOperation[] get_operations()
{ {
return _nodes_by_name.Values.Select(x => x).ToArray(); return _nodes_by_name.Values.Select(x => x).ToArray();
} }


+ 3
- 3
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -36,12 +36,12 @@ namespace Tensorflow
} }




public virtual NDArray run<T>(T fetches, FeedItem[] feed_dict = null)
public virtual NDArray run(object fetches, FeedItem[] feed_dict = null)
{ {
return _run(fetches, feed_dict); return _run(fetches, feed_dict);
} }


private NDArray _run<T>(T fetches, FeedItem[] feed_dict = null)
private NDArray _run(object fetches, FeedItem[] feed_dict = null)
{ {
var feed_dict_tensor = new Dictionary<object, object>(); var feed_dict_tensor = new Dictionary<object, object>();


@@ -49,7 +49,7 @@ namespace Tensorflow
feed_dict.ToList().ForEach(x => feed_dict_tensor.Add(x.Key, x.Value)); feed_dict.ToList().ForEach(x => feed_dict_tensor.Add(x.Key, x.Value));


// Create a fetch handler to take care of the structure of fetches. // Create a fetch handler to take care of the structure of fetches.
var fetch_handler = new _FetchHandler<T>(_graph, fetches, feed_dict_tensor);
var fetch_handler = new _FetchHandler(_graph, fetches, feed_dict_tensor);


// Run request and get response. // Run request and get response.
// We need to keep the returned movers alive for the following _do_run(). // We need to keep the returned movers alive for the following _do_run().


+ 21
- 5
src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs View File

@@ -8,20 +8,36 @@ namespace Tensorflow
/// <summary> /// <summary>
/// Fetch mapper for singleton tensors and ops. /// Fetch mapper for singleton tensors and ops.
/// </summary> /// </summary>
public class _ElementFetchMapper<T> : _FetchMapper<T>
public class _ElementFetchMapper : _FetchMapper
{ {
private List<object> _unique_fetches = new List<object>(); private List<object> _unique_fetches = new List<object>();
private Func<List<object>, object> _contraction_fn; private Func<List<object>, object> _contraction_fn;


public _ElementFetchMapper(List<T> fetches, Func<List<object>, object> contraction_fn)
public _ElementFetchMapper(object[] fetches, Func<List<object>, object> contraction_fn)
{ {
var g = ops.get_default_graph();
ITensorOrOperation el = null;

foreach(var fetch in fetches) foreach(var fetch in fetches)
{ {
var g = ops.get_default_graph();
var el = g.as_graph_element(fetch, allow_tensor: true, allow_operation: true);
_unique_fetches.Add(el);
switch(fetch)
{
case Tensor tensor:
el = g.as_graph_element(tensor, allow_tensor: true, allow_operation: true);
break;
case Operation op:
el = g.as_graph_element(op, allow_tensor: true, allow_operation: true);
break;
case String str:
// Looks like a Tensor name and can be a Tensor.
el = g._nodes_by_name[str];
break;
default:
throw new NotImplementedException("_ElementFetchMapper");
}
} }


_unique_fetches.Add(el);
_contraction_fn = contraction_fn; _contraction_fn = contraction_fn;
} }




+ 7
- 7
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -8,24 +8,24 @@ namespace Tensorflow
/// <summary> /// <summary>
/// Handler for structured fetches. /// Handler for structured fetches.
/// </summary> /// </summary>
public class _FetchHandler<T>
public class _FetchHandler
{ {
private _ElementFetchMapper<T> _fetch_mapper;
private _ElementFetchMapper _fetch_mapper;
private List<Tensor> _fetches = new List<Tensor>(); private List<Tensor> _fetches = new List<Tensor>();
private List<bool> _ops = new List<bool>(); private List<bool> _ops = new List<bool>();
private List<Tensor> _final_fetches = new List<Tensor>(); private List<Tensor> _final_fetches = new List<Tensor>();
private List<T> _targets = new List<T>();
private List<object> _targets = new List<object>();


public _FetchHandler(Graph graph, T fetches, Dictionary<object, object> feeds = null, Action feed_handles = null)
public _FetchHandler(Graph graph, object fetches, Dictionary<object, object> feeds = null, Action feed_handles = null)
{ {
_fetch_mapper = new _FetchMapper<T>().for_fetch(fetches);
_fetch_mapper = new _FetchMapper().for_fetch(fetches);
foreach(var fetch in _fetch_mapper.unique_fetches()) foreach(var fetch in _fetch_mapper.unique_fetches())
{ {
switch (fetch) switch (fetch)
{ {
case Operation val: case Operation val:
_assert_fetchable(graph, val); _assert_fetchable(graph, val);
_targets.Add((T)(object)val);
_targets.Add(val);
_ops.Add(true); _ops.Add(true);
break; break;
case Tensor val: case Tensor val:
@@ -82,7 +82,7 @@ namespace Tensorflow
return _final_fetches; return _final_fetches;
} }


public List<T> targets()
public List<object> targets()
{ {
return _targets; return _targets;
} }


+ 4
- 4
src/TensorFlowNET.Core/Sessions/_FetchMapper.cs View File

@@ -4,13 +4,13 @@ using System.Text;


namespace Tensorflow namespace Tensorflow
{ {
public class _FetchMapper<T>
public class _FetchMapper
{ {
public _ElementFetchMapper<T> for_fetch(T fetch)
public _ElementFetchMapper for_fetch(object fetch)
{ {
var fetches = new List<T> { fetch };
var fetches = new object[] { fetch };


return new _ElementFetchMapper<T>(fetches, (List<object> fetched_vals) =>
return new _ElementFetchMapper(fetches, (List<object> fetched_vals) =>
{ {
return fetched_vals[0]; return fetched_vals[0];
}); });


+ 5
- 0
src/TensorFlowNET.Core/tf.cs View File

@@ -47,6 +47,11 @@ namespace Tensorflow
return g; return g;
} }


public static void ResetGraph()
{
g = new Graph();
}

public static Session Session() public static Session Session()
{ {
defaultSession = new Session(); defaultSession = new Session();


+ 2
- 0
test/TensorFlowNET.UnitTest/GradientTest.cs View File

@@ -12,6 +12,8 @@ namespace TensorFlowNET.UnitTest
[TestMethod] [TestMethod]
public void Gradients() public void Gradients()
{ {
tf.ResetGraph();

var a = tf.constant(0.0); var a = tf.constant(0.0);
var b = 2.0 * a; var b = 2.0 * a;
Assert.AreEqual(b.name, "mul:0"); Assert.AreEqual(b.name, "mul:0");


Loading…
Cancel
Save