- Ignored all unit tests related to CSession as it does not use TF.NET's api directly and unable to be tested with other tests parallely.tags/v0.12
@@ -22,58 +22,54 @@ using System.Linq; | |||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | |||||
/// <summary> | |||||
/// Represents a graph node that performs computation on tensors. | |||||
/// | |||||
/// An `Operation` is a node in a TensorFlow `Graph` that takes zero or | |||||
/// more `Tensor` objects as input, and produces zero or more `Tensor` | |||||
/// objects as output. Objects of type `Operation` are created by | |||||
/// calling an op constructor(such as `tf.matmul`) | |||||
/// or `tf.Graph.create_op`. | |||||
/// | |||||
/// For example `c = tf.matmul(a, b)` creates an `Operation` of type | |||||
/// "MatMul" that takes tensors `a` and `b` as input, and produces `c` | |||||
/// as output. | |||||
/// | |||||
/// After the graph has been launched in a session, an `Operation` can | |||||
/// be executed by passing it to | |||||
/// `tf.Session.run`. | |||||
/// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`. | |||||
{ | |||||
/// <summary> | |||||
/// Represents a graph node that performs computation on tensors. | |||||
/// | |||||
/// An `Operation` is a node in a TensorFlow `Graph` that takes zero or | |||||
/// more `Tensor` objects as input, and produces zero or more `Tensor` | |||||
/// objects as output. Objects of type `Operation` are created by | |||||
/// calling an op constructor(such as `tf.matmul`) | |||||
/// or `tf.Graph.create_op`. | |||||
/// | |||||
/// For example `c = tf.matmul(a, b)` creates an `Operation` of type | |||||
/// "MatMul" that takes tensors `a` and `b` as input, and produces `c` | |||||
/// as output. | |||||
/// | |||||
/// After the graph has been launched in a session, an `Operation` can | |||||
/// be executed by passing it to | |||||
/// `tf.Session.run`. | |||||
/// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`. | |||||
/// </summary> | /// </summary> | ||||
public partial class Operation : ITensorOrOperation | public partial class Operation : ITensorOrOperation | ||||
{ | { | ||||
private readonly IntPtr _handle; // _c_op in python | private readonly IntPtr _handle; // _c_op in python | ||||
private readonly IntPtr _operDesc; | |||||
private readonly IntPtr _operDesc; | |||||
private readonly Graph _graph; | |||||
private NodeDef _node_def; | |||||
private Graph _graph; | |||||
public string type => OpType; | public string type => OpType; | ||||
public Graph graph => _graph; | public Graph graph => _graph; | ||||
public int _id => _id_value; | public int _id => _id_value; | ||||
public int _id_value; | public int _id_value; | ||||
public Operation op => this; | public Operation op => this; | ||||
public TF_DataType dtype => TF_DataType.DtInvalid; | public TF_DataType dtype => TF_DataType.DtInvalid; | ||||
public string name => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationName(_handle)); | public string name => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationName(_handle)); | ||||
public string OpType => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | public string OpType => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | ||||
public string Device => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | public string Device => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | ||||
private NodeDef _node_def; | |||||
public NodeDef node_def | public NodeDef node_def | ||||
{ | { | ||||
get | get | ||||
{ | { | ||||
if(_node_def == null) | |||||
if (_node_def == null) | |||||
_node_def = GetNodeDef(); | _node_def = GetNodeDef(); | ||||
return _node_def; | return _node_def; | ||||
} | } | ||||
} | } | ||||
public Operation(IntPtr handle, Graph g=null) | |||||
public Operation(IntPtr handle, Graph g = null) | |||||
{ | { | ||||
if (handle == IntPtr.Zero) | if (handle == IntPtr.Zero) | ||||
return; | return; | ||||
@@ -97,14 +93,15 @@ namespace Tensorflow | |||||
_operDesc = c_api.TF_NewOperation(g, opType, oper_name); | _operDesc = c_api.TF_NewOperation(g, opType, oper_name); | ||||
c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); | c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); | ||||
using (var status = new Status()) | |||||
{ | |||||
_handle = c_api.TF_FinishOperation(_operDesc, status); | |||||
status.Check(true); | |||||
} | |||||
// Dict mapping op name to file and line information for op colocation | |||||
// context managers. | |||||
lock (Locks.ProcessWide) | |||||
using (var status = new Status()) | |||||
{ | |||||
_handle = c_api.TF_FinishOperation(_operDesc, status); | |||||
status.Check(true); | |||||
} | |||||
// Dict mapping op name to file and line information for op colocation | |||||
// context managers. | |||||
_control_flow_context = graph._get_control_flow_context(); | _control_flow_context = graph._get_control_flow_context(); | ||||
} | } | ||||
@@ -133,9 +130,9 @@ namespace Tensorflow | |||||
// Build the list of control inputs. | // Build the list of control inputs. | ||||
var control_input_ops = new List<Operation>(); | var control_input_ops = new List<Operation>(); | ||||
if(control_inputs != null) | |||||
if (control_inputs != null) | |||||
{ | { | ||||
foreach(var c in control_inputs) | |||||
foreach (var c in control_inputs) | |||||
{ | { | ||||
switch (c) | switch (c) | ||||
{ | { | ||||
@@ -196,15 +193,13 @@ namespace Tensorflow | |||||
{ | { | ||||
if (!string.IsNullOrEmpty(input_arg.NumberAttr)) | if (!string.IsNullOrEmpty(input_arg.NumberAttr)) | ||||
{ | { | ||||
input_len = (int)attrs[input_arg.NumberAttr].I; | |||||
input_len = (int) attrs[input_arg.NumberAttr].I; | |||||
is_sequence = true; | is_sequence = true; | ||||
} | |||||
else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) | |||||
} else if (!string.IsNullOrEmpty(input_arg.TypeListAttr)) | |||||
{ | { | ||||
input_len = attrs[input_arg.TypeListAttr].List.Type.Count; | input_len = attrs[input_arg.TypeListAttr].List.Type.Count; | ||||
is_sequence = true; | is_sequence = true; | ||||
} | |||||
else | |||||
} else | |||||
{ | { | ||||
input_len = 1; | input_len = 1; | ||||
is_sequence = false; | is_sequence = false; | ||||
@@ -225,22 +220,21 @@ namespace Tensorflow | |||||
{ | { | ||||
AttrValue x = null; | AttrValue x = null; | ||||
using (var status = new Status()) | |||||
using (var buf = new Buffer()) | |||||
{ | |||||
unsafe | |||||
lock (Locks.ProcessWide) | |||||
using (var status = new Status()) | |||||
using (var buf = new Buffer()) | |||||
{ | { | ||||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | ||||
status.Check(true); | status.Check(true); | ||||
x = AttrValue.Parser.ParseFrom(buf.MemoryBlock.Stream()); | x = AttrValue.Parser.ParseFrom(buf.MemoryBlock.Stream()); | ||||
} | } | ||||
} | |||||
string oneof_value = x.ValueCase.ToString(); | string oneof_value = x.ValueCase.ToString(); | ||||
if (string.IsNullOrEmpty(oneof_value)) | if (string.IsNullOrEmpty(oneof_value)) | ||||
return null; | return null; | ||||
if(oneof_value == "list") | |||||
if (oneof_value == "list") | |||||
throw new NotImplementedException($"Unsupported field type in {x.ToString()}"); | throw new NotImplementedException($"Unsupported field type in {x.ToString()}"); | ||||
if (oneof_value == "type") | if (oneof_value == "type") | ||||
@@ -259,60 +253,63 @@ namespace Tensorflow | |||||
private NodeDef GetNodeDef() | private NodeDef GetNodeDef() | ||||
{ | { | ||||
using (var s = new Status()) | |||||
using (var buffer = new Buffer()) | |||||
{ | |||||
c_api.TF_OperationToNodeDef(_handle, buffer, s); | |||||
s.Check(); | |||||
return NodeDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// Update the input to this operation at the given index. | |||||
/// | |||||
/// NOTE: This is for TF internal use only.Please don't use it. | |||||
/// </summary> | |||||
/// <param name="index">the index of the input to update.</param> | |||||
/// <param name="tensor"> the Tensor to be used as the input at the given index.</param> | |||||
public void _update_input(int index, Tensor tensor) | |||||
{ | |||||
_assert_same_graph(tensor); | |||||
var input = _tf_input(index); | |||||
var output = tensor._as_tf_output(); | |||||
// Reset cached inputs. | |||||
_inputs = null; | |||||
// after the c_api call next time _inputs is accessed | |||||
// the updated inputs are reloaded from the c_api | |||||
using (var status = new Status()) | |||||
{ | |||||
c_api.UpdateEdge(_graph, output, input, status); | |||||
//var updated_inputs = inputs; | |||||
status.Check(); | |||||
} | |||||
} | |||||
private void _assert_same_graph(Tensor tensor) | |||||
{ | |||||
//TODO: implement | |||||
} | |||||
/// <summary> | |||||
/// Create and return a new TF_Output for output_idx'th output of this op. | |||||
/// </summary> | |||||
public TF_Output _tf_output(int output_idx) | |||||
{ | |||||
return new TF_Output(op, output_idx); | |||||
} | |||||
/// <summary> | |||||
/// Create and return a new TF_Input for input_idx'th input of this op. | |||||
/// </summary> | |||||
public TF_Input _tf_input(int input_idx) | |||||
{ | |||||
return new TF_Input(op, input_idx); | |||||
} | |||||
} | |||||
} | |||||
lock (Locks.ProcessWide) | |||||
using (var s = new Status()) | |||||
using (var buffer = new Buffer()) | |||||
{ | |||||
c_api.TF_OperationToNodeDef(_handle, buffer, s); | |||||
s.Check(); | |||||
return NodeDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// Update the input to this operation at the given index. | |||||
/// | |||||
/// NOTE: This is for TF internal use only.Please don't use it. | |||||
/// </summary> | |||||
/// <param name="index">the index of the input to update.</param> | |||||
/// <param name="tensor"> the Tensor to be used as the input at the given index.</param> | |||||
public void _update_input(int index, Tensor tensor) | |||||
{ | |||||
_assert_same_graph(tensor); | |||||
var input = _tf_input(index); | |||||
var output = tensor._as_tf_output(); | |||||
// Reset cached inputs. | |||||
_inputs = null; | |||||
// after the c_api call next time _inputs is accessed | |||||
// the updated inputs are reloaded from the c_api | |||||
lock (Locks.ProcessWide) | |||||
using (var status = new Status()) | |||||
{ | |||||
c_api.UpdateEdge(_graph, output, input, status); | |||||
//var updated_inputs = inputs; | |||||
status.Check(); | |||||
} | |||||
} | |||||
private void _assert_same_graph(Tensor tensor) | |||||
{ | |||||
//TODO: implement | |||||
} | |||||
/// <summary> | |||||
/// Create and return a new TF_Output for output_idx'th output of this op. | |||||
/// </summary> | |||||
public TF_Output _tf_output(int output_idx) | |||||
{ | |||||
return new TF_Output(op, output_idx); | |||||
} | |||||
/// <summary> | |||||
/// Create and return a new TF_Input for input_idx'th input of this op. | |||||
/// </summary> | |||||
public TF_Input _tf_input(int input_idx) | |||||
{ | |||||
return new TF_Input(op, input_idx); | |||||
} | |||||
} | |||||
} |
@@ -36,23 +36,20 @@ namespace Tensorflow | |||||
protected byte[] _target; | protected byte[] _target; | ||||
public Graph graph => _graph; | public Graph graph => _graph; | ||||
public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) | |||||
public BaseSession(string target = "", Graph g = null, SessionOptions opts = null, Status status = null) | |||||
{ | { | ||||
_graph = g ?? ops.get_default_graph(); | _graph = g ?? ops.get_default_graph(); | ||||
_graph.as_default(); | _graph.as_default(); | ||||
_target = Encoding.UTF8.GetBytes(target); | _target = Encoding.UTF8.GetBytes(target); | ||||
SessionOptions newOpts = opts ?? new SessionOptions(); | |||||
SessionOptions lopts = opts ?? new SessionOptions(); | |||||
var status = new Status(); | |||||
_handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status); | |||||
// dispose opts only if not provided externally. | |||||
if (opts == null) | |||||
newOpts.Dispose(); | |||||
status.Check(true); | |||||
lock (Locks.ProcessWide) | |||||
{ | |||||
status = status ?? new Status(); | |||||
_handle = c_api.TF_NewSession(_graph, opts ?? lopts, status); | |||||
status.Check(true); | |||||
} | |||||
} | } | ||||
public virtual void run(Operation op, params FeedItem[] feed_dict) | public virtual void run(Operation op, params FeedItem[] feed_dict) | ||||
@@ -72,19 +69,19 @@ namespace Tensorflow | |||||
public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | ||||
{ | { | ||||
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); | |||||
var results = _run(new object[] {fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4}, feed_dict); | |||||
return (results[0], results[1], results[2], results[3]); | return (results[0], results[1], results[2], results[3]); | ||||
} | } | ||||
public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | ||||
{ | { | ||||
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); | |||||
var results = _run(new object[] {fetches.Item1, fetches.Item2, fetches.Item3}, feed_dict); | |||||
return (results[0], results[1], results[2]); | return (results[0], results[1], results[2]); | ||||
} | } | ||||
public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | ||||
{ | { | ||||
var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); | |||||
var results = _run(new object[] {fetches.Item1, fetches.Item2}, feed_dict); | |||||
return (results[0], results[1]); | return (results[0], results[1]); | ||||
} | } | ||||
@@ -95,8 +92,7 @@ namespace Tensorflow | |||||
public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) | public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) | ||||
{ | { | ||||
var feed_items = feed_dict == null ? new FeedItem[0] : | |||||
feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | |||||
var feed_items = feed_dict == null ? new FeedItem[0] : feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | |||||
return _run(fetches, feed_items); | return _run(fetches, feed_items); | ||||
} | } | ||||
@@ -130,7 +126,7 @@ namespace Tensorflow | |||||
// We only want to really perform the run if fetches or targets are provided, | // We only want to really perform the run if fetches or targets are provided, | ||||
// or if the call is a partial run that specifies feeds. | // or if the call is a partial run that specifies feeds. | ||||
var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor); | |||||
var results = _do_run(final_targets.Select(x => (Operation) x).ToList(), final_fetches, feed_dict_tensor); | |||||
return fetch_handler.build_results(this, results); | return fetch_handler.build_results(this, results); | ||||
} | } | ||||
@@ -150,7 +146,6 @@ namespace Tensorflow | |||||
/// </returns> | /// </returns> | ||||
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | ||||
{ | { | ||||
var feeds = new KeyValuePair<TF_Output, Tensor>[feed_dict.Count]; | var feeds = new KeyValuePair<TF_Output, Tensor>[feed_dict.Count]; | ||||
int i = 0; | int i = 0; | ||||
foreach (var x in feed_dict) | foreach (var x in feed_dict) | ||||
@@ -159,16 +154,25 @@ namespace Tensorflow | |||||
{ | { | ||||
switch (x.Value) | switch (x.Value) | ||||
{ | { | ||||
case Tensor v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); break; | |||||
case NDArray v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); break; | |||||
case IntPtr v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
case Tensor v: | |||||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); | |||||
break; | |||||
case NDArray v: | |||||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); | |||||
break; | |||||
case IntPtr v: | |||||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
break; | |||||
#if _REGEN | #if _REGEN | ||||
// @formatter:off — disable formatter after this line | |||||
%types = ["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] | %types = ["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] | ||||
%foreach types% | %foreach types% | ||||
case #1 v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | case #1 v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | ||||
case #1[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | case #1[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | ||||
% | % | ||||
// @formatter:on — enable formatter after this line | |||||
#else | #else | ||||
// @formatter:off — disable formatter after this line | |||||
case sbyte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | case sbyte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | ||||
case sbyte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | case sbyte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | ||||
case byte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | case byte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | ||||
@@ -191,9 +195,14 @@ namespace Tensorflow | |||||
case double[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | case double[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | ||||
case Complex v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | case Complex v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | ||||
case Complex[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | case Complex[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | ||||
// @formatter:on — enable formatter after this line | |||||
#endif | #endif | ||||
case bool v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte) (v ? 1 : 0), TF_DataType.TF_BOOL)); break; | |||||
case string v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; | |||||
case bool v: | |||||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte) (v ? 1 : 0), TF_DataType.TF_BOOL)); | |||||
break; | |||||
case string v: | |||||
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); | |||||
break; | |||||
default: | default: | ||||
throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? "<null>"}"); | throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? "<null>"}"); | ||||
} | } | ||||
@@ -217,12 +226,12 @@ namespace Tensorflow | |||||
c_api.TF_SessionRun(_handle, | c_api.TF_SessionRun(_handle, | ||||
run_options: null, | run_options: null, | ||||
inputs: feed_dict.Select(f => f.Key).ToArray(), | inputs: feed_dict.Select(f => f.Key).ToArray(), | ||||
input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), | |||||
input_values: feed_dict.Select(f => (IntPtr) f.Value).ToArray(), | |||||
ninputs: feed_dict.Length, | ninputs: feed_dict.Length, | ||||
outputs: fetch_list, | outputs: fetch_list, | ||||
output_values: output_values, | output_values: output_values, | ||||
noutputs: fetch_list.Length, | noutputs: fetch_list.Length, | ||||
target_opers: target_list.Select(f => (IntPtr)f).ToArray(), | |||||
target_opers: target_list.Select(f => (IntPtr) f).ToArray(), | |||||
ntargets: target_list.Count, | ntargets: target_list.Count, | ||||
run_metadata: IntPtr.Zero, | run_metadata: IntPtr.Zero, | ||||
status: status); | status: status); | ||||
@@ -253,7 +262,7 @@ namespace Tensorflow | |||||
ret = NDArray.Scalar(*(bool*) srcAddress); | ret = NDArray.Scalar(*(bool*) srcAddress); | ||||
break; | break; | ||||
case TF_DataType.TF_STRING: | case TF_DataType.TF_STRING: | ||||
using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize))) | |||||
using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long) tensor.bytesize))) | |||||
ret = NDArray.FromString(reader.ReadString()); | ret = NDArray.FromString(reader.ReadString()); | ||||
break; | break; | ||||
case TF_DataType.TF_UINT8: | case TF_DataType.TF_UINT8: | ||||
@@ -318,81 +327,95 @@ namespace Tensorflow | |||||
#endregion | #endregion | ||||
#else | #else | ||||
#region Compute | |||||
switch (tensor.dtype) | |||||
{ | |||||
case TF_DataType.TF_BOOL: | |||||
{ | |||||
#region Compute | |||||
switch (tensor.dtype) | |||||
{ | |||||
case TF_DataType.TF_BOOL: | |||||
{ | |||||
ret = new NDArray(NPTypeCode.Boolean, ndims, false); | ret = new NDArray(NPTypeCode.Boolean, ndims, false); | ||||
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | ||||
break; | |||||
} | |||||
case TF_DataType.TF_UINT8: | |||||
{ | |||||
break; | |||||
} | |||||
case TF_DataType.TF_UINT8: | |||||
{ | |||||
ret = new NDArray(NPTypeCode.Byte, ndims, false); | ret = new NDArray(NPTypeCode.Byte, ndims, false); | ||||
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | ||||
break; | |||||
} | |||||
case TF_DataType.TF_INT16: | |||||
{ | |||||
break; | |||||
} | |||||
case TF_DataType.TF_INT16: | |||||
{ | |||||
ret = new NDArray(NPTypeCode.Int16, ndims, false); | ret = new NDArray(NPTypeCode.Int16, ndims, false); | ||||
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | ||||
break; | |||||
} | |||||
case TF_DataType.TF_UINT16: | |||||
{ | |||||
break; | |||||
} | |||||
case TF_DataType.TF_UINT16: | |||||
{ | |||||
ret = new NDArray(NPTypeCode.UInt16, ndims, false); | ret = new NDArray(NPTypeCode.UInt16, ndims, false); | ||||
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | ||||
break; | |||||
} | |||||
case TF_DataType.TF_INT32: | |||||
{ | |||||
break; | |||||
} | |||||
case TF_DataType.TF_INT32: | |||||
{ | |||||
ret = new NDArray(NPTypeCode.Int32, ndims, false); | ret = new NDArray(NPTypeCode.Int32, ndims, false); | ||||
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | ||||
break; | |||||
} | |||||
case TF_DataType.TF_UINT32: | |||||
{ | |||||
break; | |||||
} | |||||
case TF_DataType.TF_UINT32: | |||||
{ | |||||
ret = new NDArray(NPTypeCode.UInt32, ndims, false); | ret = new NDArray(NPTypeCode.UInt32, ndims, false); | ||||
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | ||||
break; | |||||
} | |||||
case TF_DataType.TF_INT64: | |||||
{ | |||||
break; | |||||
} | |||||
case TF_DataType.TF_INT64: | |||||
{ | |||||
ret = new NDArray(NPTypeCode.Int64, ndims, false); | ret = new NDArray(NPTypeCode.Int64, ndims, false); | ||||
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | ||||
break; | |||||
} | |||||
case TF_DataType.TF_UINT64: | |||||
{ | |||||
break; | |||||
} | |||||
case TF_DataType.TF_UINT64: | |||||
{ | |||||
ret = new NDArray(NPTypeCode.UInt64, ndims, false); | ret = new NDArray(NPTypeCode.UInt64, ndims, false); | ||||
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | ||||
break; | |||||
} | |||||
case TF_DataType.TF_DOUBLE: | |||||
{ | |||||
break; | |||||
} | |||||
case TF_DataType.TF_DOUBLE: | |||||
{ | |||||
ret = new NDArray(NPTypeCode.Double, ndims, false); | ret = new NDArray(NPTypeCode.Double, ndims, false); | ||||
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | ||||
break; | |||||
} | |||||
case TF_DataType.TF_FLOAT: | |||||
{ | |||||
break; | |||||
} | |||||
case TF_DataType.TF_FLOAT: | |||||
{ | |||||
ret = new NDArray(NPTypeCode.Single, ndims, false); | ret = new NDArray(NPTypeCode.Single, ndims, false); | ||||
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); | ||||
break; | |||||
} | |||||
break; | |||||
} | |||||
case TF_DataType.TF_STRING: | case TF_DataType.TF_STRING: | ||||
{ | { | ||||
throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
//TODO:! This is not the way to handle string[], it should be done with TF_DecodeString | //TODO:! This is not the way to handle string[], it should be done with TF_DecodeString | ||||
using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize))) | |||||
using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long) tensor.bytesize))) | |||||
ret = NDArray.FromString(reader.ReadString()); | ret = NDArray.FromString(reader.ReadString()); | ||||
break; | break; | ||||
} | } | ||||
default: | |||||
throw new NotSupportedException(); | |||||
} | |||||
#endregion | |||||
default: | |||||
throw new NotSupportedException(); | |||||
} | |||||
#endregion | |||||
#endif | #endif | ||||
} | } | ||||
} | } | ||||
@@ -411,9 +434,7 @@ namespace Tensorflow | |||||
} | } | ||||
private void _extend_graph() | private void _extend_graph() | ||||
{ | |||||
} | |||||
{ } | |||||
public void close() | public void close() | ||||
{ | { | ||||
@@ -422,11 +443,12 @@ namespace Tensorflow | |||||
protected override void DisposeUnmanagedResources(IntPtr handle) | protected override void DisposeUnmanagedResources(IntPtr handle) | ||||
{ | { | ||||
using (var status = new Status()) | |||||
{ | |||||
c_api.TF_DeleteSession(handle, status); | |||||
status.Check(true); | |||||
} | |||||
lock (Locks.ProcessWide) | |||||
using (var status = new Status()) | |||||
{ | |||||
c_api.TF_DeleteSession(handle, status); | |||||
status.Check(true); | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -21,24 +21,16 @@ namespace Tensorflow | |||||
{ | { | ||||
public class Session : BaseSession, IObjectLife | public class Session : BaseSession, IObjectLife | ||||
{ | { | ||||
public Session(string target = "", Graph g = null) | |||||
: base(target, g, null) | |||||
{ | |||||
} | |||||
public Session(string target = "", Graph g = null) : base(target, g, null) | |||||
{ } | |||||
public Session(IntPtr handle, Graph g = null) | |||||
: base("", g, null) | |||||
public Session(IntPtr handle, Graph g = null) : base("", g, null) | |||||
{ | { | ||||
_handle = handle; | _handle = handle; | ||||
} | } | ||||
public Session(Graph g, SessionOptions opts = null, Status s = null) | |||||
: base("", g, opts) | |||||
{ | |||||
if (s == null) | |||||
s = new Status(); | |||||
} | |||||
public Session(Graph g, SessionOptions opts = null, Status s = null) : base("", g, opts, s) | |||||
{ } | |||||
public Session as_default() | public Session as_default() | ||||
{ | { | ||||
@@ -21,6 +21,7 @@ using Google.Protobuf; | |||||
using System.Linq; | using System.Linq; | ||||
using System.Threading; | using System.Threading; | ||||
using NumSharp; | using NumSharp; | ||||
using Tensorflow.Util; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -207,47 +208,49 @@ namespace Tensorflow | |||||
/// <returns>A wrapped TF_Operation*.</returns> | /// <returns>A wrapped TF_Operation*.</returns> | ||||
public static (IntPtr, IntPtr) _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs) | public static (IntPtr, IntPtr) _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs) | ||||
{ | { | ||||
var op_desc = graph.NewOperation(node_def.Op, node_def.Name); | |||||
//TODO: Implement TF_SetDevice | |||||
//if node_def.device: | |||||
// c_api.TF_SetDevice(op_desc, compat.as_str(node_def.device)) | |||||
// Add inputs | |||||
foreach (var op_input in inputs) | |||||
lock (Locks.ProcessWide) | |||||
{ | { | ||||
if (op_input is Tensor[] op_inputs) | |||||
c_api.TF_AddInputList(op_desc, op_inputs.Select(x => x._as_tf_output()).ToArray(), op_inputs.Length); | |||||
else if (op_input is Tensor op_input1) | |||||
var op_desc = graph.NewOperation(node_def.Op, node_def.Name); | |||||
//TODO: Implement TF_SetDevice | |||||
//if node_def.device: | |||||
// c_api.TF_SetDevice(op_desc, compat.as_str(node_def.device)) | |||||
// Add inputs | |||||
foreach (var op_input in inputs) | |||||
{ | { | ||||
c_api.TF_AddInput(op_desc, op_input1._as_tf_output()); | |||||
if (op_input is Tensor[] op_inputs) | |||||
c_api.TF_AddInputList(op_desc, op_inputs.Select(x => x._as_tf_output()).ToArray(), op_inputs.Length); | |||||
else if (op_input is Tensor op_input1) | |||||
{ | |||||
c_api.TF_AddInput(op_desc, op_input1._as_tf_output()); | |||||
} else | |||||
throw new NotImplementedException("_create_c_op"); | |||||
} | } | ||||
else | |||||
throw new NotImplementedException("_create_c_op"); | |||||
} | |||||
var status = new Status(); | |||||
var status = new Status(); | |||||
// Add control inputs | |||||
foreach (var control_input in control_inputs) | |||||
c_api.TF_AddControlInput(op_desc, control_input); | |||||
// Add control inputs | |||||
foreach (var control_input in control_inputs) | |||||
c_api.TF_AddControlInput(op_desc, control_input); | |||||
// Add attrs | |||||
foreach (var attr in node_def.Attr) | |||||
{ | |||||
var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. | |||||
var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak | |||||
Marshal.Copy(bytes, 0, proto, bytes.Length); | |||||
uint len = (uint)bytes.Length; | |||||
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); | |||||
// Add attrs | |||||
foreach (var attr in node_def.Attr) | |||||
{ | |||||
var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream. | |||||
var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak | |||||
Marshal.Copy(bytes, 0, proto, bytes.Length); | |||||
uint len = (uint) bytes.Length; | |||||
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status); | |||||
status.Check(true); | |||||
} | |||||
status.Check(true); | |||||
} | |||||
var c_op = c_api.TF_FinishOperation(op_desc, status); | |||||
var c_op = c_api.TF_FinishOperation(op_desc, status); | |||||
status.Check(true); | |||||
status.Check(true); | |||||
return (c_op, op_desc); | |||||
return (c_op, op_desc); | |||||
} | |||||
} | } | ||||
public static OpDef _get_op_def(Graph graph, string type) | public static OpDef _get_op_def(Graph graph, string type) | ||||
@@ -11,7 +11,7 @@ namespace TensorFlowNET.UnitTest | |||||
/// tensorflow\c\c_api_test.cc | /// tensorflow\c\c_api_test.cc | ||||
/// `class CApiGradientsTest` | /// `class CApiGradientsTest` | ||||
/// </summary> | /// </summary> | ||||
[TestClass] | |||||
[TestClass, Ignore] | |||||
public class CApiGradientsTest : CApiTest, IDisposable | public class CApiGradientsTest : CApiTest, IDisposable | ||||
{ | { | ||||
private Graph graph_ = new Graph(); | private Graph graph_ = new Graph(); | ||||
@@ -2,6 +2,7 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.Util; | |||||
namespace TensorFlowNET.UnitTest | namespace TensorFlowNET.UnitTest | ||||
{ | { | ||||
@@ -22,9 +23,12 @@ namespace TensorFlowNET.UnitTest | |||||
public CSession(Graph graph, Status s, bool user_XLA = false) | public CSession(Graph graph, Status s, bool user_XLA = false) | ||||
{ | { | ||||
var opts = new SessionOptions(); | |||||
opts.SetConfig(new ConfigProto { InterOpParallelismThreads = 4 }); | |||||
session_ = new Session(graph, opts, s); | |||||
lock (Locks.ProcessWide) | |||||
{ | |||||
var opts = new SessionOptions(); | |||||
opts.SetConfig(new ConfigProto {InterOpParallelismThreads = 4}); | |||||
session_ = new Session(graph, opts, s); | |||||
} | |||||
} | } | ||||
public void SetInputs(Dictionary<Operation, Tensor> inputs) | public void SetInputs(Dictionary<Operation, Tensor> inputs) | ||||
@@ -64,13 +68,13 @@ namespace TensorFlowNET.UnitTest | |||||
public unsafe void Run(Status s) | public unsafe void Run(Status s) | ||||
{ | { | ||||
var inputs_ptr = inputs_.ToArray(); | var inputs_ptr = inputs_.ToArray(); | ||||
var input_values_ptr = input_values_.Select(x => (IntPtr)x).ToArray(); | |||||
var input_values_ptr = input_values_.Select(x => (IntPtr) x).ToArray(); | |||||
var outputs_ptr = outputs_.ToArray(); | var outputs_ptr = outputs_.ToArray(); | ||||
var output_values_ptr = output_values_.Select(x => IntPtr.Zero).ToArray(); | var output_values_ptr = output_values_.Select(x => IntPtr.Zero).ToArray(); | ||||
IntPtr[] targets_ptr = new IntPtr[0]; | IntPtr[] targets_ptr = new IntPtr[0]; | ||||
c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length, | c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length, | ||||
outputs_ptr, output_values_ptr, outputs_.Count, | |||||
outputs_ptr, output_values_ptr, outputs_.Count, | |||||
targets_ptr, targets_.Count, | targets_ptr, targets_.Count, | ||||
IntPtr.Zero, s); | IntPtr.Zero, s); | ||||
@@ -90,4 +94,4 @@ namespace TensorFlowNET.UnitTest | |||||
ResetOutputValues(); | ResetOutputValues(); | ||||
} | } | ||||
} | } | ||||
} | |||||
} |
@@ -207,7 +207,7 @@ namespace TensorFlowNET.UnitTest | |||||
public void ImportGraphDef() | public void ImportGraphDef() | ||||
{ | { | ||||
var s = new Status(); | var s = new Status(); | ||||
var graph = new Graph(); | |||||
var graph = new Graph().as_default(); | |||||
// Create a simple graph. | // Create a simple graph. | ||||
c_test_util.Placeholder(graph, s); | c_test_util.Placeholder(graph, s); | ||||
@@ -221,7 +221,7 @@ namespace TensorFlowNET.UnitTest | |||||
// Import it, with a prefix, in a fresh graph. | // Import it, with a prefix, in a fresh graph. | ||||
graph.Dispose(); | graph.Dispose(); | ||||
graph = new Graph(); | |||||
graph = new Graph().as_default(); | |||||
var opts = c_api.TF_NewImportGraphDefOptions(); | var opts = c_api.TF_NewImportGraphDefOptions(); | ||||
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); | c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); | ||||
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); | c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); | ||||
@@ -359,7 +359,7 @@ namespace TensorFlowNET.UnitTest | |||||
public void ImportGraphDef_WithReturnOutputs() | public void ImportGraphDef_WithReturnOutputs() | ||||
{ | { | ||||
var s = new Status(); | var s = new Status(); | ||||
var graph = new Graph(); | |||||
var graph = new Graph().as_default(); | |||||
// Create a graph with two nodes: x and 3 | // Create a graph with two nodes: x and 3 | ||||
c_test_util.Placeholder(graph, s); | c_test_util.Placeholder(graph, s); | ||||
@@ -375,7 +375,7 @@ namespace TensorFlowNET.UnitTest | |||||
// Import it in a fresh graph with return outputs. | // Import it in a fresh graph with return outputs. | ||||
graph.Dispose(); | graph.Dispose(); | ||||
graph = new Graph(); | |||||
graph = new Graph().as_default(); | |||||
var opts = new ImportGraphDefOptions(); | var opts = new ImportGraphDefOptions(); | ||||
opts.AddReturnOutput("feed", 0); | opts.AddReturnOutput("feed", 0); | ||||
opts.AddReturnOutput("scalar", 0); | opts.AddReturnOutput("scalar", 0); | ||||
@@ -4,6 +4,7 @@ using System.Runtime.InteropServices; | |||||
using FluentAssertions; | using FluentAssertions; | ||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.Util; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace TensorFlowNET.UnitTest | namespace TensorFlowNET.UnitTest | ||||
@@ -14,7 +15,7 @@ namespace TensorFlowNET.UnitTest | |||||
[TestMethod] | [TestMethod] | ||||
public void SessionCreation() | public void SessionCreation() | ||||
{ | { | ||||
tf.Session(); //create one to increase next id to 1. | |||||
ops.uid(); //increment id by one | |||||
MultiThreadedUnitTestExecuter.Run(8, Core); | MultiThreadedUnitTestExecuter.Run(8, Core); | ||||
@@ -23,6 +24,28 @@ namespace TensorFlowNET.UnitTest | |||||
{ | { | ||||
tf.peak_default_graph().Should().BeNull(); | tf.peak_default_graph().Should().BeNull(); | ||||
using (var sess = tf.Session()) | |||||
{ | |||||
var default_graph = tf.peak_default_graph(); | |||||
var sess_graph = sess.GetPrivate<Graph>("_graph"); | |||||
sess_graph.Should().NotBeNull(); | |||||
default_graph.Should().NotBeNull() | |||||
.And.BeEquivalentTo(sess_graph); | |||||
} | |||||
} | |||||
} | |||||
[TestMethod] | |||||
public void SessionCreation_x2() | |||||
{ | |||||
ops.uid(); //increment id by one | |||||
MultiThreadedUnitTestExecuter.Run(16, Core); | |||||
//the core method | |||||
void Core(int tid) | |||||
{ | |||||
tf.peak_default_graph().Should().BeNull(); | |||||
//tf.Session created an other graph | //tf.Session created an other graph | ||||
using (var sess = tf.Session()) | using (var sess = tf.Session()) | ||||
{ | { | ||||
@@ -38,7 +61,7 @@ namespace TensorFlowNET.UnitTest | |||||
[TestMethod] | [TestMethod] | ||||
public void GraphCreation() | public void GraphCreation() | ||||
{ | { | ||||
tf.Graph(); //create one to increase next id to 1. | |||||
ops.uid(); //increment id by one | |||||
MultiThreadedUnitTestExecuter.Run(8, Core); | MultiThreadedUnitTestExecuter.Run(8, Core); | ||||
@@ -47,7 +70,7 @@ namespace TensorFlowNET.UnitTest | |||||
{ | { | ||||
tf.peak_default_graph().Should().BeNull(); | tf.peak_default_graph().Should().BeNull(); | ||||
var beforehand = tf.get_default_graph(); //this should create default automatically. | var beforehand = tf.get_default_graph(); //this should create default automatically. | ||||
beforehand.graph_key.Should().NotContain("0", "Already created a graph in an other thread."); | |||||
beforehand.graph_key.Should().NotContain("-0/", "Already created a graph in an other thread."); | |||||
tf.peak_default_graph().Should().NotBeNull(); | tf.peak_default_graph().Should().NotBeNull(); | ||||
using (var sess = tf.Session()) | using (var sess = tf.Session()) | ||||
@@ -67,5 +90,174 @@ namespace TensorFlowNET.UnitTest | |||||
} | } | ||||
} | } | ||||
} | } | ||||
[TestMethod] | |||||
public void Marshal_AllocHGlobal() | |||||
{ | |||||
MultiThreadedUnitTestExecuter.Run(8, Core); | |||||
//the core method | |||||
void Core(int tid) | |||||
{ | |||||
for (int i = 0; i < 100; i++) | |||||
{ | |||||
Marshal.FreeHGlobal(Marshal.AllocHGlobal(sizeof(int))); | |||||
} | |||||
} | |||||
} | |||||
[TestMethod] | |||||
public void TensorCreation() | |||||
{ | |||||
//lock (Locks.ProcessWide) | |||||
// tf.Session(); //create one to increase next id to 1. | |||||
MultiThreadedUnitTestExecuter.Run(8, Core); | |||||
//the core method | |||||
void Core(int tid) | |||||
{ | |||||
using (var sess = tf.Session()) | |||||
{ | |||||
Tensor t = null; | |||||
for (int i = 0; i < 100; i++) | |||||
{ | |||||
t = new Tensor(1); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
[TestMethod] | |||||
public void TensorCreation_Array() | |||||
{ | |||||
//lock (Locks.ProcessWide) | |||||
// tf.Session(); //create one to increase next id to 1. | |||||
MultiThreadedUnitTestExecuter.Run(8, Core); | |||||
//the core method | |||||
void Core(int tid) | |||||
{ | |||||
//tf.Session created an other graph | |||||
using (var sess = tf.Session()) | |||||
{ | |||||
Tensor t = null; | |||||
for (int i = 0; i < 100; i++) | |||||
{ | |||||
t = new Tensor(new int[] {1, 2, 3}); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
[TestMethod] | |||||
public void TensorCreation_Undressed() | |||||
{ | |||||
//lock (Locks.ProcessWide) | |||||
// tf.Session(); //create one to increase next id to 1. | |||||
MultiThreadedUnitTestExecuter.Run(8, Core); | |||||
//the core method | |||||
unsafe void Core(int tid) | |||||
{ | |||||
using (var sess = tf.Session()) | |||||
{ | |||||
Tensor t = null; | |||||
for (int i = 0; i < 100; i++) | |||||
{ | |||||
var v = (int*) Marshal.AllocHGlobal(sizeof(int)); | |||||
c_api.DeallocatorArgs _deallocatorArgs = new c_api.DeallocatorArgs(); | |||||
var handle = c_api.TF_NewTensor(typeof(int).as_dtype(), dims: new long[0], num_dims: 0, | |||||
data: (IntPtr) v, len: (UIntPtr) sizeof(int), | |||||
deallocator: (IntPtr data, IntPtr size, ref c_api.DeallocatorArgs args) => Marshal.FreeHGlobal(data), | |||||
ref _deallocatorArgs); | |||||
c_api.TF_DeleteTensor(handle); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
[TestMethod] | |||||
public void SessionRun() | |||||
{ | |||||
MultiThreadedUnitTestExecuter.Run(8, Core); | |||||
//the core method | |||||
void Core(int tid) | |||||
{ | |||||
tf.peak_default_graph().Should().BeNull(); | |||||
//graph is created automatically to perform create these operations | |||||
var a1 = tf.constant(new[] {2f}, shape: new[] {1}); | |||||
var a2 = tf.constant(new[] {3f}, shape: new[] {1}); | |||||
var math = a1 + a2; | |||||
for (int i = 0; i < 100; i++) | |||||
{ | |||||
using (var sess = tf.Session()) | |||||
{ | |||||
sess.run(math).GetAtIndex<float>(0).Should().Be(5); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
[TestMethod] | |||||
public void SessionRun_InsideSession() | |||||
{ | |||||
MultiThreadedUnitTestExecuter.Run(8, Core); | |||||
//the core method | |||||
void Core(int tid) | |||||
{ | |||||
using (var sess = tf.Session()) | |||||
{ | |||||
tf.peak_default_graph().Should().NotBeNull(); | |||||
//graph is created automatically to perform create these operations | |||||
var a1 = tf.constant(new[] {2f}, shape: new[] {1}); | |||||
var a2 = tf.constant(new[] {3f}, shape: new[] {1}); | |||||
var math = a1 + a2; | |||||
var result = sess.run(math); | |||||
result[0].GetAtIndex<float>(0).Should().Be(5); | |||||
} | |||||
} | |||||
} | |||||
[TestMethod] | |||||
public void SessionRun_Initialization() | |||||
{ | |||||
MultiThreadedUnitTestExecuter.Run(8, Core); | |||||
//the core method | |||||
void Core(int tid) | |||||
{ | |||||
using (var sess = tf.Session()) | |||||
{ | |||||
tf.peak_default_graph().Should().NotBeNull(); | |||||
//graph is created automatically to perform create these operations | |||||
var a1 = tf.constant(new[] {2f}, shape: new[] {1}); | |||||
var a2 = tf.constant(new[] {3f}, shape: new[] {1}); | |||||
var math = a1 + a2; | |||||
} | |||||
} | |||||
} | |||||
[TestMethod] | |||||
public void SessionRun_Initialization_OutsideSession() | |||||
{ | |||||
MultiThreadedUnitTestExecuter.Run(8, Core); | |||||
//the core method | |||||
void Core(int tid) | |||||
{ | |||||
tf.peak_default_graph().Should().BeNull(); | |||||
//graph is created automatically to perform create these operations | |||||
var a1 = tf.constant(new[] {2f}, shape: new[] {1}); | |||||
var a2 = tf.constant(new[] {3f}, shape: new[] {1}); | |||||
var math = a1 + a2; | |||||
} | |||||
} | |||||
} | } | ||||
} | } |
@@ -8,6 +8,7 @@ using System.Text; | |||||
using FluentAssertions; | using FluentAssertions; | ||||
using Google.Protobuf; | using Google.Protobuf; | ||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.Util; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace TensorFlowNET.UnitTest | namespace TensorFlowNET.UnitTest | ||||
@@ -19,13 +20,13 @@ namespace TensorFlowNET.UnitTest | |||||
/// tensorflow\c\c_api_test.cc | /// tensorflow\c\c_api_test.cc | ||||
/// `TEST(CAPI, Session)` | /// `TEST(CAPI, Session)` | ||||
/// </summary> | /// </summary> | ||||
[TestMethod] | |||||
[TestMethod, Ignore] | |||||
public void Session() | public void Session() | ||||
{ | { | ||||
lock (this) | |||||
lock (Locks.ProcessWide) | |||||
{ | { | ||||
var s = new Status(); | var s = new Status(); | ||||
var graph = new Graph(); | |||||
var graph = new Graph().as_default(); | |||||
// Make a placeholder operation. | // Make a placeholder operation. | ||||
var feed = c_test_util.Placeholder(graph, s); | var feed = c_test_util.Placeholder(graph, s); | ||||
@@ -117,7 +117,7 @@ namespace TensorFlowNET.UnitTest | |||||
public void SetShape() | public void SetShape() | ||||
{ | { | ||||
var s = new Status(); | var s = new Status(); | ||||
var graph = new Graph(); | |||||
var graph = new Graph().as_default(); | |||||
var feed = c_test_util.Placeholder(graph, s); | var feed = c_test_util.Placeholder(graph, s); | ||||
var feed_out_0 = new TF_Output(feed, 0); | var feed_out_0 = new TF_Output(feed, 0); | ||||
@@ -12,42 +12,51 @@ namespace TensorFlowNET.UnitTest | |||||
{ | { | ||||
public static Operation Add(Operation l, Operation r, Graph graph, Status s, string name = "add") | public static Operation Add(Operation l, Operation r, Graph graph, Status s, string name = "add") | ||||
{ | { | ||||
var desc = c_api.TF_NewOperation(graph, "AddN", name); | |||||
var inputs = new TF_Output[] | |||||
lock (Locks.ProcessWide) | |||||
{ | { | ||||
new TF_Output(l, 0), | |||||
new TF_Output(r, 0), | |||||
}; | |||||
var desc = c_api.TF_NewOperation(graph, "AddN", name); | |||||
c_api.TF_AddInputList(desc, inputs, inputs.Length); | |||||
var inputs = new TF_Output[] | |||||
{ | |||||
new TF_Output(l, 0), | |||||
new TF_Output(r, 0), | |||||
}; | |||||
var op = c_api.TF_FinishOperation(desc, s); | |||||
s.Check(); | |||||
c_api.TF_AddInputList(desc, inputs, inputs.Length); | |||||
return op; | |||||
var op = c_api.TF_FinishOperation(desc, s); | |||||
s.Check(); | |||||
return op; | |||||
} | |||||
} | } | ||||
[SuppressMessage("ReSharper", "RedundantAssignment")] | [SuppressMessage("ReSharper", "RedundantAssignment")] | ||||
public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) | public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) | ||||
{ | { | ||||
using (var buffer = new Buffer()) | |||||
lock (Locks.ProcessWide) | |||||
{ | { | ||||
c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); | |||||
attr_value = AttrValue.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
} | |||||
using (var buffer = new Buffer()) | |||||
{ | |||||
c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s); | |||||
attr_value = AttrValue.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
} | |||||
return s.Code == TF_Code.TF_OK; | |||||
return s.Code == TF_Code.TF_OK; | |||||
} | |||||
} | } | ||||
public static GraphDef GetGraphDef(Graph graph) | public static GraphDef GetGraphDef(Graph graph) | ||||
{ | { | ||||
using (var s = new Status()) | |||||
using (var buffer = new Buffer()) | |||||
lock (Locks.ProcessWide) | |||||
{ | { | ||||
c_api.TF_GraphToGraphDef(graph, buffer, s); | |||||
s.Check(); | |||||
return GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
using (var s = new Status()) | |||||
using (var buffer = new Buffer()) | |||||
{ | |||||
c_api.TF_GraphToGraphDef(graph, buffer, s); | |||||
s.Check(); | |||||
return GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream()); | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -58,6 +67,7 @@ namespace TensorFlowNET.UnitTest | |||||
{ | { | ||||
return false; | return false; | ||||
} | } | ||||
bool found_t = false; | bool found_t = false; | ||||
bool found_n = false; | bool found_n = false; | ||||
foreach (var attr in node_def.Attr) | foreach (var attr in node_def.Attr) | ||||
@@ -67,19 +77,16 @@ namespace TensorFlowNET.UnitTest | |||||
if (attr.Value.Type == DataType.DtInt32) | if (attr.Value.Type == DataType.DtInt32) | ||||
{ | { | ||||
found_t = true; | found_t = true; | ||||
} | |||||
else | |||||
} else | |||||
{ | { | ||||
return false; | return false; | ||||
} | } | ||||
} | |||||
else if (attr.Key == "N") | |||||
} else if (attr.Key == "N") | |||||
{ | { | ||||
if (attr.Value.I == n) | if (attr.Value.I == n) | ||||
{ | { | ||||
found_n = true; | found_n = true; | ||||
} | |||||
else | |||||
} else | |||||
{ | { | ||||
return false; | return false; | ||||
} | } | ||||
@@ -92,7 +99,7 @@ namespace TensorFlowNET.UnitTest | |||||
public static bool IsNeg(NodeDef node_def, string input) | public static bool IsNeg(NodeDef node_def, string input) | ||||
{ | { | ||||
return node_def.Op == "Neg" && node_def.Name == "neg" && | return node_def.Op == "Neg" && node_def.Name == "neg" && | ||||
node_def.Input.Count == 1 && node_def.Input[0] == input; | |||||
node_def.Input.Count == 1 && node_def.Input[0] == input; | |||||
} | } | ||||
public static bool IsPlaceholder(NodeDef node_def) | public static bool IsPlaceholder(NodeDef node_def) | ||||
@@ -111,13 +118,11 @@ namespace TensorFlowNET.UnitTest | |||||
if (attr.Value.Type == DataType.DtInt32) | if (attr.Value.Type == DataType.DtInt32) | ||||
{ | { | ||||
found_dtype = true; | found_dtype = true; | ||||
} | |||||
else | |||||
} else | |||||
{ | { | ||||
return false; | return false; | ||||
} | } | ||||
} | |||||
else if (attr.Key == "shape") | |||||
} else if (attr.Key == "shape") | |||||
{ | { | ||||
found_shape = true; | found_shape = true; | ||||
} | } | ||||
@@ -132,72 +137,82 @@ namespace TensorFlowNET.UnitTest | |||||
{ | { | ||||
return false; | return false; | ||||
} | } | ||||
bool found_dtype = false; | bool found_dtype = false; | ||||
bool found_value = false; | bool found_value = false; | ||||
foreach (var attr in node_def.Attr) { | |||||
foreach (var attr in node_def.Attr) | |||||
{ | |||||
if (attr.Key == "dtype") | if (attr.Key == "dtype") | ||||
{ | { | ||||
if (attr.Value.Type == DataType.DtInt32) | if (attr.Value.Type == DataType.DtInt32) | ||||
{ | { | ||||
found_dtype = true; | found_dtype = true; | ||||
} | |||||
else | |||||
} else | |||||
{ | { | ||||
return false; | return false; | ||||
} | } | ||||
} | |||||
else if (attr.Key == "value") | |||||
} else if (attr.Key == "value") | |||||
{ | { | ||||
if (attr.Value.Tensor != null && | if (attr.Value.Tensor != null && | ||||
attr.Value.Tensor.IntVal.Count == 1 && | attr.Value.Tensor.IntVal.Count == 1 && | ||||
attr.Value.Tensor.IntVal[0] == v) | attr.Value.Tensor.IntVal[0] == v) | ||||
{ | { | ||||
found_value = true; | found_value = true; | ||||
} | |||||
else | |||||
} else | |||||
{ | { | ||||
return false; | return false; | ||||
} | } | ||||
} | } | ||||
} | } | ||||
return found_dtype && found_value; | return found_dtype && found_value; | ||||
} | } | ||||
public static Operation Neg(Operation n, Graph graph, Status s, string name = "neg") | public static Operation Neg(Operation n, Graph graph, Status s, string name = "neg") | ||||
{ | { | ||||
OperationDescription desc = c_api.TF_NewOperation(graph, "Neg", name); | |||||
var neg_input = new TF_Output(n, 0); | |||||
c_api.TF_AddInput(desc, neg_input); | |||||
var op = c_api.TF_FinishOperation(desc, s); | |||||
s.Check(); | |||||
lock (Locks.ProcessWide) | |||||
{ | |||||
OperationDescription desc = c_api.TF_NewOperation(graph, "Neg", name); | |||||
var neg_input = new TF_Output(n, 0); | |||||
c_api.TF_AddInput(desc, neg_input); | |||||
var op = c_api.TF_FinishOperation(desc, s); | |||||
s.Check(); | |||||
return op; | |||||
return op; | |||||
} | |||||
} | } | ||||
public static Operation Placeholder(Graph graph, Status s, string name = "feed", TF_DataType dtype = TF_DataType.TF_INT32, long[] dims = null) | public static Operation Placeholder(Graph graph, Status s, string name = "feed", TF_DataType dtype = TF_DataType.TF_INT32, long[] dims = null) | ||||
{ | { | ||||
var desc = c_api.TF_NewOperation(graph, "Placeholder", name); | |||||
c_api.TF_SetAttrType(desc, "dtype", dtype); | |||||
if (dims != null) | |||||
lock (Locks.ProcessWide) | |||||
{ | { | ||||
c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length); | |||||
} | |||||
var op = c_api.TF_FinishOperation(desc, s); | |||||
s.Check(); | |||||
var desc = c_api.TF_NewOperation(graph, "Placeholder", name); | |||||
c_api.TF_SetAttrType(desc, "dtype", dtype); | |||||
if (dims != null) | |||||
{ | |||||
c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length); | |||||
} | |||||
var op = c_api.TF_FinishOperation(desc, s); | |||||
s.Check(); | |||||
return op; | |||||
return op; | |||||
} | |||||
} | } | ||||
public static Operation Const(Tensor t, Graph graph, Status s, string name) | public static Operation Const(Tensor t, Graph graph, Status s, string name) | ||||
{ | { | ||||
var desc = c_api.TF_NewOperation(graph, "Const", name); | |||||
c_api.TF_SetAttrTensor(desc, "value", t, s); | |||||
s.Check(); | |||||
c_api.TF_SetAttrType(desc, "dtype", t.dtype); | |||||
var op = c_api.TF_FinishOperation(desc, s); | |||||
s.Check(); | |||||
return op; | |||||
lock (Locks.ProcessWide) | |||||
{ | |||||
var desc = c_api.TF_NewOperation(graph, "Const", name); | |||||
c_api.TF_SetAttrTensor(desc, "value", t, s); | |||||
s.Check(); | |||||
c_api.TF_SetAttrType(desc, "dtype", t.dtype); | |||||
var op = c_api.TF_FinishOperation(desc, s); | |||||
s.Check(); | |||||
return op; | |||||
} | |||||
} | } | ||||
public static Operation ScalarConst(int v, Graph graph, Status s, string name = "scalar") | public static Operation ScalarConst(int v, Graph graph, Status s, string name = "scalar") | ||||
@@ -205,4 +220,4 @@ namespace TensorFlowNET.UnitTest | |||||
return Const(new Tensor(v), graph, s, name); | return Const(new Tensor(v), graph, s, name); | ||||
} | } | ||||
} | } | ||||
} | |||||
} |