Browse Source

addInConstant unit test successfully

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
fca0368361
6 changed files with 109 additions and 16 deletions
  1. +60
    -8
      src/TensorFlowNET.Core/Session/BaseSession.cs
  2. +5
    -0
      src/TensorFlowNET.Core/Session/_ElementFetchMapper.cs
  3. +16
    -0
      src/TensorFlowNET.Core/Session/_FetchHandler.cs
  4. +9
    -0
      src/TensorFlowNET.Core/Tensor.cs
  5. +3
    -0
      src/TensorFlowNET.Core/c_api.cs
  6. +16
    -8
      test/TensorFlowNET.UnitTest/OperationsTest.cs

+ 60
- 8
src/TensorFlowNET.Core/Session/BaseSession.cs View File

@@ -1,6 +1,7 @@
using NumSharp.Core;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Tensorflow
@@ -38,12 +39,14 @@ namespace Tensorflow
}

public virtual byte[] run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null)
public virtual object run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null)
{
return _run(fetches, feed_dict);
var result = _run(fetches, feed_dict);

return result;
}

private unsafe byte[] _run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null)
private unsafe object _run(Tensor fetches, Dictionary<Tensor, object> feed_dict = null)
{
var feed_dict_tensor = new Dictionary<Tensor, NDArray>();

@@ -66,22 +69,71 @@ namespace Tensorflow
// Create a fetch handler to take care of the structure of fetches.
var fetch_handler = new _FetchHandler(_graph, fetches);

// Run request and get response.
// We need to keep the returned movers alive for the following _do_run().
// These movers are no longer needed when _do_run() completes, and
// are deleted when `movers` goes out of scope when this _run() ends.
var _ = _update_with_movers();
var final_fetches = fetch_handler.fetches();
var final_targets = fetch_handler.targets();

// We only want to really perform the run if fetches or targets are provided,
// or if the call is a partial run that specifies feeds.
var results = _do_run(final_fetches);

return fetch_handler.build_results(null, results);
}

private object[] _do_run(List<object> fetch_list)
{
var fetches = fetch_list.Select(x => (x as Tensor)._as_tf_output()).ToArray();

return _call_tf_sessionrun(fetches);
}

private unsafe object[] _call_tf_sessionrun(TF_Output[] fetch_list)
{
// Ensure any changes to the graph are reflected in the runtime.
_extend_graph();

var status = new Status();

var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray();

c_api.TF_SessionRun(_session,
run_options: IntPtr.Zero,
inputs: new TF_Output[] { },
input_values: new IntPtr[] { },
ninputs: 0,
outputs: new TF_Output[] { new TF_Output() },
output_values: new IntPtr[] { },
noutputs: 1,
outputs: fetch_list,
output_values: output_values,
noutputs: fetch_list.Length,
target_opers: new IntPtr[] { },
ntargets: 1,
ntargets: 0,
run_metadata: IntPtr.Zero,
status: status.Handle);

return null;
var result = output_values.Select(x => new Tensor(x).buffer).Select(x =>
{
return (object)*(float*)x;
}).ToArray();

return result;
}

/// <summary>
/// If a tensor handle that is fed to a device incompatible placeholder,
/// we move the tensor to the right device, generate a new tensor handle,
/// and update feed_dict to use the new handle.
/// </summary>
private List<object> _update_with_movers()
{
return new List<object> { };
}

private void _extend_graph()
{

}
}
}

+ 5
- 0
src/TensorFlowNET.Core/Session/_ElementFetchMapper.cs View File

@@ -21,6 +21,11 @@ namespace Tensorflow
}
}

public object build_results(object[] values)
{
return values[0];
}

public List<Object> unique_fetches()
{
return _unique_fetches;


+ 16
- 0
src/TensorFlowNET.Core/Session/_FetchHandler.cs View File

@@ -13,6 +13,7 @@ namespace Tensorflow
private List<object> _fetches = new List<object>();
private List<bool> _ops = new List<bool>();
private List<object> _final_fetches = new List<object>();
private List<object> _targets = new List<object>();

public _FetchHandler(Graph graph, Tensor fetches, object feeds = null, object feed_handles = null)
{
@@ -33,6 +34,11 @@ namespace Tensorflow
_final_fetches = _fetches;
}

public object build_results(Session session, object[] results)
{
return _fetch_mapper.build_results(results);
}

private void _assert_fetchable(Graph graph, Operation op)
{
if (!graph.is_fetchable(op))
@@ -40,5 +46,15 @@ namespace Tensorflow
throw new Exception($"Operation {op.name} has been marked as not fetchable.");
}
}

public List<Object> fetches()
{
return _final_fetches;
}

public List<Object> targets()
{
return _targets;
}
}
}

+ 9
- 0
src/TensorFlowNET.Core/Tensor.cs View File

@@ -17,6 +17,15 @@ namespace Tensorflow

public string name;

private readonly IntPtr _handle;
public IntPtr handle => _handle;
public IntPtr buffer => c_api.TF_TensorData(_handle);

public Tensor(IntPtr handle)
{
_handle = handle;
}

public Tensor(Operation op, int value_index, DataType dtype)
{
_op = op;


+ 3
- 0
src/TensorFlowNET.Core/c_api.cs View File

@@ -77,6 +77,9 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value);

[DllImport(TensorFlowLibName)]
public static extern unsafe IntPtr TF_TensorData(TF_Tensor tensor);

[DllImport(TensorFlowLibName)]
public static extern TF_Session TF_NewSession(TF_Graph graph, TF_SessionOptions opts, TF_Status status);



+ 16
- 8
test/TensorFlowNET.UnitTest/OperationsTest.cs View File

@@ -12,13 +12,7 @@ namespace TensorFlowNET.UnitTest
[TestMethod]
public void constant()
{
var a = tf.constant(4.0f);
var b = tf.constant(5.0f);
var c = tf.add(a, b);
using (var sess = tf.Session())
{
var o = sess.run(c);
}
var x = tf.constant(4.0f);
}

[TestMethod]
@@ -28,7 +22,7 @@ namespace TensorFlowNET.UnitTest
}

[TestMethod]
public void add()
public void addInPlaceholder()
{
var a = tf.placeholder(tf.float32);
var b = tf.placeholder(tf.float32);
@@ -43,5 +37,19 @@ namespace TensorFlowNET.UnitTest
var o = sess.run(c, feed_dict);
}
}

[TestMethod]
public void addInConstant()
{
var a = tf.constant(4.0f);
var b = tf.constant(5.0f);
var c = tf.add(a, b);

using (var sess = tf.Session())
{
var o = sess.run(c);
Assert.AreEqual(o, 9.0f);
}
}
}
}

Loading…
Cancel
Save