Browse Source

Merge pull request #371 from SciSharp/multithreading

Multithreading Support and fixed critical heap corruption
tags/v0.12
Haiping GitHub 6 years ago
parent
commit
641a105bee
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 3054 additions and 663 deletions
  1. +2
    -0
      TensorFlow.NET.sln.DotSettings
  2. +9
    -0
      src/TensorFlowNET.Core/APIs/c_api.cs
  3. +14
    -2
      src/TensorFlowNET.Core/APIs/tf.graph.cs
  4. +3
    -3
      src/TensorFlowNET.Core/Gradients/gradients_util.cs
  5. +16
    -3
      src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs
  6. +8
    -6
      src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
  7. +102
    -105
      src/TensorFlowNET.Core/Operations/Operation.cs
  8. +1
    -1
      src/TensorFlowNET.Core/Operations/gen_ops.cs
  9. +107
    -97
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  10. +6
    -14
      src/TensorFlowNET.Core/Sessions/Session.cs
  11. +27
    -0
      src/TensorFlowNET.Core/Tensors/AllocationType.cs
  12. +203
    -257
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  13. +26
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  14. +73
    -0
      src/TensorFlowNET.Core/Tensors/c_api.tensor.cs
  15. +21
    -0
      src/TensorFlowNET.Core/Util/Locks.cs
  16. +41
    -36
      src/TensorFlowNET.Core/ops.cs
  17. +10
    -1
      src/TensorFlowNET.Core/tensorflow.cs
  18. +1
    -1
      test/TensorFlowNET.Examples/ImageProcessing/ImageRecognitionInception.cs
  19. +1
    -1
      test/TensorFlowNET.Examples/ImageProcessing/InceptionArchGoogLeNet.cs
  20. +2
    -2
      test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs
  21. +1
    -1
      test/TensorFlowNET.UnitTest/CApiGradientsTest.cs
  22. +10
    -6
      test/TensorFlowNET.UnitTest/CSession.cs
  23. +4
    -4
      test/TensorFlowNET.UnitTest/GraphTest.cs
  24. +263
    -0
      test/TensorFlowNET.UnitTest/MultithreadingTests.cs
  25. +5
    -4
      test/TensorFlowNET.UnitTest/SessionTest.cs
  26. +2
    -0
      test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj
  27. +2
    -0
      test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj.DotSettings
  28. +44
    -57
      test/TensorFlowNET.UnitTest/TensorTest.cs
  29. +173
    -0
      test/TensorFlowNET.UnitTest/Utilities/MultiThreadedUnitTestExecuter.cs
  30. +914
    -0
      test/TensorFlowNET.UnitTest/Utilities/PrivateObject.cs
  31. +314
    -0
      test/TensorFlowNET.UnitTest/Utilities/PrivateObjectExtensions.cs
  32. +572
    -0
      test/TensorFlowNET.UnitTest/Utilities/PrivateType.cs
  33. +1
    -1
      test/TensorFlowNET.UnitTest/VariableTest.cs
  34. +76
    -61
      test/TensorFlowNET.UnitTest/c_test_util.cs

+ 2
- 0
TensorFlow.NET.sln.DotSettings View File

@@ -0,0 +1,2 @@
<wpf:ResourceDictionary xml:space="preserve" xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml" xmlns:s="clr-namespace:System;assembly=mscorlib" xmlns:ss="urn:shemas-jetbrains-com:settings-storage-xaml" xmlns:wpf="http://schemas.microsoft.com/winfx/2006/xaml/presentation">
<s:Boolean x:Key="/Default/UserDictionary/Words/=Tensorflow/@EntryIndexedValue">True</s:Boolean></wpf:ResourceDictionary>

+ 9
- 0
src/TensorFlowNET.Core/APIs/c_api.cs View File

@@ -54,6 +54,15 @@ namespace Tensorflow


public struct DeallocatorArgs public struct DeallocatorArgs
{ {
internal static unsafe c_api.DeallocatorArgs* EmptyPtr;
internal static unsafe IntPtr Empty;

static unsafe DeallocatorArgs()
{
Empty = new IntPtr(EmptyPtr = (DeallocatorArgs*) Marshal.AllocHGlobal(Marshal.SizeOf<DeallocatorArgs>()));
*EmptyPtr = new DeallocatorArgs() {gc_handle = IntPtr.Zero, deallocator_called = false};
}

public bool deallocator_called; public bool deallocator_called;
public IntPtr gc_handle; public IntPtr gc_handle;
} }


+ 14
- 2
src/TensorFlowNET.Core/APIs/tf.graph.cs View File

@@ -29,7 +29,19 @@ namespace Tensorflow
return ops.get_default_graph(); return ops.get_default_graph();
} }


public Graph Graph()
/// <summary>
/// Equivalent to <see cref="get_default_graph"/> but does not create a new graph if it there is none.
/// </summary>
public Graph peak_default_graph()
{
return ops.default_graph_stack.peak_controller();
}

/// <summary>
/// Creates a new graph.
/// </summary>
///<remarks>Has no interaction with graph defaulting. Equivalent to new Graph();</remarks>
public Graph Graph()
=> new Graph(); => new Graph();
} }
}
}

+ 3
- 3
src/TensorFlowNET.Core/Gradients/gradients_util.cs View File

@@ -61,7 +61,7 @@ namespace Tensorflow
string grad_scope = scope; string grad_scope = scope;
// Get a uid for this call to gradients that can be used to help // Get a uid for this call to gradients that can be used to help
// cluster ops for compilation. // cluster ops for compilation.
var gradient_uid = ops.get_default_graph().unique_name("uid");
var gradient_uid = curr_graph.unique_name("uid");
ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name: "y"); ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name: "y");
xs = ops.internal_convert_n_to_tensor_or_indexed_slices(xs, name: "x", as_ref: true); xs = ops.internal_convert_n_to_tensor_or_indexed_slices(xs, name: "x", as_ref: true);
grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, gradient_uid); grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, gradient_uid);
@@ -80,7 +80,7 @@ namespace Tensorflow
var to_ops = ys.Select(x => x.op).ToList(); var to_ops = ys.Select(x => x.op).ToList();
var from_ops = xs.Select(x => x.op).ToList(); var from_ops = xs.Select(x => x.op).ToList();
var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList(); var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList();
(var reachable_to_ops, var pending_count, var loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs);
var (reachable_to_ops, pending_count, loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List<object>(), xs);


foreach (var (y, grad_y) in zip(ys, grad_ys)) foreach (var (y, grad_y) in zip(ys, grad_ys))
_SetGrad(grads, y, grad_y); _SetGrad(grads, y, grad_y);
@@ -168,7 +168,7 @@ namespace Tensorflow
{ {
if (in_grad != null) if (in_grad != null)
{ {
if (in_grad is Tensor &&
if (!(in_grad is null) &&
in_grad.Tag == null && // maybe a IndexedSlice in_grad.Tag == null && // maybe a IndexedSlice
t_in.dtype != TF_DataType.TF_RESOURCE) t_in.dtype != TF_DataType.TF_RESOURCE)
{ {


+ 16
- 3
src/TensorFlowNET.Core/Graphs/DefaultGraphStack.cs View File

@@ -21,11 +21,10 @@ using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
{ {

/// <summary> /// <summary>
/// Serves as a stack for determining current default graph. /// Serves as a stack for determining current default graph.
/// </summary> /// </summary>
public class DefaultGraphStack
public class DefaultGraphStack
{ {
private readonly List<StackModel> _stack = new List<StackModel>(); private readonly List<StackModel> _stack = new List<StackModel>();


@@ -40,7 +39,7 @@ namespace Tensorflow


public Graph get_controller() public Graph get_controller()
{ {
if (_stack.Count(x => x.IsDefault) == 0)
if (_stack.Count == 0 || _stack.Count(x => x.IsDefault) == 0)
_stack.Add(new StackModel {Graph = tf.Graph(), IsDefault = true}); _stack.Add(new StackModel {Graph = tf.Graph(), IsDefault = true});
for (var i = _stack.Count - 1; i >= 0; i--) for (var i = _stack.Count - 1; i >= 0; i--)
{ {
@@ -52,6 +51,20 @@ namespace Tensorflow
throw new TensorflowException("Unable to find a default graph"); throw new TensorflowException("Unable to find a default graph");
} }


public Graph peak_controller()
{
if (_stack.Count == 0 || _stack.Count(x => x.IsDefault) == 0)
return null;
for (var i = _stack.Count - 1; i >= 0; i--)
{
var x = _stack[i];
if (x.IsDefault)
return x.Graph;
}

return null;
}

public bool remove(Graph g) public bool remove(Graph g)
{ {
if (_stack.Count == 0) if (_stack.Count == 0)


+ 8
- 6
src/TensorFlowNET.Core/Graphs/Graph.Operation.cs View File

@@ -54,19 +54,21 @@ namespace Tensorflow
var handle = return_oper_handle.node + tf_op_size * i; var handle = return_oper_handle.node + tf_op_size * i;
return_opers[i] = new Operation(*(IntPtr*)handle); return_opers[i] = new Operation(*(IntPtr*)handle);
} }
}
}
return return_opers; return return_opers;
} }


public Operation OperationByName(string operName) public Operation OperationByName(string operName)
{ {
var handle = c_api.TF_GraphOperationByName(_handle, operName); var handle = c_api.TF_GraphOperationByName(_handle, operName);
if(graph_key != tf.get_default_graph().graph_key)
{
Console.WriteLine($"Current graph is not default graph.");
// throw new ValueError($"Current graph is not default graph.");
var defaultKey = tf.get_default_graph().graph_key;
if (graph_key != defaultKey)
{
//Console.WriteLine($"Current graph is not default graph.");
throw new ValueError($"Current graph is not default graph. Default Graph Key: {defaultKey}, Current Graph Key: {graph_key}");
} }

return new Operation(handle, g: this); return new Operation(handle, g: this);
} }




+ 102
- 105
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -22,58 +22,54 @@ using System.Linq;
using Tensorflow.Util; using Tensorflow.Util;


namespace Tensorflow namespace Tensorflow
{
/// <summary>
/// Represents a graph node that performs computation on tensors.
///
/// An `Operation` is a node in a TensorFlow `Graph` that takes zero or
/// more `Tensor` objects as input, and produces zero or more `Tensor`
/// objects as output. Objects of type `Operation` are created by
/// calling an op constructor(such as `tf.matmul`)
/// or `tf.Graph.create_op`.
///
/// For example `c = tf.matmul(a, b)` creates an `Operation` of type
/// "MatMul" that takes tensors `a` and `b` as input, and produces `c`
/// as output.
///
/// After the graph has been launched in a session, an `Operation` can
/// be executed by passing it to
/// `tf.Session.run`.
/// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`.
{
/// <summary>
/// Represents a graph node that performs computation on tensors.
///
/// An `Operation` is a node in a TensorFlow `Graph` that takes zero or
/// more `Tensor` objects as input, and produces zero or more `Tensor`
/// objects as output. Objects of type `Operation` are created by
/// calling an op constructor(such as `tf.matmul`)
/// or `tf.Graph.create_op`.
///
/// For example `c = tf.matmul(a, b)` creates an `Operation` of type
/// "MatMul" that takes tensors `a` and `b` as input, and produces `c`
/// as output.
///
/// After the graph has been launched in a session, an `Operation` can
/// be executed by passing it to
/// `tf.Session.run`.
/// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`.
/// </summary> /// </summary>
public partial class Operation : ITensorOrOperation public partial class Operation : ITensorOrOperation
{ {
private readonly IntPtr _handle; // _c_op in python private readonly IntPtr _handle; // _c_op in python
private readonly IntPtr _operDesc;
private readonly IntPtr _operDesc;
private readonly Graph _graph;
private NodeDef _node_def;


private Graph _graph;
public string type => OpType; public string type => OpType;

public Graph graph => _graph; public Graph graph => _graph;
public int _id => _id_value; public int _id => _id_value;
public int _id_value; public int _id_value;
public Operation op => this; public Operation op => this;

public TF_DataType dtype => TF_DataType.DtInvalid; public TF_DataType dtype => TF_DataType.DtInvalid;

public string name => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationName(_handle)); public string name => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationName(_handle));
public string OpType => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationOpType(_handle)); public string OpType => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationOpType(_handle));
public string Device => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); public string Device => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationDevice(_handle));


private NodeDef _node_def;
public NodeDef node_def public NodeDef node_def
{ {
get get
{ {
if(_node_def == null)
if (_node_def == null)
_node_def = GetNodeDef(); _node_def = GetNodeDef();


return _node_def; return _node_def;
} }
} }


public Operation(IntPtr handle, Graph g=null)
public Operation(IntPtr handle, Graph g = null)
{ {
if (handle == IntPtr.Zero) if (handle == IntPtr.Zero)
return; return;
@@ -97,14 +93,15 @@ namespace Tensorflow


_operDesc = c_api.TF_NewOperation(g, opType, oper_name); _operDesc = c_api.TF_NewOperation(g, opType, oper_name);
c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32);
using (var status = new Status())
{
_handle = c_api.TF_FinishOperation(_operDesc, status);
status.Check(true);
}
// Dict mapping op name to file and line information for op colocation
// context managers.
lock (Locks.ProcessWide)
using (var status = new Status())
{
_handle = c_api.TF_FinishOperation(_operDesc, status);
status.Check(true);
}

// Dict mapping op name to file and line information for op colocation
// context managers.
_control_flow_context = graph._get_control_flow_context(); _control_flow_context = graph._get_control_flow_context();
} }


@@ -133,9 +130,9 @@ namespace Tensorflow


// Build the list of control inputs. // Build the list of control inputs.
var control_input_ops = new List<Operation>(); var control_input_ops = new List<Operation>();
if(control_inputs != null)
if (control_inputs != null)
{ {
foreach(var c in control_inputs)
foreach (var c in control_inputs)
{ {
switch (c) switch (c)
{ {
@@ -196,15 +193,13 @@ namespace Tensorflow
{ {
if (!string.IsNullOrEmpty(input_arg.NumberAttr)) if (!string.IsNullOrEmpty(input_arg.NumberAttr))
{ {
input_len = (int)attrs[input_arg.NumberAttr].I;
input_len = (int) attrs[input_arg.NumberAttr].I;
is_sequence = true; is_sequence = true;
}
else if (!string.IsNullOrEmpty(input_arg.TypeListAttr))
} else if (!string.IsNullOrEmpty(input_arg.TypeListAttr))
{ {
input_len = attrs[input_arg.TypeListAttr].List.Type.Count; input_len = attrs[input_arg.TypeListAttr].List.Type.Count;
is_sequence = true; is_sequence = true;
}
else
} else
{ {
input_len = 1; input_len = 1;
is_sequence = false; is_sequence = false;
@@ -225,22 +220,21 @@ namespace Tensorflow
{ {
AttrValue x = null; AttrValue x = null;


using (var status = new Status())
using (var buf = new Buffer())
{
unsafe
lock (Locks.ProcessWide)
using (var status = new Status())
using (var buf = new Buffer())
{ {
c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status);
status.Check(true); status.Check(true);

x = AttrValue.Parser.ParseFrom(buf.MemoryBlock.Stream()); x = AttrValue.Parser.ParseFrom(buf.MemoryBlock.Stream());
} }
}


string oneof_value = x.ValueCase.ToString(); string oneof_value = x.ValueCase.ToString();
if (string.IsNullOrEmpty(oneof_value)) if (string.IsNullOrEmpty(oneof_value))
return null; return null;


if(oneof_value == "list")
if (oneof_value == "list")
throw new NotImplementedException($"Unsupported field type in {x.ToString()}"); throw new NotImplementedException($"Unsupported field type in {x.ToString()}");


if (oneof_value == "type") if (oneof_value == "type")
@@ -259,60 +253,63 @@ namespace Tensorflow


private NodeDef GetNodeDef() private NodeDef GetNodeDef()
{ {
using (var s = new Status())
using (var buffer = new Buffer())
{
c_api.TF_OperationToNodeDef(_handle, buffer, s);
s.Check();
return NodeDef.Parser.ParseFrom(buffer.MemoryBlock.Stream());
}
}
/// <summary>
/// Update the input to this operation at the given index.
///
/// NOTE: This is for TF internal use only.Please don't use it.
/// </summary>
/// <param name="index">the index of the input to update.</param>
/// <param name="tensor"> the Tensor to be used as the input at the given index.</param>
public void _update_input(int index, Tensor tensor)
{
_assert_same_graph(tensor);
var input = _tf_input(index);
var output = tensor._as_tf_output();
// Reset cached inputs.
_inputs = null;
// after the c_api call next time _inputs is accessed
// the updated inputs are reloaded from the c_api
using (var status = new Status())
{
c_api.UpdateEdge(_graph, output, input, status);
//var updated_inputs = inputs;
status.Check();
}
}
private void _assert_same_graph(Tensor tensor)
{
//TODO: implement
}
/// <summary>
/// Create and return a new TF_Output for output_idx'th output of this op.
/// </summary>
public TF_Output _tf_output(int output_idx)
{
return new TF_Output(op, output_idx);
}
/// <summary>
/// Create and return a new TF_Input for input_idx'th input of this op.
/// </summary>
public TF_Input _tf_input(int input_idx)
{
return new TF_Input(op, input_idx);
}
}
}
lock (Locks.ProcessWide)
using (var s = new Status())
using (var buffer = new Buffer())
{
c_api.TF_OperationToNodeDef(_handle, buffer, s);
s.Check();

return NodeDef.Parser.ParseFrom(buffer.MemoryBlock.Stream());
}
}

/// <summary>
/// Update the input to this operation at the given index.
///
/// NOTE: This is for TF internal use only.Please don't use it.
/// </summary>
/// <param name="index">the index of the input to update.</param>
/// <param name="tensor"> the Tensor to be used as the input at the given index.</param>
public void _update_input(int index, Tensor tensor)
{
_assert_same_graph(tensor);

var input = _tf_input(index);
var output = tensor._as_tf_output();

// Reset cached inputs.
_inputs = null;
// after the c_api call next time _inputs is accessed
// the updated inputs are reloaded from the c_api
lock (Locks.ProcessWide)
using (var status = new Status())
{
c_api.UpdateEdge(_graph, output, input, status);
//var updated_inputs = inputs;
status.Check();
}
}

private void _assert_same_graph(Tensor tensor)
{
//TODO: implement
}

/// <summary>
/// Create and return a new TF_Output for output_idx'th output of this op.
/// </summary>
public TF_Output _tf_output(int output_idx)
{
return new TF_Output(op, output_idx);
}

/// <summary>
/// Create and return a new TF_Input for input_idx'th input of this op.
/// </summary>
public TF_Input _tf_input(int input_idx)
{
return new TF_Input(op, input_idx);
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Operations/gen_ops.cs View File

@@ -7730,7 +7730,7 @@ namespace Tensorflow.Operations
/// </returns> /// </returns>
/// <remarks> /// <remarks>
/// RFC 4180 format is expected for the CSV records. /// RFC 4180 format is expected for the CSV records.
/// (https://tools.ietf.org/html/rfc4180)
/// (https://tools.ietensorflow.org/html/rfc4180)
/// Note that we allow leading and trailing spaces with int or float field. /// Note that we allow leading and trailing spaces with int or float field.
/// </remarks> /// </remarks>
public static Tensor[] decode_c_s_v (Tensor records, Tensor[] record_defaults, string field_delim = null, bool? use_quote_delim = null, string na_value = null, int[] select_cols = null, string name = "DecodeCSV") public static Tensor[] decode_c_s_v (Tensor records, Tensor[] record_defaults, string field_delim = null, bool? use_quote_delim = null, string na_value = null, int[] select_cols = null, string name = "DecodeCSV")


+ 107
- 97
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -36,23 +36,20 @@ namespace Tensorflow
protected byte[] _target; protected byte[] _target;
public Graph graph => _graph; public Graph graph => _graph;


public BaseSession(string target = "", Graph g = null, SessionOptions opts = null)
public BaseSession(string target = "", Graph g = null, SessionOptions opts = null, Status status = null)
{ {
_graph = g is null ? ops.get_default_graph() : g;
_graph = g ?? ops.get_default_graph();
_graph.as_default(); _graph.as_default();
_target = UTF8Encoding.UTF8.GetBytes(target);
_target = Encoding.UTF8.GetBytes(target);


SessionOptions newOpts = opts ?? new SessionOptions();
SessionOptions lopts = opts ?? new SessionOptions();


var status = new Status();

_handle = c_api.TF_NewSession(_graph, opts ?? newOpts, status);

// dispose opts only if not provided externally.
if (opts == null)
newOpts.Dispose();

status.Check(true);
lock (Locks.ProcessWide)
{
status = status ?? new Status();
_handle = c_api.TF_NewSession(_graph, opts ?? lopts, status);
status.Check(true);
}
} }


public virtual void run(Operation op, params FeedItem[] feed_dict) public virtual void run(Operation op, params FeedItem[] feed_dict)
@@ -72,19 +69,19 @@ namespace Tensorflow


public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
{ {
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict);
var results = _run(new object[] {fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4}, feed_dict);
return (results[0], results[1], results[2], results[3]); return (results[0], results[1], results[2], results[3]);
} }


public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
{ {
var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict);
var results = _run(new object[] {fetches.Item1, fetches.Item2, fetches.Item3}, feed_dict);
return (results[0], results[1], results[2]); return (results[0], results[1], results[2]);
} }


public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict)
{ {
var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict);
var results = _run(new object[] {fetches.Item1, fetches.Item2}, feed_dict);
return (results[0], results[1]); return (results[0], results[1]);
} }


@@ -95,8 +92,7 @@ namespace Tensorflow


public virtual NDArray[] run(object fetches, Hashtable feed_dict = null) public virtual NDArray[] run(object fetches, Hashtable feed_dict = null)
{ {
var feed_items = feed_dict == null ? new FeedItem[0] :
feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray();
var feed_items = feed_dict == null ? new FeedItem[0] : feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray();
return _run(fetches, feed_items); return _run(fetches, feed_items);
} }


@@ -130,7 +126,7 @@ namespace Tensorflow


// We only want to really perform the run if fetches or targets are provided, // We only want to really perform the run if fetches or targets are provided,
// or if the call is a partial run that specifies feeds. // or if the call is a partial run that specifies feeds.
var results = _do_run(final_targets.Select(x => (Operation)x).ToList(), final_fetches, feed_dict_tensor);
var results = _do_run(final_targets.Select(x => (Operation) x).ToList(), final_fetches, feed_dict_tensor);


return fetch_handler.build_results(this, results); return fetch_handler.build_results(this, results);
} }
@@ -150,9 +146,7 @@ namespace Tensorflow
/// </returns> /// </returns>
private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict)
{ {

var feeds = new KeyValuePair<TF_Output, Tensor>[feed_dict.Count]; var feeds = new KeyValuePair<TF_Output, Tensor>[feed_dict.Count];
var ignoreDispose = new bool[feed_dict.Count];
int i = 0; int i = 0;
foreach (var x in feed_dict) foreach (var x in feed_dict)
{ {
@@ -160,15 +154,25 @@ namespace Tensorflow
{ {
switch (x.Value) switch (x.Value)
{ {
case Tensor v: ignoreDispose[i] = true; feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v); break;
case NDArray v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype)); break;
case Tensor v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), v);
break;
case NDArray v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v, tensor.dtype));
break;
case IntPtr v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
break;
#if _REGEN #if _REGEN
// @formatter:off — disable formatter after this line
%types = ["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] %types = ["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"]
%foreach types% %foreach types%
case #1 v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; case #1 v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case #1[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; case #1[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
% %
// @formatter:on — enable formatter after this line
#else #else
// @formatter:off — disable formatter after this line
case sbyte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; case sbyte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case sbyte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; case sbyte[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case byte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; case byte v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
@@ -191,10 +195,14 @@ namespace Tensorflow
case double[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; case double[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case Complex v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; case Complex v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case Complex[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break; case Complex[] v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
// @formatter:on — enable formatter after this line
#endif #endif
case bool v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte) (v ? 1 : 0), TF_DataType.TF_BOOL)); break;
case string v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case IntPtr v: feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v)); break;
case bool v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor((byte) (v ? 1 : 0), TF_DataType.TF_BOOL));
break;
case string v:
feeds[i++] = new KeyValuePair<TF_Output, Tensor>(tensor._as_tf_output(), new Tensor(v));
break;
default: default:
throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? "<null>"}"); throw new NotImplementedException($"feed_dict data type {x.Value?.GetType().Name ?? "<null>"}");
} }
@@ -203,18 +211,7 @@ namespace Tensorflow


var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray();
//var targets = target_list; //var targets = target_list;
try
{
return _call_tf_sessionrun(feeds, fetches, target_list);
} finally
{
for (var idx = 0; idx < feeds.Length; idx++)
{
if (ignoreDispose[idx])
continue;
feeds[idx].Value.Dispose();
}
}
return _call_tf_sessionrun(feeds, fetches, target_list);
} }


private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list) private unsafe NDArray[] _call_tf_sessionrun(KeyValuePair<TF_Output, Tensor>[] feed_dict, TF_Output[] fetch_list, List<Operation> target_list)
@@ -229,12 +226,12 @@ namespace Tensorflow
c_api.TF_SessionRun(_handle, c_api.TF_SessionRun(_handle,
run_options: null, run_options: null,
inputs: feed_dict.Select(f => f.Key).ToArray(), inputs: feed_dict.Select(f => f.Key).ToArray(),
input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(),
input_values: feed_dict.Select(f => (IntPtr) f.Value).ToArray(),
ninputs: feed_dict.Length, ninputs: feed_dict.Length,
outputs: fetch_list, outputs: fetch_list,
output_values: output_values, output_values: output_values,
noutputs: fetch_list.Length, noutputs: fetch_list.Length,
target_opers: target_list.Select(f => (IntPtr)f).ToArray(),
target_opers: target_list.Select(f => (IntPtr) f).ToArray(),
ntargets: target_list.Count, ntargets: target_list.Count,
run_metadata: IntPtr.Zero, run_metadata: IntPtr.Zero,
status: status); status: status);
@@ -265,7 +262,7 @@ namespace Tensorflow
ret = NDArray.Scalar(*(bool*) srcAddress); ret = NDArray.Scalar(*(bool*) srcAddress);
break; break;
case TF_DataType.TF_STRING: case TF_DataType.TF_STRING:
using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize)))
using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long) tensor.bytesize)))
ret = NDArray.FromString(reader.ReadString()); ret = NDArray.FromString(reader.ReadString());
break; break;
case TF_DataType.TF_UINT8: case TF_DataType.TF_UINT8:
@@ -330,81 +327,95 @@ namespace Tensorflow
#endregion #endregion
#else #else


#region Compute
switch (tensor.dtype)
{
case TF_DataType.TF_BOOL:
{
#region Compute

switch (tensor.dtype)
{
case TF_DataType.TF_BOOL:
{
ret = new NDArray(NPTypeCode.Boolean, ndims, false); ret = new NDArray(NPTypeCode.Boolean, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}
case TF_DataType.TF_UINT8:
{
break;
}

case TF_DataType.TF_UINT8:
{
ret = new NDArray(NPTypeCode.Byte, ndims, false); ret = new NDArray(NPTypeCode.Byte, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}
case TF_DataType.TF_INT16:
{
break;
}

case TF_DataType.TF_INT16:
{
ret = new NDArray(NPTypeCode.Int16, ndims, false); ret = new NDArray(NPTypeCode.Int16, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}
case TF_DataType.TF_UINT16:
{
break;
}

case TF_DataType.TF_UINT16:
{
ret = new NDArray(NPTypeCode.UInt16, ndims, false); ret = new NDArray(NPTypeCode.UInt16, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}
case TF_DataType.TF_INT32:
{
break;
}

case TF_DataType.TF_INT32:
{
ret = new NDArray(NPTypeCode.Int32, ndims, false); ret = new NDArray(NPTypeCode.Int32, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}
case TF_DataType.TF_UINT32:
{
break;
}

case TF_DataType.TF_UINT32:
{
ret = new NDArray(NPTypeCode.UInt32, ndims, false); ret = new NDArray(NPTypeCode.UInt32, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}
case TF_DataType.TF_INT64:
{
break;
}

case TF_DataType.TF_INT64:
{
ret = new NDArray(NPTypeCode.Int64, ndims, false); ret = new NDArray(NPTypeCode.Int64, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}
case TF_DataType.TF_UINT64:
{
break;
}

case TF_DataType.TF_UINT64:
{
ret = new NDArray(NPTypeCode.UInt64, ndims, false); ret = new NDArray(NPTypeCode.UInt64, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}
case TF_DataType.TF_DOUBLE:
{
break;
}

case TF_DataType.TF_DOUBLE:
{
ret = new NDArray(NPTypeCode.Double, ndims, false); ret = new NDArray(NPTypeCode.Double, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}
case TF_DataType.TF_FLOAT:
{
break;
}

case TF_DataType.TF_FLOAT:
{
ret = new NDArray(NPTypeCode.Single, ndims, false); ret = new NDArray(NPTypeCode.Single, ndims, false);
System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize); System.Buffer.MemoryCopy(src, ret.Unsafe.Address, bytesize, bytesize);
break;
}
break;
}

case TF_DataType.TF_STRING: case TF_DataType.TF_STRING:
{ {
throw new NotImplementedException(); throw new NotImplementedException();
//TODO:! This is not the way to handle string[], it should be done with TF_DecodeString //TODO:! This is not the way to handle string[], it should be done with TF_DecodeString
using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long)tensor.bytesize)))
using (var reader = new CodedInputStream(new IntPtr(srcAddress).Stream(8, (long) tensor.bytesize)))
ret = NDArray.FromString(reader.ReadString()); ret = NDArray.FromString(reader.ReadString());
break; break;
} }
default:
throw new NotSupportedException();
}
#endregion

default:
throw new NotSupportedException();
}

#endregion

#endif #endif
} }
} }
@@ -423,9 +434,7 @@ namespace Tensorflow
} }


private void _extend_graph() private void _extend_graph()
{

}
{ }


public void close() public void close()
{ {
@@ -434,11 +443,12 @@ namespace Tensorflow


protected override void DisposeUnmanagedResources(IntPtr handle) protected override void DisposeUnmanagedResources(IntPtr handle)
{ {
using (var status = new Status())
{
c_api.TF_DeleteSession(handle, status);
status.Check(true);
}
lock (Locks.ProcessWide)
using (var status = new Status())
{
c_api.TF_DeleteSession(handle, status);
status.Check(true);
}
} }
} }
} }

+ 6
- 14
src/TensorFlowNET.Core/Sessions/Session.cs View File

@@ -21,28 +21,20 @@ namespace Tensorflow
{ {
public class Session : BaseSession, IObjectLife public class Session : BaseSession, IObjectLife
{ {
public Session(string target = "", Graph g = null)
: base(target, g, null)
{

}
public Session(string target = "", Graph g = null) : base(target, g, null)
{ }


public Session(IntPtr handle, Graph g = null)
: base("", g, null)
public Session(IntPtr handle, Graph g = null) : base("", g, null)
{ {
_handle = handle; _handle = handle;
} }


public Session(Graph g, SessionOptions opts = null, Status s = null)
: base("", g, opts)
{
if (s == null)
s = new Status();
}
public Session(Graph g, SessionOptions opts = null, Status s = null) : base("", g, opts, s)
{ }


public Session as_default() public Session as_default()
{ {
tf.defaultSession = this;
tf._defaultSessionFactory.Value = this;
return this; return this;
} }




+ 27
- 0
src/TensorFlowNET.Core/Tensors/AllocationType.cs View File

@@ -0,0 +1,27 @@
namespace Tensorflow
{
/// <summary>
/// Used internally to
/// </summary>
public enum AllocationType
{
None = 0,
/// <summary>
/// Allocation was done by passing in a pointer, might be also holding reference to a C# object.
/// </summary>
FromPointer = 1,
/// <summary>
/// Allocation was done by calling c_api.TF_AllocateTensor or TF decided it has to copy data during c_api.TF_NewTensor. <br></br>
/// Deallocation is handled solely by Tensorflow.
/// </summary>
Tensorflow = 2,
/// <summary>
/// Allocation was done by Marshal.AllocateHGlobal
/// </summary>
Marshal = 3,
/// <summary>
/// Allocation was done by GCHandle.Alloc
/// </summary>
GCHandle = 4,
}
}

+ 203
- 257
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -28,42 +28,37 @@ using static Tensorflow.c_api;


namespace Tensorflow namespace Tensorflow
{ {
[SuppressMessage("ReSharper", "InvokeAsExtensionMethod")]
public partial class Tensor public partial class Tensor
{ {
/// <summary> /// <summary>
/// true if unmanaged buffer has been freed.
/// When Tensor was created from an object that is managed by C#'s GC - this will hold reference to prevent it from being collected.
/// </summary> /// </summary>
private bool _deallocator_called => _deallocatorArgs.deallocator_called;
protected object AllocationReferenceHolder;


/// <summary> /// <summary>
/// true if the Tensor was created from a managed array
/// The handle that was used to allocate this tensor, dependent on <see cref="AllocationType"/>.
/// </summary> /// </summary>
private bool _isPinnedArray => _deallocatorArgs.gc_handle != IntPtr.Zero;
protected object AllocationHandle;


/// <summary> /// <summary>
/// True only if the Tensor object was created in a way that the Tensor object itself allocated memory or pinned a managed object.
/// False if the Tensor was created from a pointer
/// True if this Tensor holds data allocated by C#.
/// </summary> /// </summary>
public bool IsMemoryOwner { get; private set; }
public bool IsMemoryOwner => AllocationType >= AllocationType.Marshal;


/// <summary> /// <summary>
/// This holds values that are used by the unmanaged deallocator callback
/// The allocation method used to create this Tensor.
/// </summary> /// </summary>
private DeallocatorArgs _deallocatorArgs = new DeallocatorArgs() { gc_handle = IntPtr.Zero };

// note: they must be assigned to a static variable in order to work as unmanaged callbacks
private static readonly Deallocator _hGlobalDeallocator = FreeHGlobalMemory;
private static readonly Deallocator _gcHandleDeallocator = FreeGCHandle;
private static readonly Deallocator _nothingDeallocator = FreeNothing;
public AllocationType AllocationType { get; protected set; }


/// <summary> /// <summary>
/// Create a Tensor object from an existing TF handle
/// Create a Tensor object from an existing TF handle
/// </summary> /// </summary>
/// <param name="handle"></param>
/// <param name="handle">Handle to a <see cref="Tensor"/> object.</param>
public Tensor(IntPtr handle) public Tensor(IntPtr handle)
{ {
_handle = handle; _handle = handle;
IsMemoryOwner = false;
//no need to set AllocationType = AllocationType.None;
} }


/// <summary> /// <summary>
@@ -71,430 +66,412 @@ namespace Tensorflow
/// Note: the caller is responsible for freeing the memory. Calling Dispose on this object will dispose the TensorFlow tensor /// Note: the caller is responsible for freeing the memory. Calling Dispose on this object will dispose the TensorFlow tensor
/// but not the memory itself! /// but not the memory itself!
/// </summary> /// </summary>
/// <param name="ptr">Pointer to unmanaged, fixed or pinned memory which the caller owns</param>
/// <param name="data_ptr">Pointer to unmanaged, fixed or pinned memory which the caller owns</param>
/// <param name="shape">Tensor shape</param> /// <param name="shape">Tensor shape</param>
/// <param name="dType">TF data type</param> /// <param name="dType">TF data type</param>
/// <param name="num_bytes">Size of the tensor in memory</param> /// <param name="num_bytes">Size of the tensor in memory</param>
public Tensor(IntPtr ptr, long[] shape, TF_DataType dType, int num_bytes)
public Tensor(IntPtr data_ptr, long[] shape, TF_DataType dType, int num_bytes)
{ {
_handle = TF_NewTensor(dType, dims: shape, num_dims: shape.Length, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs);
IsMemoryOwner = false;
unsafe
{
_handle = TF_NewTensor(dType, dims: shape, num_dims: shape.Length, data: data_ptr, len: (UIntPtr) num_bytes);
AllocationType = TF_TensorData(_handle) == data_ptr ? AllocationType.FromPointer : AllocationType.Tensorflow;
}
}

/// <summary>
/// Create a new Tensor from the given unmanaged memory pointer (which must be allocated, fixed or pinned by the caller)
/// Note: the caller is responsible for freeing the memory. Calling Dispose on this object will dispose the TensorFlow tensor
/// but not the memory itself!
/// </summary>
/// <param name="data_ptr">Pointer to unmanaged, fixed or pinned memory which the caller owns</param>
/// <param name="shape">Tensor shape</param>
/// <param name="dType">TF data type</param>
/// <param name="num_bytes">Size of the tensor in memory</param>
public unsafe Tensor(void* data_ptr, long[] shape, TF_DataType dType, int num_bytes)
{
_handle = TF_NewTensor(dType, dims: shape, num_dims: shape.Length, data: data_ptr, len: (UIntPtr) num_bytes);
AllocationType = TF_TensorData(_handle).ToPointer() == data_ptr ? AllocationType.FromPointer : AllocationType.Tensorflow;
} }


#if _REGEN #if _REGEN
%types=["sbyte", "bool", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"]
%types = ["sbyte", "bool", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"]
%foreach types% %foreach types%
/// <summary> /// <summary>
/// Create a 1d Tensor from the given linear array and shape
/// Create a 1d Tensor from the given linear array and shape
/// </summary> /// </summary>
public Tensor(#1[] data, TF_DataType? dType = null) public Tensor(#1[] data, TF_DataType? dType = null)
{ {
_handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(#1)), new long[]{data.Length}, data, Marshal.SizeOf<#1>());
IsMemoryOwner=true;
_handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(#1)), new long[] {data.Length}, data, #(#1=="Complex"|"Marshal.SizeOf<Complex>()"|"sizeof(#(str(#1)))"));
} }


/// <summary> /// <summary>
/// Create a N-dimensional Tensor from the given array
/// Create a N-dimensional Tensor from the given array
/// </summary> /// </summary>
public Tensor(#1[] data, long[] shape, TF_DataType? dType = null) public Tensor(#1[] data, long[] shape, TF_DataType? dType = null)
{ {
_handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(#1)), shape, data, Marshal.SizeOf<#1>());
IsMemoryOwner=true;
_handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(#1)), shape, data, #(#1=="Complex"|"Marshal.SizeOf<Complex>()"|"sizeof(#(str(#1)))"));
} }


/// <summary> /// <summary>
/// Create a scalar Tensor from the given value
/// Create a scalar Tensor from the given value
/// </summary> /// </summary>
public unsafe Tensor(#1 value, TF_DataType? dType = null) public unsafe Tensor(#1 value, TF_DataType? dType = null)
{ {
var v = (#1*)Marshal.AllocHGlobal(sizeof(#1));
*v = value;
_handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(#1)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(#1), deallocator: _hGlobalDeallocator, ref _deallocatorArgs);
IsMemoryOwner=true;
_handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(#1)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(#1));
*(#1*) TF_TensorData(_handle) = value;
AllocationType = AllocationType.Tensorflow;
} }
% %
#else #else

/// <summary> /// <summary>
/// Create a 1d Tensor from the given linear array and shape
/// Create a 1d Tensor from the given linear array and shape
/// </summary> /// </summary>
public Tensor(sbyte[] data, TF_DataType? dType = null) public Tensor(sbyte[] data, TF_DataType? dType = null)
{ {
_handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(sbyte)), new long[]{data.Length}, data, Marshal.SizeOf<sbyte>());
IsMemoryOwner=true;
_handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(sbyte)), new long[] {data.Length}, data, sizeof(sbyte));
} }


/// <summary> /// <summary>
/// Create a N-dimensional Tensor from the given array
/// Create a N-dimensional Tensor from the given array
/// </summary> /// </summary>
public Tensor(sbyte[] data, long[] shape, TF_DataType? dType = null) public Tensor(sbyte[] data, long[] shape, TF_DataType? dType = null)
{ {
_handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(sbyte)), shape, data, Marshal.SizeOf<sbyte>());
IsMemoryOwner=true;
_handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(sbyte)), shape, data, sizeof(sbyte));
} }


/// <summary> /// <summary>
/// Create a scalar Tensor from the given value
/// Create a scalar Tensor from the given value
/// </summary> /// </summary>
public unsafe Tensor(sbyte value, TF_DataType? dType = null) public unsafe Tensor(sbyte value, TF_DataType? dType = null)
{ {
var v = (sbyte*)Marshal.AllocHGlobal(sizeof(sbyte));
*v = value;
_handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(sbyte)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(sbyte), deallocator: _hGlobalDeallocator, ref _deallocatorArgs);
IsMemoryOwner=true;
_handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(sbyte)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(sbyte));
*(sbyte*) TF_TensorData(_handle) = value;
AllocationType = AllocationType.Tensorflow;
} }
/// <summary> /// <summary>
/// Create a 1d Tensor from the given linear array and shape
/// Create a 1d Tensor from the given linear array and shape
/// </summary> /// </summary>
public Tensor(bool[] data, TF_DataType? dType = null) public Tensor(bool[] data, TF_DataType? dType = null)
{ {
_handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(bool)), new long[]{data.Length}, data, Marshal.SizeOf<bool>());
IsMemoryOwner=true;
_handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(bool)), new long[] {data.Length}, data, sizeof(bool));
} }


/// <summary> /// <summary>
/// Create a N-dimensional Tensor from the given array
/// Create a N-dimensional Tensor from the given array
/// </summary> /// </summary>
public Tensor(bool[] data, long[] shape, TF_DataType? dType = null) public Tensor(bool[] data, long[] shape, TF_DataType? dType = null)
{ {
_handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(bool)), shape, data, Marshal.SizeOf<bool>());
IsMemoryOwner=true;
_handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(bool)), shape, data, sizeof(bool));
} }


/// <summary> /// <summary>
/// Create a scalar Tensor from the given value
/// Create a scalar Tensor from the given value
/// </summary> /// </summary>
public unsafe Tensor(bool value, TF_DataType? dType = null) public unsafe Tensor(bool value, TF_DataType? dType = null)
{ {
var v = (bool*)Marshal.AllocHGlobal(sizeof(bool));
*v = value;
_handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(bool)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(bool), deallocator: _hGlobalDeallocator, ref _deallocatorArgs);
IsMemoryOwner=true;
_handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(bool)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(bool));
*(bool*) TF_TensorData(_handle) = value;
AllocationType = AllocationType.Tensorflow;
} }
/// <summary> /// <summary>
/// Create a 1d Tensor from the given linear array and shape
/// Create a 1d Tensor from the given linear array and shape
/// </summary> /// </summary>
public Tensor(byte[] data, TF_DataType? dType = null) public Tensor(byte[] data, TF_DataType? dType = null)
{ {
_handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(byte)), new long[]{data.Length}, data, Marshal.SizeOf<byte>());
IsMemoryOwner=true;
_handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(byte)), new long[] {data.Length}, data, sizeof(byte));
} }


/// <summary> /// <summary>
/// Create a N-dimensional Tensor from the given array
/// Create a N-dimensional Tensor from the given array
/// </summary> /// </summary>
public Tensor(byte[] data, long[] shape, TF_DataType? dType = null) public Tensor(byte[] data, long[] shape, TF_DataType? dType = null)
{ {
_handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(byte)), shape, data, Marshal.SizeOf<byte>());
IsMemoryOwner=true;
_handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(byte)), shape, data, sizeof(byte));
} }


/// <summary> /// <summary>
/// Create a scalar Tensor from the given value
/// Create a scalar Tensor from the given value
/// </summary> /// </summary>
public unsafe Tensor(byte value, TF_DataType? dType = null) public unsafe Tensor(byte value, TF_DataType? dType = null)
{ {
var v = (byte*)Marshal.AllocHGlobal(sizeof(byte));
*v = value;
_handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(byte)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(byte), deallocator: _hGlobalDeallocator, ref _deallocatorArgs);
IsMemoryOwner=true;
_handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(byte)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(byte));
*(byte*) TF_TensorData(_handle) = value;
AllocationType = AllocationType.Tensorflow;
} }
/// <summary> /// <summary>
/// Create a 1d Tensor from the given linear array and shape
/// Create a 1d Tensor from the given linear array and shape
/// </summary> /// </summary>
public Tensor(short[] data, TF_DataType? dType = null) public Tensor(short[] data, TF_DataType? dType = null)
{ {
_handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(short)), new long[]{data.Length}, data, Marshal.SizeOf<short>());
IsMemoryOwner=true;
_handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(short)), new long[] {data.Length}, data, sizeof(short));
} }


/// <summary> /// <summary>
/// Create a N-dimensional Tensor from the given array
/// Create a N-dimensional Tensor from the given array
/// </summary> /// </summary>
public Tensor(short[] data, long[] shape, TF_DataType? dType = null) public Tensor(short[] data, long[] shape, TF_DataType? dType = null)
{ {
_handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(short)), shape, data, Marshal.SizeOf<short>());
IsMemoryOwner=true;
_handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(short)), shape, data, sizeof(short));
} }


/// <summary> /// <summary>
/// Create a scalar Tensor from the given value
/// Create a scalar Tensor from the given value
/// </summary> /// </summary>
public unsafe Tensor(short value, TF_DataType? dType = null) public unsafe Tensor(short value, TF_DataType? dType = null)
{ {
var v = (short*)Marshal.AllocHGlobal(sizeof(short));
*v = value;
_handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(short)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(short), deallocator: _hGlobalDeallocator, ref _deallocatorArgs);
IsMemoryOwner=true;
_handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(short)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(short));
*(short*) TF_TensorData(_handle) = value;
AllocationType = AllocationType.Tensorflow;
} }
/// <summary> /// <summary>
/// Create a 1d Tensor from the given linear array and shape
/// Create a 1d Tensor from the given linear array and shape
/// </summary> /// </summary>
public Tensor(ushort[] data, TF_DataType? dType = null) public Tensor(ushort[] data, TF_DataType? dType = null)
{ {
_handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(ushort)), new long[]{data.Length}, data, Marshal.SizeOf<ushort>());
IsMemoryOwner=true;
_handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(ushort)), new long[] {data.Length}, data, sizeof(ushort));
} }


/// <summary> /// <summary>
/// Create a N-dimensional Tensor from the given array
/// Create a N-dimensional Tensor from the given array
/// </summary> /// </summary>
public Tensor(ushort[] data, long[] shape, TF_DataType? dType = null) public Tensor(ushort[] data, long[] shape, TF_DataType? dType = null)
{ {
_handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(ushort)), shape, data, Marshal.SizeOf<ushort>());
IsMemoryOwner=true;
_handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(ushort)), shape, data, sizeof(ushort));
} }


/// <summary> /// <summary>
/// Create a scalar Tensor from the given value
/// Create a scalar Tensor from the given value
/// </summary> /// </summary>
public unsafe Tensor(ushort value, TF_DataType? dType = null) public unsafe Tensor(ushort value, TF_DataType? dType = null)
{ {
var v = (ushort*)Marshal.AllocHGlobal(sizeof(ushort));
*v = value;
_handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(ushort)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(ushort), deallocator: _hGlobalDeallocator, ref _deallocatorArgs);
IsMemoryOwner=true;
_handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(ushort)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(ushort));
*(ushort*) TF_TensorData(_handle) = value;
AllocationType = AllocationType.Tensorflow;
} }
/// <summary> /// <summary>
/// Create a 1d Tensor from the given linear array and shape
/// Create a 1d Tensor from the given linear array and shape
/// </summary> /// </summary>
public Tensor(int[] data, TF_DataType? dType = null) public Tensor(int[] data, TF_DataType? dType = null)
{ {
_handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(int)), new long[]{data.Length}, data, Marshal.SizeOf<int>());
IsMemoryOwner=true;
_handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(int)), new long[] {data.Length}, data, sizeof(int));
} }


/// <summary> /// <summary>
/// Create a N-dimensional Tensor from the given array
/// Create a N-dimensional Tensor from the given array
/// </summary> /// </summary>
public Tensor(int[] data, long[] shape, TF_DataType? dType = null) public Tensor(int[] data, long[] shape, TF_DataType? dType = null)
{ {
_handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(int)), shape, data, Marshal.SizeOf<int>());
IsMemoryOwner=true;
_handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(int)), shape, data, sizeof(int));
} }


/// <summary> /// <summary>
/// Create a scalar Tensor from the given value
/// Create a scalar Tensor from the given value
/// </summary> /// </summary>
public unsafe Tensor(int value, TF_DataType? dType = null) public unsafe Tensor(int value, TF_DataType? dType = null)
{ {
var v = (int*)Marshal.AllocHGlobal(sizeof(int));
*v = value;
_handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(int)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(int), deallocator: _hGlobalDeallocator, ref _deallocatorArgs);
IsMemoryOwner=true;
_handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(int)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(int));
*(int*) TF_TensorData(_handle) = value;
AllocationType = AllocationType.Tensorflow;
} }
/// <summary> /// <summary>
/// Create a 1d Tensor from the given linear array and shape
/// Create a 1d Tensor from the given linear array and shape
/// </summary> /// </summary>
public Tensor(uint[] data, TF_DataType? dType = null) public Tensor(uint[] data, TF_DataType? dType = null)
{ {
_handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(uint)), new long[]{data.Length}, data, Marshal.SizeOf<uint>());
IsMemoryOwner=true;
_handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(uint)), new long[] {data.Length}, data, sizeof(uint));
} }


/// <summary> /// <summary>
/// Create a N-dimensional Tensor from the given array
/// Create a N-dimensional Tensor from the given array
/// </summary> /// </summary>
public Tensor(uint[] data, long[] shape, TF_DataType? dType = null) public Tensor(uint[] data, long[] shape, TF_DataType? dType = null)
{ {
_handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(uint)), shape, data, Marshal.SizeOf<uint>());
IsMemoryOwner=true;
_handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(uint)), shape, data, sizeof(uint));
} }


/// <summary> /// <summary>
/// Create a scalar Tensor from the given value
/// Create a scalar Tensor from the given value
/// </summary> /// </summary>
public unsafe Tensor(uint value, TF_DataType? dType = null) public unsafe Tensor(uint value, TF_DataType? dType = null)
{ {
var v = (uint*)Marshal.AllocHGlobal(sizeof(uint));
*v = value;
_handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(uint)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(uint), deallocator: _hGlobalDeallocator, ref _deallocatorArgs);
IsMemoryOwner=true;
_handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(uint)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(uint));
*(uint*) TF_TensorData(_handle) = value;
AllocationType = AllocationType.Tensorflow;
} }
/// <summary> /// <summary>
/// Create a 1d Tensor from the given linear array and shape
/// Create a 1d Tensor from the given linear array and shape
/// </summary> /// </summary>
public Tensor(long[] data, TF_DataType? dType = null) public Tensor(long[] data, TF_DataType? dType = null)
{ {
_handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(long)), new long[]{data.Length}, data, Marshal.SizeOf<long>());
IsMemoryOwner=true;
_handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(long)), new long[] {data.Length}, data, sizeof(long));
} }


/// <summary> /// <summary>
/// Create a N-dimensional Tensor from the given array
/// Create a N-dimensional Tensor from the given array
/// </summary> /// </summary>
public Tensor(long[] data, long[] shape, TF_DataType? dType = null) public Tensor(long[] data, long[] shape, TF_DataType? dType = null)
{ {
_handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(long)), shape, data, Marshal.SizeOf<long>());
IsMemoryOwner=true;
_handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(long)), shape, data, sizeof(long));
} }


/// <summary> /// <summary>
/// Create a scalar Tensor from the given value
/// Create a scalar Tensor from the given value
/// </summary> /// </summary>
public unsafe Tensor(long value, TF_DataType? dType = null) public unsafe Tensor(long value, TF_DataType? dType = null)
{ {
var v = (long*)Marshal.AllocHGlobal(sizeof(long));
*v = value;
_handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(long)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(long), deallocator: _hGlobalDeallocator, ref _deallocatorArgs);
IsMemoryOwner=true;
_handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(long)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(long));
*(long*) TF_TensorData(_handle) = value;
AllocationType = AllocationType.Tensorflow;
} }
/// <summary> /// <summary>
/// Create a 1d Tensor from the given linear array and shape
/// Create a 1d Tensor from the given linear array and shape
/// </summary> /// </summary>
public Tensor(ulong[] data, TF_DataType? dType = null) public Tensor(ulong[] data, TF_DataType? dType = null)
{ {
_handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(ulong)), new long[]{data.Length}, data, Marshal.SizeOf<ulong>());
IsMemoryOwner=true;
_handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(ulong)), new long[] {data.Length}, data, sizeof(ulong));
} }


/// <summary> /// <summary>
/// Create a N-dimensional Tensor from the given array
/// Create a N-dimensional Tensor from the given array
/// </summary> /// </summary>
public Tensor(ulong[] data, long[] shape, TF_DataType? dType = null) public Tensor(ulong[] data, long[] shape, TF_DataType? dType = null)
{ {
_handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(ulong)), shape, data, Marshal.SizeOf<ulong>());
IsMemoryOwner=true;
_handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(ulong)), shape, data, sizeof(ulong));
} }


/// <summary> /// <summary>
/// Create a scalar Tensor from the given value
/// Create a scalar Tensor from the given value
/// </summary> /// </summary>
public unsafe Tensor(ulong value, TF_DataType? dType = null) public unsafe Tensor(ulong value, TF_DataType? dType = null)
{ {
var v = (ulong*)Marshal.AllocHGlobal(sizeof(ulong));
*v = value;
_handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(ulong)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(ulong), deallocator: _hGlobalDeallocator, ref _deallocatorArgs);
IsMemoryOwner=true;
_handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(ulong)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(ulong));
*(ulong*) TF_TensorData(_handle) = value;
AllocationType = AllocationType.Tensorflow;
} }
/// <summary> /// <summary>
/// Create a 1d Tensor from the given linear array and shape
/// Create a 1d Tensor from the given linear array and shape
/// </summary> /// </summary>
public Tensor(float[] data, TF_DataType? dType = null) public Tensor(float[] data, TF_DataType? dType = null)
{ {
_handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(float)), new long[]{data.Length}, data, Marshal.SizeOf<float>());
IsMemoryOwner=true;
_handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(float)), new long[] {data.Length}, data, sizeof(float));
} }


/// <summary> /// <summary>
/// Create a N-dimensional Tensor from the given array
/// Create a N-dimensional Tensor from the given array
/// </summary> /// </summary>
public Tensor(float[] data, long[] shape, TF_DataType? dType = null) public Tensor(float[] data, long[] shape, TF_DataType? dType = null)
{ {
_handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(float)), shape, data, Marshal.SizeOf<float>());
IsMemoryOwner=true;
_handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(float)), shape, data, sizeof(float));
} }


/// <summary> /// <summary>
/// Create a scalar Tensor from the given value
/// Create a scalar Tensor from the given value
/// </summary> /// </summary>
public unsafe Tensor(float value, TF_DataType? dType = null) public unsafe Tensor(float value, TF_DataType? dType = null)
{ {
var v = (float*)Marshal.AllocHGlobal(sizeof(float));
*v = value;
_handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(float)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(float), deallocator: _hGlobalDeallocator, ref _deallocatorArgs);
IsMemoryOwner=true;
_handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(float)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(float));
*(float*) TF_TensorData(_handle) = value;
AllocationType = AllocationType.Tensorflow;
} }
/// <summary> /// <summary>
/// Create a 1d Tensor from the given linear array and shape
/// Create a 1d Tensor from the given linear array and shape
/// </summary> /// </summary>
public Tensor(double[] data, TF_DataType? dType = null) public Tensor(double[] data, TF_DataType? dType = null)
{ {
_handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(double)), new long[]{data.Length}, data, Marshal.SizeOf<double>());
IsMemoryOwner=true;
_handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(double)), new long[] {data.Length}, data, sizeof(double));
} }


/// <summary> /// <summary>
/// Create a N-dimensional Tensor from the given array
/// Create a N-dimensional Tensor from the given array
/// </summary> /// </summary>
public Tensor(double[] data, long[] shape, TF_DataType? dType = null) public Tensor(double[] data, long[] shape, TF_DataType? dType = null)
{ {
_handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(double)), shape, data, Marshal.SizeOf<double>());
IsMemoryOwner=true;
_handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(double)), shape, data, sizeof(double));
} }


/// <summary> /// <summary>
/// Create a scalar Tensor from the given value
/// Create a scalar Tensor from the given value
/// </summary> /// </summary>
public unsafe Tensor(double value, TF_DataType? dType = null) public unsafe Tensor(double value, TF_DataType? dType = null)
{ {
var v = (double*)Marshal.AllocHGlobal(sizeof(double));
*v = value;
_handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(double)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(double), deallocator: _hGlobalDeallocator, ref _deallocatorArgs);
IsMemoryOwner=true;
_handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(double)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(double));
*(double*) TF_TensorData(_handle) = value;
AllocationType = AllocationType.Tensorflow;
} }
/// <summary> /// <summary>
/// Create a 1d Tensor from the given linear array and shape
/// Create a 1d Tensor from the given linear array and shape
/// </summary> /// </summary>
public Tensor(Complex[] data, TF_DataType? dType = null) public Tensor(Complex[] data, TF_DataType? dType = null)
{ {
_handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(Complex)), new long[]{data.Length}, data, Marshal.SizeOf<Complex>());
IsMemoryOwner=true;
_handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(Complex)), new long[] {data.Length}, data, Marshal.SizeOf<Complex>());
} }


/// <summary> /// <summary>
/// Create a N-dimensional Tensor from the given array
/// Create a N-dimensional Tensor from the given array
/// </summary> /// </summary>
public Tensor(Complex[] data, long[] shape, TF_DataType? dType = null) public Tensor(Complex[] data, long[] shape, TF_DataType? dType = null)
{ {
_handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(Complex)), shape, data, Marshal.SizeOf<Complex>());
IsMemoryOwner=true;
_handle = CreateTensorFromArray(dType ?? dtypes.as_dtype(typeof(Complex)), shape, data, Marshal.SizeOf<Complex>());
} }


/// <summary> /// <summary>
/// Create a scalar Tensor from the given value
/// Create a scalar Tensor from the given value
/// </summary> /// </summary>
public unsafe Tensor(Complex value, TF_DataType? dType = null) public unsafe Tensor(Complex value, TF_DataType? dType = null)
{ {
var v = (Complex*)Marshal.AllocHGlobal(sizeof(Complex));
*v = value;
_handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(Complex)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(Complex), deallocator: _hGlobalDeallocator, ref _deallocatorArgs);
IsMemoryOwner=true;
_handle = TF_AllocateTensor(dType ?? dtypes.as_dtype(typeof(Complex)), dims: new long[0], num_dims: 0, len: (UIntPtr) sizeof(Complex));
*(Complex*) TF_TensorData(_handle) = value;
AllocationType = AllocationType.Tensorflow;
} }
#endif #endif


/// <summary> /// <summary>
/// Create a string Tensor from the given string
/// Create a string Tensor from the given string
/// </summary> /// </summary>
public unsafe Tensor(string str) public unsafe Tensor(string str)
{ {
var status = new Status(); var status = new Status();
var buffer = Encoding.UTF8.GetBytes(str); var buffer = Encoding.UTF8.GetBytes(str);
var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length);
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8));
var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length);
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8));
AllocationType = AllocationType.Tensorflow;


IntPtr tensor = c_api.TF_TensorData(handle); IntPtr tensor = c_api.TF_TensorData(handle);
Marshal.WriteInt64(tensor, 0); Marshal.WriteInt64(tensor, 0);
fixed (byte* src = buffer) fixed (byte* src = buffer)
c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status);
c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(Int64)), size, status);
_handle = handle; _handle = handle;
status.Check(true); status.Check(true);
} }


public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null) public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null)
{ {
if (tensorDType == null)
tensorDType = nd.dtype.as_dtype();

// todo: handle nd of type "String" here too // todo: handle nd of type "String" here too
if (tensorDType == TF_DataType.TF_STRING && nd.typecode == NPTypeCode.Byte) if (tensorDType == TF_DataType.TF_STRING && nd.typecode == NPTypeCode.Byte)
{ {
if (nd.Unsafe.Storage.Shape.IsContiguous) if (nd.Unsafe.Storage.Shape.IsContiguous)
{ {
var bytesLength = (UIntPtr)nd.size;
var bytesLength = (UIntPtr) nd.size;
var size = c_api.TF_StringEncodedSize(bytesLength); var size = c_api.TF_StringEncodedSize(bytesLength);
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8));
AllocationType = AllocationType.Tensorflow;


IntPtr tensor = c_api.TF_TensorData(handle); IntPtr tensor = c_api.TF_TensorData(handle);
Marshal.WriteInt64(tensor, 0); Marshal.WriteInt64(tensor, 0);
@@ -504,13 +481,12 @@ namespace Tensorflow


status.Check(true); status.Check(true);
_handle = handle; _handle = handle;
IsMemoryOwner = false;
}
else
} else
{ {
var buffer = nd.ToArray<byte>(); var buffer = nd.ToArray<byte>();
var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length); var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length);
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8)); var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8));
AllocationType = AllocationType.Tensorflow;


IntPtr tensor = c_api.TF_TensorData(handle); IntPtr tensor = c_api.TF_TensorData(handle);
Marshal.WriteInt64(tensor, 0); Marshal.WriteInt64(tensor, 0);
@@ -521,7 +497,6 @@ namespace Tensorflow


status.Check(true); status.Check(true);
_handle = handle; _handle = handle;
IsMemoryOwner = false;
} }


return; return;
@@ -532,27 +507,27 @@ namespace Tensorflow


private unsafe IntPtr CreateTensorFromNDArray(NDArray nd, TF_DataType? given_dtype) private unsafe IntPtr CreateTensorFromNDArray(NDArray nd, TF_DataType? given_dtype)
{ {
if (nd.dtype.Name == "String")
if (nd.typecode == NPTypeCode.String)
throw new NotImplementedException("Support for NDArray of type string not implemented yet"); throw new NotImplementedException("Support for NDArray of type string not implemented yet");
IArraySlice arraySlice;
if (nd.Unsafe.Storage.Shape.IsContiguous == false)
{
// the memory is NOT contiguous, so we have to copy the view into a contiguous memory block.
arraySlice = nd.CloneData();
}
else

var arraySlice = nd.Unsafe.Storage.Shape.IsContiguous ? nd.GetData() : nd.CloneData();

var handle = TF_NewTensor(
given_dtype ?? nd.dtype.as_dtype(),
dims: nd.shape.Select(i => (long) i).ToArray(),
num_dims: nd.ndim,
data: arraySlice.Address,
len: (UIntPtr) (nd.size * nd.dtypesize));

//if TF decided not to perform copy, hold reference for given NDArray.
if (TF_TensorData(handle).ToPointer() == arraySlice.Address)
{ {
// the memory is contiguous
arraySlice = nd.GetData();
}
this.Tag = arraySlice; // keep a reference to the memory block to make sure it is not disposed while TF is using it
var ptr = new IntPtr(arraySlice.Address);
int num_bytes = (nd.size * nd.dtypesize);
var dtype = given_dtype ?? nd.dtype.as_dtype();
var handle = TF_NewTensor(dtype, dims: nd.shape.Select(i=>(long)i).ToArray(), num_dims: nd.ndim, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs);
IsMemoryOwner = false;
return handle;
AllocationType = AllocationType.FromPointer;
AllocationReferenceHolder = arraySlice;
} else
AllocationType = AllocationType.Tensorflow;


return handle;
} }


public unsafe Tensor(byte[][] buffer, long[] shape) public unsafe Tensor(byte[][] buffer, long[] shape)
@@ -560,11 +535,13 @@ namespace Tensorflow
int size = 0; int size = 0;
foreach (var b in buffer) foreach (var b in buffer)
{ {
size += (int)TF_StringEncodedSize((UIntPtr)b.Length);
size += (int) TF_StringEncodedSize((UIntPtr) b.Length);
} }

int totalSize = size + buffer.Length * 8; int totalSize = size + buffer.Length * 8;
ulong offset = 0; ulong offset = 0;
IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr)totalSize);
IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr) totalSize);
AllocationType = AllocationType.Tensorflow;


// Clear offset table // Clear offset table
IntPtr pOffset = TF_TensorData(handle); IntPtr pOffset = TF_TensorData(handle);
@@ -572,15 +549,15 @@ namespace Tensorflow
IntPtr dstLimit = pOffset + totalSize; IntPtr dstLimit = pOffset + totalSize;
for (int i = 0; i < buffer.Length; i++) for (int i = 0; i < buffer.Length; i++)
{ {
Marshal.WriteInt64(pOffset, (long)offset);
Marshal.WriteInt64(pOffset, (long) offset);
using (var status = new Status()) using (var status = new Status())
{ {
fixed (byte* src = &buffer[i][0]) fixed (byte* src = &buffer[i][0])
{ {
var written = TF_StringEncode(src, (UIntPtr)buffer[i].Length, (sbyte*)dst, (UIntPtr)(dstLimit.ToInt64() - dst.ToInt64()), status);
var written = TF_StringEncode(src, (UIntPtr) buffer[i].Length, (sbyte*) dst, (UIntPtr) (dstLimit.ToInt64() - dst.ToInt64()), status);
status.Check(true); status.Check(true);
pOffset += 8; pOffset += 8;
dst += (int)written;
dst += (int) written;
offset += written; offset += written;
} }
} }
@@ -612,24 +589,26 @@ namespace Tensorflow
/// </remarks> /// </remarks>
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
[SuppressMessage("ReSharper", "LocalVariableHidesMember")] [SuppressMessage("ReSharper", "LocalVariableHidesMember")]
protected unsafe IntPtr CreateTensorWithoutCopying(TF_DataType dt, long[] shape, Array data, int element_size)
protected unsafe IntPtr CreateTensorFromArray(TF_DataType dt, long[] shape, Array data, int element_size)
{ {
if (dt == TF_DataType.TF_STRING && data is byte[] buffer) if (dt == TF_DataType.TF_STRING && data is byte[] buffer)
{ {
var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length);
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8));
var size = c_api.TF_StringEncodedSize((UIntPtr) buffer.Length);
var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr) ((ulong) size + 8));
AllocationType = AllocationType.Tensorflow;


IntPtr tensor = c_api.TF_TensorData(handle); IntPtr tensor = c_api.TF_TensorData(handle);
Marshal.WriteInt64(tensor, 0); Marshal.WriteInt64(tensor, 0);


var status = new Status(); var status = new Status();
fixed (byte* src = buffer) fixed (byte* src = buffer)
c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status);
c_api.TF_StringEncode(src, (UIntPtr) buffer.Length, (sbyte*) (tensor + sizeof(Int64)), size, status);


status.Check(true); status.Check(true);
return handle; return handle;
} }
return CreateTensorWithoutCopying(dt, shape, data, 0, data.Length, element_size);

return CreateTensorFromArray(dt, shape, data, 0, data.Length, element_size);
} }


/// <summary> /// <summary>
@@ -647,67 +626,34 @@ namespace Tensorflow
/// specified dimensions. /// specified dimensions.
/// </remarks> /// </remarks>
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
protected unsafe IntPtr CreateTensorWithoutCopying(TF_DataType dt, long[] shape, Array data, int start, int count, int element_size)
protected IntPtr CreateTensorFromArray(TF_DataType dt, long[] shape, Array data, int start, int count, int element_size)
{ {
if (start < 0 || start > data.Length - count) if (start < 0 || start > data.Length - count)
throw new ArgumentException($"Array length {data.Length} does not match the given shape {new Shape(shape.Cast<int>().ToArray())}"); throw new ArgumentException($"Array length {data.Length} does not match the given shape {new Shape(shape.Cast<int>().ToArray())}");


// get a handle to the pinned array which we will pass on to the tensor computation engine to use // get a handle to the pinned array which we will pass on to the tensor computation engine to use
var gcHandle = GCHandle.Alloc(data, GCHandleType.Pinned); var gcHandle = GCHandle.Alloc(data, GCHandleType.Pinned);
_deallocatorArgs = new DeallocatorArgs() { gc_handle = GCHandle.ToIntPtr(gcHandle) };
if (shape == null || shape.Length == 0)
return TF_NewTensor(dt, new long[0], 0, gcHandle.AddrOfPinnedObject() + start * element_size, (UIntPtr)(count * element_size), _gcHandleDeallocator, ref _deallocatorArgs);
else
return TF_NewTensor(dt, shape, shape.Length, gcHandle.AddrOfPinnedObject() + start * element_size, (UIntPtr)(count * element_size), _gcHandleDeallocator, ref _deallocatorArgs);
}

[MonoPInvokeCallback(typeof(Deallocator))]
internal static void FreeHGlobalMemory(IntPtr dataPtr, IntPtr len, ref DeallocatorArgs args)
{
if (args.deallocator_called || dataPtr == IntPtr.Zero)
return;
var pinnedAddr = gcHandle.AddrOfPinnedObject();


// NumSharp will dispose
Marshal.FreeHGlobal(dataPtr);
args.deallocator_called = true;
}
//call NewTensor
IntPtr handle;
if (shape == null || shape.Length == 0)
handle = TF_NewTensor(dt, new long[0], 0, pinnedAddr + start * element_size, (UIntPtr) (count * element_size));
else
handle = TF_NewTensor(dt, shape, shape.Length, pinnedAddr + start * element_size, (UIntPtr) (count * element_size));


[MonoPInvokeCallback(typeof(Deallocator))]
internal static void FreeGCHandle(IntPtr dataPtr, IntPtr len, ref DeallocatorArgs args)
{
if (args.deallocator_called || args.gc_handle == IntPtr.Zero)
return;
// note: since the ptr given to tensorflow is just the addr of the pinned object we can not directly free it! we need to free the gcHandle instead
GCHandle.FromIntPtr(args.gc_handle).Free();
args.deallocator_called = true;
}
//Figure if TF decided to clone or not.
if (c_api.TF_TensorData(handle) == pinnedAddr)
{
AllocationType = AllocationType.GCHandle;
AllocationHandle = gcHandle;
} else
{
AllocationType = AllocationType.Tensorflow;
gcHandle.Free();
}


[MonoPInvokeCallback(typeof(Deallocator))]
internal static void FreeNothing(IntPtr dataPtr, IntPtr len, ref DeallocatorArgs args)
{
args.deallocator_called = true;
return handle;
} }
} }

/// <summary>
/// This attribute can be applied to callback functions that will be invoked
/// from unmanaged code to managed code.
/// </summary>
/// <remarks>
/// <code>
/// [TensorFlow.MonoPInvokeCallback (typeof (BufferReleaseFunc))]
/// internal static void MyFreeFunc (IntPtr data, IntPtr length){..}
/// </code>
/// </remarks>
public sealed class MonoPInvokeCallbackAttribute : Attribute
{
/// <summary>
/// Use this constructor to annotate the type of the callback function that
/// will be invoked from unmanaged code.
/// </summary>
/// <param name="t">T.</param>
public MonoPInvokeCallbackAttribute(Type t) { }
}

}
}

+ 26
- 0
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -555,9 +555,35 @@ namespace Tensorflow
return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}"; return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype}";
} }


/// <summary>
/// Dispose any managed resources.
/// </summary>
/// <remarks>Equivalent to what you would perform inside <see cref="DisposableObject.Dispose"/></remarks>
protected override void DisposeManagedResources()
{
AllocationReferenceHolder = null;
}

[SuppressMessage("ReSharper", "ConvertIfStatementToSwitchStatement")]
protected override void DisposeUnmanagedResources(IntPtr handle) protected override void DisposeUnmanagedResources(IntPtr handle)
{ {
c_api.TF_DeleteTensor(handle); c_api.TF_DeleteTensor(handle);

if (AllocationHandle == null)
return;

if (AllocationType == AllocationType.GCHandle)
{
((GCHandle) AllocationHandle).Free();
AllocationHandle = null;
AllocationType = AllocationType.None;
} else if (AllocationType == AllocationType.Marshal)
{
Marshal.FreeHGlobal((IntPtr) AllocationHandle);
AllocationHandle = null;
AllocationType = AllocationType.None;
} else
throw new InvalidOperationException($"Tensor.AllocationHandle is not null ({AllocationHandle}) but AllocationType is not matched to a C# allocation type ({AllocationType}).");
} }


public bool IsDisposed => _disposed; public bool IsDisposed => _disposed;


+ 73
- 0
src/TensorFlowNET.Core/Tensors/c_api.tensor.cs View File

@@ -15,6 +15,7 @@
******************************************************************************/ ******************************************************************************/


using System; using System;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;


namespace Tensorflow namespace Tensorflow
@@ -77,6 +78,51 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, UIntPtr len, Deallocator deallocator, ref DeallocatorArgs deallocator_arg); public static extern IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, UIntPtr len, Deallocator deallocator, ref DeallocatorArgs deallocator_arg);


/// <summary>
/// Return a new tensor that holds the bytes data[0,len-1]
/// </summary>
/// <param name="dataType"></param>
/// <param name="dims"></param>
/// <param name="num_dims"></param>
/// <param name="data"></param>
/// <param name="len">num_bytes, ex: 6 * sizeof(float)</param>
/// <param name="deallocator"></param>
/// <param name="deallocator_arg"></param>
/// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, UIntPtr len, Deallocator deallocator, IntPtr deallocator_arg);

/// <summary>
/// Return a new tensor that holds the bytes data[0,len-1]
/// </summary>
/// <param name="dataType"></param>
/// <param name="dims"></param>
/// <param name="num_dims"></param>
/// <param name="data"></param>
/// <param name="len">num_bytes, ex: 6 * sizeof(float)</param>
/// <param name="deallocator"></param>
/// <param name="deallocator_arg"></param>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static unsafe IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, UIntPtr len)
{
return TF_NewTensor(dataType, dims, num_dims, data, len, EmptyDeallocator, DeallocatorArgs.Empty);
}
/// <summary>
/// Return a new tensor that holds the bytes data[0,len-1]
/// </summary>
/// <param name="dataType"></param>
/// <param name="dims"></param>
/// <param name="num_dims"></param>
/// <param name="data"></param>
/// <param name="len">num_bytes, ex: 6 * sizeof(float)</param>
/// <param name="deallocator"></param>
/// <param name="deallocator_arg"></param>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static unsafe IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, void* data, UIntPtr len)
{
return TF_NewTensor(dataType, dims, num_dims, new IntPtr(data), len);
}

/// <summary> /// <summary>
/// Return the number of dimensions that the tensor has. /// Return the number of dimensions that the tensor has.
/// </summary> /// </summary>
@@ -159,5 +205,32 @@ namespace Tensorflow


[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern unsafe UIntPtr TF_StringDecode(byte* src, UIntPtr src_len, byte** dst, UIntPtr* dst_len, IntPtr status); public static extern unsafe UIntPtr TF_StringDecode(byte* src, UIntPtr src_len, byte** dst, UIntPtr* dst_len, IntPtr status);


public static c_api.Deallocator EmptyDeallocator = FreeNothingDeallocator;

[MonoPInvokeCallback(typeof(c_api.Deallocator))]
private static void FreeNothingDeallocator(IntPtr dataPtr, IntPtr len, ref c_api.DeallocatorArgs args)
{ }

/// <summary>
/// This attribute can be applied to callback functions that will be invoked
/// from unmanaged code to managed code.
/// </summary>
/// <remarks>
/// <code>
/// [TensorFlow.MonoPInvokeCallback (typeof (BufferReleaseFunc))]
/// internal static void MyFreeFunc (IntPtr data, IntPtr length){..}
/// </code>
/// </remarks>
public sealed class MonoPInvokeCallbackAttribute : Attribute
{
/// <summary>
/// Use this constructor to annotate the type of the callback function that
/// will be invoked from unmanaged code.
/// </summary>
/// <param name="t">T.</param>
public MonoPInvokeCallbackAttribute(Type t) { }
}
} }
} }

+ 21
- 0
src/TensorFlowNET.Core/Util/Locks.cs View File

@@ -0,0 +1,21 @@
using System.Threading;

namespace Tensorflow.Util
{
/// <summary>
/// Provides a set of locks on different shared levels.
/// </summary>
public static class Locks
{
private static readonly ThreadLocal<object> _lockpool = new ThreadLocal<object>(() => new object());

/// <summary>
/// A seperate lock for every requesting thread.
/// </summary>
/// <remarks>This property is thread-safe.</remarks>
public static object ThreadWide => _lockpool.Value;


public static readonly object ProcessWide = new object();
}
}

+ 41
- 36
src/TensorFlowNET.Core/ops.cs View File

@@ -19,13 +19,19 @@ using System.Collections.Generic;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using Google.Protobuf; using Google.Protobuf;
using System.Linq; using System.Linq;
using System.Threading;
using NumSharp; using NumSharp;
using Tensorflow.Util;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
{ {
public partial class ops public partial class ops
{ {
private static readonly ThreadLocal<DefaultGraphStack> _defaultGraphFactory = new ThreadLocal<DefaultGraphStack>(() => new DefaultGraphStack());

public static DefaultGraphStack default_graph_stack => _defaultGraphFactory.Value;

public static int tensor_id(Tensor tensor) public static int tensor_id(Tensor tensor)
{ {
return tensor.Id; return tensor.Id;
@@ -72,8 +78,6 @@ namespace Tensorflow
return get_default_graph().get_collection_ref(key); return get_default_graph().get_collection_ref(key);
} }


public static DefaultGraphStack default_graph_stack = new DefaultGraphStack();

/// <summary> /// <summary>
/// Returns the default graph for the current thread. /// Returns the default graph for the current thread.
/// ///
@@ -93,6 +97,7 @@ namespace Tensorflow
//return _default_graph_stack.get_default() //return _default_graph_stack.get_default()
return default_graph_stack.get_controller(); return default_graph_stack.get_controller();
} }

public static Graph set_default_graph(Graph graph) public static Graph set_default_graph(Graph graph)
{ {
//TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack! //TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack!
@@ -203,47 +208,49 @@ namespace Tensorflow
/// <returns>A wrapped TF_Operation*.</returns> /// <returns>A wrapped TF_Operation*.</returns>
public static (IntPtr, IntPtr) _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs) public static (IntPtr, IntPtr) _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs)
{ {
var op_desc = graph.NewOperation(node_def.Op, node_def.Name);

//TODO: Implement TF_SetDevice
//if node_def.device:
// c_api.TF_SetDevice(op_desc, compat.as_str(node_def.device))
// Add inputs
foreach (var op_input in inputs)
lock (Locks.ProcessWide)
{ {
if (op_input is Tensor[] op_inputs)
c_api.TF_AddInputList(op_desc, op_inputs.Select(x => x._as_tf_output()).ToArray(), op_inputs.Length);
else if (op_input is Tensor op_input1)
var op_desc = graph.NewOperation(node_def.Op, node_def.Name);

//TODO: Implement TF_SetDevice
//if node_def.device:
// c_api.TF_SetDevice(op_desc, compat.as_str(node_def.device))
// Add inputs
foreach (var op_input in inputs)
{ {
c_api.TF_AddInput(op_desc, op_input1._as_tf_output());
if (op_input is Tensor[] op_inputs)
c_api.TF_AddInputList(op_desc, op_inputs.Select(x => x._as_tf_output()).ToArray(), op_inputs.Length);
else if (op_input is Tensor op_input1)
{
c_api.TF_AddInput(op_desc, op_input1._as_tf_output());
} else
throw new NotImplementedException("_create_c_op");
} }
else
throw new NotImplementedException("_create_c_op");
}


var status = new Status();
var status = new Status();


// Add control inputs
foreach (var control_input in control_inputs)
c_api.TF_AddControlInput(op_desc, control_input);
// Add control inputs
foreach (var control_input in control_inputs)
c_api.TF_AddControlInput(op_desc, control_input);


// Add attrs
foreach (var attr in node_def.Attr)
{
var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream.
var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak
Marshal.Copy(bytes, 0, proto, bytes.Length);
uint len = (uint)bytes.Length;
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status);
// Add attrs
foreach (var attr in node_def.Attr)
{
var bytes = attr.Value.ToByteArray(); //TODO: we can use attr.Value.WriteTo with a memory stream.
var proto = Marshal.AllocHGlobal(bytes.Length); //TODO: potential memory leak
Marshal.Copy(bytes, 0, proto, bytes.Length);
uint len = (uint) bytes.Length;
c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: len, status: status);


status.Check(true);
}
status.Check(true);
}


var c_op = c_api.TF_FinishOperation(op_desc, status);
var c_op = c_api.TF_FinishOperation(op_desc, status);


status.Check(true);
status.Check(true);


return (c_op, op_desc);
return (c_op, op_desc);
}
} }


public static OpDef _get_op_def(Graph graph, string type) public static OpDef _get_op_def(Graph graph, string type)
@@ -311,7 +318,7 @@ namespace Tensorflow
/// <returns></returns> /// <returns></returns>
public static int uid() public static int uid()
{ {
return uid_number++;
return Interlocked.Increment(ref uid_number);
} }


public static void colocate_with(bool ignore_existing = false) public static void colocate_with(bool ignore_existing = false)
@@ -386,8 +393,6 @@ namespace Tensorflow
/// <returns>The default `Session` being used in the current thread.</returns> /// <returns>The default `Session` being used in the current thread.</returns>
public static Session get_default_session() public static Session get_default_session()
{ {
if (tf.defaultSession == null)
tf.defaultSession = tf.Session();
return tf.defaultSession; return tf.defaultSession;
} }




+ 10
- 1
src/TensorFlowNET.Core/tensorflow.cs View File

@@ -14,12 +14,15 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System.Threading;
using Tensorflow.Eager; using Tensorflow.Eager;


namespace Tensorflow namespace Tensorflow
{ {
public partial class tensorflow : IObjectLife public partial class tensorflow : IObjectLife
{ {
protected internal readonly ThreadLocal<Session> _defaultSessionFactory;

public TF_DataType @byte = TF_DataType.TF_UINT8; public TF_DataType @byte = TF_DataType.TF_UINT8;
public TF_DataType @sbyte = TF_DataType.TF_INT8; public TF_DataType @sbyte = TF_DataType.TF_INT8;
public TF_DataType int16 = TF_DataType.TF_INT16; public TF_DataType int16 = TF_DataType.TF_INT16;
@@ -34,7 +37,13 @@ namespace Tensorflow


public Context context = new Context(new ContextOptions(), new Status()); public Context context = new Context(new ContextOptions(), new Status());


public Session defaultSession;

public tensorflow()
{
_defaultSessionFactory = new ThreadLocal<Session>(Session);
}

public Session defaultSession => _defaultSessionFactory.Value;


public RefVariable Variable<T>(T data, public RefVariable Variable<T>(T data,
bool trainable = true, bool trainable = true,


+ 1
- 1
test/TensorFlowNET.Examples/ImageProcessing/ImageRecognitionInception.cs View File

@@ -89,7 +89,7 @@ namespace TensorFlowNET.Examples
Directory.CreateDirectory(dir); Directory.CreateDirectory(dir);


// get model file // get model file
string url = "https://storage.googleapis.com/download.tf.org/models/inception5h.zip";
string url = "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip";


Utility.Web.Download(url, dir, "inception5h.zip"); Utility.Web.Download(url, dir, "inception5h.zip");




+ 1
- 1
test/TensorFlowNET.Examples/ImageProcessing/InceptionArchGoogLeNet.cs View File

@@ -93,7 +93,7 @@ namespace TensorFlowNET.Examples
Directory.CreateDirectory(dir); Directory.CreateDirectory(dir);


// get model file // get model file
string url = "https://storage.googleapis.com/download.tf.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz";
string url = "https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz";
Utility.Web.Download(url, dir, $"{pbFile}.tar.gz"); Utility.Web.Download(url, dir, $"{pbFile}.tar.gz");




+ 2
- 2
test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs View File

@@ -33,7 +33,7 @@ namespace TensorFlowNET.Examples
/// and simply train a new classification layer on top. Transfer learning is a technique that shortcuts much of this /// and simply train a new classification layer on top. Transfer learning is a technique that shortcuts much of this
/// by taking a piece of a model that has already been trained on a related task and reusing it in a new model. /// by taking a piece of a model that has already been trained on a related task and reusing it in a new model.
/// ///
/// https://www.tf.org/hub/tutorials/image_retraining
/// https://www.tensorflow.org/hub/tutorials/image_retraining
/// </summary> /// </summary>
public class RetrainImageClassifier : IExample public class RetrainImageClassifier : IExample
{ {
@@ -168,7 +168,7 @@ namespace TensorFlowNET.Examples
/// weights, and then sets up all the gradients for the backward pass. /// weights, and then sets up all the gradients for the backward pass.
/// ///
/// The set up for the softmax and fully-connected layers is based on: /// The set up for the softmax and fully-connected layers is based on:
/// https://www.tf.org/tutorials/mnist/beginners/index.html
/// https://www.tensorflow.org/tutorials/mnist/beginners/index.html
/// </summary> /// </summary>
/// <param name="class_count"></param> /// <param name="class_count"></param>
/// <param name="final_tensor_name"></param> /// <param name="final_tensor_name"></param>


+ 1
- 1
test/TensorFlowNET.UnitTest/CApiGradientsTest.cs View File

@@ -11,7 +11,7 @@ namespace TensorFlowNET.UnitTest
/// tensorflow\c\c_api_test.cc /// tensorflow\c\c_api_test.cc
/// `class CApiGradientsTest` /// `class CApiGradientsTest`
/// </summary> /// </summary>
[TestClass]
[TestClass, Ignore]
public class CApiGradientsTest : CApiTest, IDisposable public class CApiGradientsTest : CApiTest, IDisposable
{ {
private Graph graph_ = new Graph(); private Graph graph_ = new Graph();


+ 10
- 6
test/TensorFlowNET.UnitTest/CSession.cs View File

@@ -2,6 +2,7 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using Tensorflow; using Tensorflow;
using Tensorflow.Util;


namespace TensorFlowNET.UnitTest namespace TensorFlowNET.UnitTest
{ {
@@ -22,9 +23,12 @@ namespace TensorFlowNET.UnitTest


public CSession(Graph graph, Status s, bool user_XLA = false) public CSession(Graph graph, Status s, bool user_XLA = false)
{ {
var opts = new SessionOptions();
opts.SetConfig(new ConfigProto { InterOpParallelismThreads = 4 });
session_ = new Session(graph, opts, s);
lock (Locks.ProcessWide)
{
var opts = new SessionOptions();
opts.SetConfig(new ConfigProto {InterOpParallelismThreads = 4});
session_ = new Session(graph, opts, s);
}
} }


public void SetInputs(Dictionary<Operation, Tensor> inputs) public void SetInputs(Dictionary<Operation, Tensor> inputs)
@@ -64,13 +68,13 @@ namespace TensorFlowNET.UnitTest
public unsafe void Run(Status s) public unsafe void Run(Status s)
{ {
var inputs_ptr = inputs_.ToArray(); var inputs_ptr = inputs_.ToArray();
var input_values_ptr = input_values_.Select(x => (IntPtr)x).ToArray();
var input_values_ptr = input_values_.Select(x => (IntPtr) x).ToArray();
var outputs_ptr = outputs_.ToArray(); var outputs_ptr = outputs_.ToArray();
var output_values_ptr = output_values_.Select(x => IntPtr.Zero).ToArray(); var output_values_ptr = output_values_.Select(x => IntPtr.Zero).ToArray();
IntPtr[] targets_ptr = new IntPtr[0]; IntPtr[] targets_ptr = new IntPtr[0];


c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length, c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length,
outputs_ptr, output_values_ptr, outputs_.Count,
outputs_ptr, output_values_ptr, outputs_.Count,
targets_ptr, targets_.Count, targets_ptr, targets_.Count,
IntPtr.Zero, s); IntPtr.Zero, s);


@@ -90,4 +94,4 @@ namespace TensorFlowNET.UnitTest
ResetOutputValues(); ResetOutputValues();
} }
} }
}
}

+ 4
- 4
test/TensorFlowNET.UnitTest/GraphTest.cs View File

@@ -207,7 +207,7 @@ namespace TensorFlowNET.UnitTest
public void ImportGraphDef() public void ImportGraphDef()
{ {
var s = new Status(); var s = new Status();
var graph = new Graph();
var graph = new Graph().as_default();


// Create a simple graph. // Create a simple graph.
c_test_util.Placeholder(graph, s); c_test_util.Placeholder(graph, s);
@@ -221,7 +221,7 @@ namespace TensorFlowNET.UnitTest


// Import it, with a prefix, in a fresh graph. // Import it, with a prefix, in a fresh graph.
graph.Dispose(); graph.Dispose();
graph = new Graph();
graph = new Graph().as_default();
var opts = c_api.TF_NewImportGraphDefOptions(); var opts = c_api.TF_NewImportGraphDefOptions();
c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported");
c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s); c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s);
@@ -359,7 +359,7 @@ namespace TensorFlowNET.UnitTest
public void ImportGraphDef_WithReturnOutputs() public void ImportGraphDef_WithReturnOutputs()
{ {
var s = new Status(); var s = new Status();
var graph = new Graph();
var graph = new Graph().as_default();


// Create a graph with two nodes: x and 3 // Create a graph with two nodes: x and 3
c_test_util.Placeholder(graph, s); c_test_util.Placeholder(graph, s);
@@ -375,7 +375,7 @@ namespace TensorFlowNET.UnitTest


// Import it in a fresh graph with return outputs. // Import it in a fresh graph with return outputs.
graph.Dispose(); graph.Dispose();
graph = new Graph();
graph = new Graph().as_default();
var opts = new ImportGraphDefOptions(); var opts = new ImportGraphDefOptions();
opts.AddReturnOutput("feed", 0); opts.AddReturnOutput("feed", 0);
opts.AddReturnOutput("scalar", 0); opts.AddReturnOutput("scalar", 0);


+ 263
- 0
test/TensorFlowNET.UnitTest/MultithreadingTests.cs View File

@@ -0,0 +1,263 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using FluentAssertions;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
using Tensorflow.Util;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest
{
[TestClass]
public class MultithreadingTests
{
[TestMethod]
public void SessionCreation()
{
ops.uid(); //increment id by one

MultiThreadedUnitTestExecuter.Run(8, Core);

//the core method
void Core(int tid)
{
tf.peak_default_graph().Should().BeNull();

using (var sess = tf.Session())
{
var default_graph = tf.peak_default_graph();
var sess_graph = sess.GetPrivate<Graph>("_graph");
sess_graph.Should().NotBeNull();
default_graph.Should().NotBeNull()
.And.BeEquivalentTo(sess_graph);
}
}
}

[TestMethod]
public void SessionCreation_x2()
{
ops.uid(); //increment id by one

MultiThreadedUnitTestExecuter.Run(16, Core);

//the core method
void Core(int tid)
{
tf.peak_default_graph().Should().BeNull();
//tf.Session created an other graph
using (var sess = tf.Session())
{
var default_graph = tf.peak_default_graph();
var sess_graph = sess.GetPrivate<Graph>("_graph");
sess_graph.Should().NotBeNull();
default_graph.Should().NotBeNull()
.And.BeEquivalentTo(sess_graph);
}
}
}

[TestMethod]
public void GraphCreation()
{
ops.uid(); //increment id by one

MultiThreadedUnitTestExecuter.Run(8, Core);

//the core method
void Core(int tid)
{
tf.peak_default_graph().Should().BeNull();
var beforehand = tf.get_default_graph(); //this should create default automatically.
beforehand.graph_key.Should().NotContain("-0/", "Already created a graph in an other thread.");
tf.peak_default_graph().Should().NotBeNull();

using (var sess = tf.Session())
{
var default_graph = tf.peak_default_graph();
var sess_graph = sess.GetPrivate<Graph>("_graph");
sess_graph.Should().NotBeNull();
default_graph.Should().NotBeNull()
.And.BeEquivalentTo(sess_graph)
.And.BeEquivalentTo(beforehand);

Console.WriteLine($"{tid}-{default_graph.graph_key}");

//var result = sess.run(new object[] {g, a});
//var actualDeriv = result[0].GetData<float>()[0];
//var actual = result[1].GetData<float>()[0];
}
}
}


[TestMethod]
public void Marshal_AllocHGlobal()
{
MultiThreadedUnitTestExecuter.Run(8, Core);

//the core method
void Core(int tid)
{
for (int i = 0; i < 100; i++)
{
Marshal.FreeHGlobal(Marshal.AllocHGlobal(sizeof(int)));
}
}
}

[TestMethod]
public void TensorCreation()
{
//lock (Locks.ProcessWide)
// tf.Session(); //create one to increase next id to 1.

MultiThreadedUnitTestExecuter.Run(8, Core);

//the core method
void Core(int tid)
{
using (var sess = tf.Session())
{
Tensor t = null;
for (int i = 0; i < 100; i++)
{
t = new Tensor(1);
}
}
}
}

[TestMethod]
public void TensorCreation_Array()
{
//lock (Locks.ProcessWide)
// tf.Session(); //create one to increase next id to 1.

MultiThreadedUnitTestExecuter.Run(8, Core);

//the core method
void Core(int tid)
{
//tf.Session created an other graph
using (var sess = tf.Session())
{
Tensor t = null;
for (int i = 0; i < 100; i++)
{
t = new Tensor(new int[] {1, 2, 3});
}
}
}
}

[TestMethod]
public void TensorCreation_Undressed()
{
//lock (Locks.ProcessWide)
// tf.Session(); //create one to increase next id to 1.

MultiThreadedUnitTestExecuter.Run(8, Core);

//the core method
unsafe void Core(int tid)
{
using (var sess = tf.Session())
{
Tensor t = null;
for (int i = 0; i < 100; i++)
{
var v = (int*) Marshal.AllocHGlobal(sizeof(int));
c_api.DeallocatorArgs _deallocatorArgs = new c_api.DeallocatorArgs();
var handle = c_api.TF_NewTensor(typeof(int).as_dtype(), dims: new long[0], num_dims: 0,
data: (IntPtr) v, len: (UIntPtr) sizeof(int),
deallocator: (IntPtr data, IntPtr size, ref c_api.DeallocatorArgs args) => Marshal.FreeHGlobal(data),
ref _deallocatorArgs);
c_api.TF_DeleteTensor(handle);
}
}
}
}

[TestMethod]
public void SessionRun()
{
MultiThreadedUnitTestExecuter.Run(8, Core);

//the core method
void Core(int tid)
{
tf.peak_default_graph().Should().BeNull();
//graph is created automatically to perform create these operations
var a1 = tf.constant(new[] {2f}, shape: new[] {1});
var a2 = tf.constant(new[] {3f}, shape: new[] {1});
var math = a1 + a2;
for (int i = 0; i < 100; i++)
{
using (var sess = tf.Session())
{
sess.run(math).GetAtIndex<float>(0).Should().Be(5);
}
}
}
}

[TestMethod]
public void SessionRun_InsideSession()
{
MultiThreadedUnitTestExecuter.Run(8, Core);

//the core method
void Core(int tid)
{
using (var sess = tf.Session())
{
tf.peak_default_graph().Should().NotBeNull();
//graph is created automatically to perform create these operations
var a1 = tf.constant(new[] {2f}, shape: new[] {1});
var a2 = tf.constant(new[] {3f}, shape: new[] {1});
var math = a1 + a2;

var result = sess.run(math);
result[0].GetAtIndex<float>(0).Should().Be(5);
}
}
}

[TestMethod]
public void SessionRun_Initialization()
{
MultiThreadedUnitTestExecuter.Run(8, Core);

//the core method
void Core(int tid)
{
using (var sess = tf.Session())
{
tf.peak_default_graph().Should().NotBeNull();
//graph is created automatically to perform create these operations
var a1 = tf.constant(new[] {2f}, shape: new[] {1});
var a2 = tf.constant(new[] {3f}, shape: new[] {1});
var math = a1 + a2;
}
}
}

[TestMethod]
public void SessionRun_Initialization_OutsideSession()
{
MultiThreadedUnitTestExecuter.Run(8, Core);

//the core method
void Core(int tid)
{
tf.peak_default_graph().Should().BeNull();
//graph is created automatically to perform create these operations
var a1 = tf.constant(new[] {2f}, shape: new[] {1});
var a2 = tf.constant(new[] {3f}, shape: new[] {1});
var math = a1 + a2;
}
}
}
}

+ 5
- 4
test/TensorFlowNET.UnitTest/SessionTest.cs View File

@@ -8,6 +8,7 @@ using System.Text;
using FluentAssertions; using FluentAssertions;
using Google.Protobuf; using Google.Protobuf;
using Tensorflow; using Tensorflow;
using Tensorflow.Util;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace TensorFlowNET.UnitTest namespace TensorFlowNET.UnitTest
@@ -19,13 +20,13 @@ namespace TensorFlowNET.UnitTest
/// tensorflow\c\c_api_test.cc /// tensorflow\c\c_api_test.cc
/// `TEST(CAPI, Session)` /// `TEST(CAPI, Session)`
/// </summary> /// </summary>
[TestMethod]
[TestMethod, Ignore]
public void Session() public void Session()
{ {
lock (this)
lock (Locks.ProcessWide)
{ {
var s = new Status(); var s = new Status();
var graph = new Graph();
var graph = new Graph().as_default();


// Make a placeholder operation. // Make a placeholder operation.
var feed = c_test_util.Placeholder(graph, s); var feed = c_test_util.Placeholder(graph, s);
@@ -93,7 +94,7 @@ namespace TensorFlowNET.UnitTest
using (var sess = tf.Session()) using (var sess = tf.Session())
{ {
var result = c.eval(sess); var result = c.eval(sess);
Assert.AreEqual(6, result.Data<double>()[0]);
Assert.AreEqual(6, result.GetAtIndex<double>(0));
} }
} }
} }


+ 2
- 0
test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj View File

@@ -10,6 +10,8 @@
<DelaySign>false</DelaySign> <DelaySign>false</DelaySign>


<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> <AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile>

<LangVersion>latest</LangVersion>
</PropertyGroup> </PropertyGroup>


<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">


+ 2
- 0
test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj.DotSettings View File

@@ -0,0 +1,2 @@
<wpf:ResourceDictionary xml:space="preserve" xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml" xmlns:s="clr-namespace:System;assembly=mscorlib" xmlns:ss="urn:shemas-jetbrains-com:settings-storage-xaml" xmlns:wpf="http://schemas.microsoft.com/winfx/2006/xaml/presentation">
<s:Boolean x:Key="/Default/CodeInspection/NamespaceProvider/NamespaceFoldersToSkip/=utilities/@EntryIndexedValue">True</s:Boolean></wpf:ResourceDictionary>

+ 44
- 57
test/TensorFlowNET.UnitTest/TensorTest.cs View File

@@ -4,6 +4,7 @@ using System;
using System.Linq; using System.Linq;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Threading; using System.Threading;
using FluentAssertions;
using Tensorflow; using Tensorflow;
using static Tensorflow.Binding; using static Tensorflow.Binding;


@@ -12,77 +13,63 @@ namespace TensorFlowNET.UnitTest
[TestClass] [TestClass]
public class TensorTest : CApiTest public class TensorTest : CApiTest
{ {
[Ignore("Not for mult-thread")]
public void TensorDeallocationThreadSafety()
{
var tensors = new Tensor[1000];
foreach (var i in range(1000))
{
tensors[i] = new Tensor(new int[1000]);
}
SemaphoreSlim s = new SemaphoreSlim(0, 2);
SemaphoreSlim s_done = new SemaphoreSlim(0, 2);

var t1 = new Thread(() =>
{
s.Wait();
foreach (var t in tensors)
t.Dispose();
s_done.Release();
});

var t2 = new Thread(() =>
{
s.Wait();
foreach (var t in tensors)
t.Dispose();
s_done.Release();
});

t1.Start();
t2.Start();
s.Release(2);
s_done.Wait();
s_done.Wait();

foreach (var t in tensors)
Assert.IsTrue(t.IsDisposed);
}

[TestMethod] [TestMethod]
public unsafe void TensorFromFixed() public unsafe void TensorFromFixed()
{ {
var array = new float[1000]; var array = new float[1000];
var span = new Span<float>(array, 100, 500); var span = new Span<float>(array, 100, 500);
fixed (float* ptr=&MemoryMarshal.GetReference(span))
fixed (float* ptr = &MemoryMarshal.GetReference(span))
{ {
using (var t = new Tensor((IntPtr)ptr, new long[] {span.Length}, tf.float32, 4*span.Length))
using (var t = new Tensor((IntPtr) ptr, new long[] {span.Length}, tf.float32, 4 * span.Length))
{ {
Assert.IsFalse(t.IsDisposed); Assert.IsFalse(t.IsDisposed);
Assert.IsFalse(t.IsMemoryOwner);
Assert.AreEqual(2000, (int) t.bytesize); Assert.AreEqual(2000, (int) t.bytesize);
} }
} }

fixed (float* ptr = &array[0]) fixed (float* ptr = &array[0])
{ {
using (var t = new Tensor((IntPtr)ptr, new long[] { array.Length }, tf.float32, 4 * array.Length))
using (var t = new Tensor((IntPtr) ptr, new long[] {array.Length}, tf.float32, 4 * array.Length))
{ {
Assert.IsFalse(t.IsDisposed); Assert.IsFalse(t.IsDisposed);
Assert.IsFalse(t.IsMemoryOwner);
Assert.AreEqual(4000, (int)t.bytesize);
Assert.AreEqual(4000, (int) t.bytesize);
} }
} }
} }


[TestMethod]
public unsafe void TensorFromArray()
{
var array = new float[1000];
using (var t = new Tensor(array, new long[] {array.Length}, tf.float32))
{
Assert.IsFalse(t.IsDisposed);
Assert.AreEqual(1000 * sizeof(float), (int) t.bytesize);
}

using (var t = new Tensor(new float[] {1}, new long[] {1}, tf.float32))
{
Assert.IsFalse(t.IsDisposed);
Assert.AreEqual(1 * sizeof(float), (int) t.bytesize);
}

using (var t = new Tensor(new float[] {1}, null, tf.float32))
{
Assert.IsFalse(t.IsDisposed);
Assert.AreEqual(1 * sizeof(float), (int) t.bytesize);
t.shape.Should().BeEmpty();
}
}

[TestMethod] [TestMethod]
public void AllocateTensor() public void AllocateTensor()
{ {
ulong num_bytes = 6 * sizeof(float); ulong num_bytes = 6 * sizeof(float);
long[] dims = { 2, 3 };
long[] dims = {2, 3};
Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes); Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes);
EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype); EXPECT_EQ(TF_DataType.TF_FLOAT, t.dtype);
EXPECT_EQ(2, t.NDims); EXPECT_EQ(2, t.NDims);
EXPECT_EQ((int)dims[0], t.shape[0]);
EXPECT_EQ((int) dims[0], t.shape[0]);
EXPECT_EQ(num_bytes, t.bytesize); EXPECT_EQ(num_bytes, t.bytesize);
t.Dispose(); t.Dispose();
} }
@@ -98,7 +85,7 @@ namespace TensorFlowNET.UnitTest
NDArray nd = np.array(2, 3); NDArray nd = np.array(2, 3);
Tensor t = new Tensor(nd); Tensor t = new Tensor(nd);
Tensor o = t.MaybeMove(); Tensor o = t.MaybeMove();
ASSERT_TRUE(o == IntPtr.Zero); // It is unsafe to move memory TF might not own.
ASSERT_TRUE(o == IntPtr.Zero); // It is unsafe to move memory TF might not own.
t.Dispose(); t.Dispose();
} }


@@ -116,10 +103,10 @@ namespace TensorFlowNET.UnitTest


EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT); EXPECT_EQ(tensor.dtype, TF_DataType.TF_FLOAT);
EXPECT_EQ(tensor.rank, nd.ndim); EXPECT_EQ(tensor.rank, nd.ndim);
EXPECT_EQ((int)tensor.shape[0], nd.shape[0]);
EXPECT_EQ((int)tensor.shape[1], nd.shape[1]);
EXPECT_EQ(tensor.bytesize, (ulong)nd.size * sizeof(float));
Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), new float[] { 1, 2, 3, 4, 5, 6 }));
EXPECT_EQ((int) tensor.shape[0], nd.shape[0]);
EXPECT_EQ((int) tensor.shape[1], nd.shape[1]);
EXPECT_EQ(tensor.bytesize, (ulong) nd.size * sizeof(float));
Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), new float[] {1, 2, 3, 4, 5, 6}));
} }


/// <summary> /// <summary>
@@ -130,7 +117,7 @@ namespace TensorFlowNET.UnitTest
public void SetShape() public void SetShape()
{ {
var s = new Status(); var s = new Status();
var graph = new Graph();
var graph = new Graph().as_default();


var feed = c_test_util.Placeholder(graph, s); var feed = c_test_util.Placeholder(graph, s);
var feed_out_0 = new TF_Output(feed, 0); var feed_out_0 = new TF_Output(feed, 0);
@@ -148,7 +135,7 @@ namespace TensorFlowNET.UnitTest
EXPECT_EQ(-1, num_dims); EXPECT_EQ(-1, num_dims);


// Set the shape to be 2 x Unknown // Set the shape to be 2 x Unknown
long[] dims = { 2, -1 };
long[] dims = {2, -1};
c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK); Assert.IsTrue(s.Code == TF_Code.TF_OK);
num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s);
@@ -177,8 +164,8 @@ namespace TensorFlowNET.UnitTest
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK); Assert.IsTrue(s.Code == TF_Code.TF_OK);
EXPECT_EQ(2, num_dims); EXPECT_EQ(2, num_dims);
EXPECT_EQ(2, (int)returned_dims[0]);
EXPECT_EQ(3, (int)returned_dims[1]);
EXPECT_EQ(2, (int) returned_dims[0]);
EXPECT_EQ(3, (int) returned_dims[1]);


// Try to set 'unknown' with same rank on the shape and see that // Try to set 'unknown' with same rank on the shape and see that
// it doesn't change. // it doesn't change.
@@ -189,8 +176,8 @@ namespace TensorFlowNET.UnitTest
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s);
Assert.IsTrue(s.Code == TF_Code.TF_OK); Assert.IsTrue(s.Code == TF_Code.TF_OK);
EXPECT_EQ(2, num_dims); EXPECT_EQ(2, num_dims);
EXPECT_EQ(2, (int)returned_dims[0]);
EXPECT_EQ(3, (int)returned_dims[1]);
EXPECT_EQ(2, (int) returned_dims[0]);
EXPECT_EQ(3, (int) returned_dims[1]);


// Try to fetch a shape with the wrong num_dims // Try to fetch a shape with the wrong num_dims
c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s); c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s);
@@ -216,4 +203,4 @@ namespace TensorFlowNET.UnitTest
s.Dispose(); s.Dispose();
} }
} }
}
}

+ 173
- 0
test/TensorFlowNET.UnitTest/Utilities/MultiThreadedUnitTestExecuter.cs View File

@@ -0,0 +1,173 @@
using System;
using System.Diagnostics;
using System.Threading;
using Microsoft.VisualStudio.TestTools.UnitTesting;

namespace TensorFlowNET.UnitTest
{
public delegate void MultiThreadedTestDelegate(int threadid);

/// <summary>
/// Creates a synchronized eco-system of running code.
/// </summary>
public class MultiThreadedUnitTestExecuter : IDisposable
{
public int ThreadCount { get; }
public Thread[] Threads { get; }
public Exception[] Exceptions { get; }
private readonly SemaphoreSlim barrier_threadstarted;
private readonly ManualResetEventSlim barrier_corestart;
private readonly SemaphoreSlim done_barrier2;

public Action<MultiThreadedUnitTestExecuter> PostRun { get; set; }

#region Static

[DebuggerHidden]
public static void Run(int threadCount, MultiThreadedTestDelegate workload)
{
if (workload == null) throw new ArgumentNullException(nameof(workload));
if (threadCount <= 0) throw new ArgumentOutOfRangeException(nameof(threadCount));
new MultiThreadedUnitTestExecuter(threadCount).Run(workload);
}

[DebuggerHidden]
public static void Run(int threadCount, params MultiThreadedTestDelegate[] workloads)
{
if (workloads == null) throw new ArgumentNullException(nameof(workloads));
if (workloads.Length == 0) throw new ArgumentException("Value cannot be an empty collection.", nameof(workloads));
if (threadCount <= 0) throw new ArgumentOutOfRangeException(nameof(threadCount));
new MultiThreadedUnitTestExecuter(threadCount).Run(workloads);
}

[DebuggerHidden]
public static void Run(int threadCount, MultiThreadedTestDelegate workload, Action<MultiThreadedUnitTestExecuter> postRun)
{
if (workload == null) throw new ArgumentNullException(nameof(workload));
if (postRun == null) throw new ArgumentNullException(nameof(postRun));
if (threadCount <= 0) throw new ArgumentOutOfRangeException(nameof(threadCount));
new MultiThreadedUnitTestExecuter(threadCount) {PostRun = postRun}.Run(workload);
}

#endregion


/// <summary>Initializes a new instance of the <see cref="T:System.Object"></see> class.</summary>
public MultiThreadedUnitTestExecuter(int threadCount)
{
if (threadCount <= 0)
throw new ArgumentOutOfRangeException(nameof(threadCount));
ThreadCount = threadCount;
Threads = new Thread[ThreadCount];
Exceptions = new Exception[ThreadCount];
done_barrier2 = new SemaphoreSlim(0, threadCount);
barrier_corestart = new ManualResetEventSlim();
barrier_threadstarted = new SemaphoreSlim(0, threadCount);
}

[DebuggerHidden]
public void Run(params MultiThreadedTestDelegate[] workloads)
{
if (workloads == null)
throw new ArgumentNullException(nameof(workloads));
if (workloads.Length != 1 && workloads.Length % ThreadCount != 0)
throw new InvalidOperationException($"Run method must accept either 1 workload or n-threads workloads. Got {workloads.Length} workloads.");

if (ThreadCount == 1)
{
Exception ex = null;
new Thread(() =>
{
try
{
workloads[0](0);
} catch (Exception e)
{
if (Debugger.IsAttached)
throw;
ex = e;
} finally
{
done_barrier2.Release(1);
}
}).Start();

done_barrier2.Wait();

if (ex != null)
throw new Exception($"Thread 0 has failed: ", ex);

PostRun?.Invoke(this);

return;
}

//thread core
Exception ThreadCore(MultiThreadedTestDelegate core, int threadid)
{
barrier_threadstarted.Release(1);
barrier_corestart.Wait();
//workload
try
{
core(threadid);
} catch (Exception e)
{
if (Debugger.IsAttached)
throw;
return e;
} finally
{
done_barrier2.Release(1);
}

return null;
}

//initialize all threads
if (workloads.Length == 1)
{
var workload = workloads[0];
for (int i = 0; i < ThreadCount; i++)
{
var i_local = i;
Threads[i] = new Thread(() => Exceptions[i_local] = ThreadCore(workload, i_local));
}
} else
{
for (int i = 0; i < ThreadCount; i++)
{
var i_local = i;
var workload = workloads[i_local % workloads.Length];
Threads[i] = new Thread(() => Exceptions[i_local] = ThreadCore(workload, i_local));
}
}

//run all threads
for (int i = 0; i < ThreadCount; i++) Threads[i].Start();
//wait for threads to be started and ready
for (int i = 0; i < ThreadCount; i++) barrier_threadstarted.Wait();

//signal threads to start
barrier_corestart.Set();

//wait for threads to finish
for (int i = 0; i < ThreadCount; i++) done_barrier2.Wait();

//handle fails
for (int i = 0; i < ThreadCount; i++)
if (Exceptions[i] != null)
throw new Exception($"Thread {i} has failed: ", Exceptions[i]);

//checks after ended
PostRun?.Invoke(this);
}

public void Dispose()
{
barrier_threadstarted.Dispose();
barrier_corestart.Dispose();
done_barrier2.Dispose();
}
}
}

+ 914
- 0
test/TensorFlowNET.UnitTest/Utilities/PrivateObject.cs View File

@@ -0,0 +1,914 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

namespace Microsoft.VisualStudio.TestTools.UnitTesting
{
using System;
using System.Collections.Generic;
//using System.Diagnostics;
//using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Reflection;

/// <summary>
/// This class represents the live NON public INTERNAL object in the system
/// </summary>
internal class PrivateObject
{
#region Data

// bind everything
private const BindingFlags BindToEveryThing = BindingFlags.Default | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Public;

private static BindingFlags constructorFlags = BindingFlags.Instance | BindingFlags.Public | BindingFlags.CreateInstance | BindingFlags.NonPublic;

private object target; // automatically initialized to null
private Type originalType; // automatically initialized to null

//private Dictionary<string, LinkedList<MethodInfo>> methodCache; // automatically initialized to null

#endregion

#region Constructors

///// <summary>
///// Initializes a new instance of the <see cref="PrivateObject"/> class that contains
///// the already existing object of the private class
///// </summary>
///// <param name="obj"> object that serves as starting point to reach the private members</param>
///// <param name="memberToAccess">the derefrencing string using . that points to the object to be retrived as in m_X.m_Y.m_Z</param>
//[SuppressMessage("Microsoft.Naming", "CA1720:IdentifiersShouldNotContainTypeNames", MessageId = "obj", Justification = "We don't know anything about the object other than that it's an object, so 'obj' seems reasonable")]
//public PrivateObject(object obj, string memberToAccess)
//{
// Helper.CheckParameterNotNull(obj, "obj", string.Empty);
// ValidateAccessString(memberToAccess);

// PrivateObject temp = obj as PrivateObject;
// if (temp == null)
// {
// temp = new PrivateObject(obj);
// }

// // Split The access string
// string[] arr = memberToAccess.Split(new char[] { '.' });

// for (int i = 0; i < arr.Length; i++)
// {
// object next = temp.InvokeHelper(arr[i], BindToEveryThing | BindingFlags.Instance | BindingFlags.GetField | BindingFlags.GetProperty, null, CultureInfo.InvariantCulture);
// temp = new PrivateObject(next);
// }

// this.target = temp.target;
// this.originalType = temp.originalType;
//}

///// <summary>
///// Initializes a new instance of the <see cref="PrivateObject"/> class that wraps the
///// specified type.
///// </summary>
///// <param name="assemblyName">Name of the assembly</param>
///// <param name="typeName">fully qualified name</param>
///// <param name="args">Argmenets to pass to the constructor</param>
//public PrivateObject(string assemblyName, string typeName, params object[] args)
// : this(assemblyName, typeName, null, args)
//{
//}

///// <summary>
///// Initializes a new instance of the <see cref="PrivateObject"/> class that wraps the
///// specified type.
///// </summary>
///// <param name="assemblyName">Name of the assembly</param>
///// <param name="typeName">fully qualified name</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the constructor to get</param>
///// <param name="args">Argmenets to pass to the constructor</param>
//public PrivateObject(string assemblyName, string typeName, Type[] parameterTypes, object[] args)
// : this(Type.GetType(string.Format(CultureInfo.InvariantCulture, "{0}, {1}", typeName, assemblyName), false), parameterTypes, args)
//{
// Helper.CheckParameterNotNull(assemblyName, "assemblyName", string.Empty);
// Helper.CheckParameterNotNull(typeName, "typeName", string.Empty);
//}

///// <summary>
///// Initializes a new instance of the <see cref="PrivateObject"/> class that wraps the
///// specified type.
///// </summary>
///// <param name="type">type of the object to create</param>
///// <param name="args">Argmenets to pass to the constructor</param>
//public PrivateObject(Type type, params object[] args)
// : this(type, null, args)
//{
// Helper.CheckParameterNotNull(type, "type", string.Empty);
//}

///// <summary>
///// Initializes a new instance of the <see cref="PrivateObject"/> class that wraps the
///// specified type.
///// </summary>
///// <param name="type">type of the object to create</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the constructor to get</param>
///// <param name="args">Argmenets to pass to the constructor</param>
//public PrivateObject(Type type, Type[] parameterTypes, object[] args)
//{
// Helper.CheckParameterNotNull(type, "type", string.Empty);
// object o;
// if (parameterTypes != null)
// {
// ConstructorInfo ci = type.GetConstructor(BindToEveryThing, null, parameterTypes, null);
// if (ci == null)
// {
// throw new ArgumentException(FrameworkMessages.PrivateAccessorConstructorNotFound);
// }

// try
// {
// o = ci.Invoke(args);
// }
// catch (TargetInvocationException e)
// {
// Debug.Assert(e.InnerException != null, "Inner exception should not be null.");
// if (e.InnerException != null)
// {
// throw e.InnerException;
// }

// throw;
// }
// }
// else
// {
// o = Activator.CreateInstance(type, constructorFlags, null, args, null);
// }

// this.ConstructFrom(o);
//}

/// <summary>
/// Initializes a new instance of the <see cref="PrivateObject"/> class that wraps
/// the given object.
/// </summary>
/// <param name="obj">object to wrap</param>
//[SuppressMessage("Microsoft.Naming", "CA1720:IdentifiersShouldNotContainTypeNames", MessageId = "obj", Justification = "We don't know anything about the object other than that it's an object, so 'obj' seems reasonable")]
public PrivateObject(object obj)
{
Helper.CheckParameterNotNull(obj, "obj", string.Empty);
this.ConstructFrom(obj);
}

/// <summary>
/// Initializes a new instance of the <see cref="PrivateObject"/> class that wraps
/// the given object.
/// </summary>
/// <param name="obj">object to wrap</param>
/// <param name="type">PrivateType object</param>
//[SuppressMessage("Microsoft.Naming", "CA1720:IdentifiersShouldNotContainTypeNames", MessageId = "obj", Justification = "We don't know anything about the object other than that it's an an object, so 'obj' seems reasonable")]
public PrivateObject(object obj, PrivateType type)
{
Helper.CheckParameterNotNull(type, "type", string.Empty);
this.target = obj;
this.originalType = type.ReferencedType;
}

#endregion

///// <summary>
///// Gets or sets the target
///// </summary>
//public object Target
//{
// get
// {
// return this.target;
// }

// set
// {
// Helper.CheckParameterNotNull(value, "Target", string.Empty);
// this.target = value;
// this.originalType = value.GetType();
// }
//}

///// <summary>
///// Gets the type of underlying object
///// </summary>
//public Type RealType
//{
// get
// {
// return this.originalType;
// }
//}

//private Dictionary<string, LinkedList<MethodInfo>> GenericMethodCache
//{
// get
// {
// if (this.methodCache == null)
// {
// this.BuildGenericMethodCacheForType(this.originalType);
// }

// Debug.Assert(this.methodCache != null, "Invalid method cache for type.");

// return this.methodCache;
// }
//}

/// <summary>
/// returns the hash code of the target object
/// </summary>
/// <returns>int representing hashcode of the target object</returns>
public override int GetHashCode()
{
//Debug.Assert(this.target != null, "target should not be null.");
return this.target.GetHashCode();
}

/// <summary>
/// Equals
/// </summary>
/// <param name="obj">Object with whom to compare</param>
/// <returns>returns true if the objects are equal.</returns>
public override bool Equals(object obj)
{
if (this != obj)
{
//Debug.Assert(this.target != null, "target should not be null.");
if (typeof(PrivateObject) == obj?.GetType())
{
return this.target.Equals(((PrivateObject) obj).target);
} else
{
return false;
}
}

return true;
}

///// <summary>
///// Invokes the specified method
///// </summary>
///// <param name="name">Name of the method</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <returns>Result of method call</returns>
//public object Invoke(string name, params object[] args)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// return this.Invoke(name, null, args, CultureInfo.InvariantCulture);
//}

///// <summary>
///// Invokes the specified method
///// </summary>
///// <param name="name">Name of the method</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the method to get.</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <returns>Result of method call</returns>
//public object Invoke(string name, Type[] parameterTypes, object[] args)
//{
// return this.Invoke(name, parameterTypes, args, CultureInfo.InvariantCulture);
//}

///// <summary>
///// Invokes the specified method
///// </summary>
///// <param name="name">Name of the method</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the method to get.</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <param name="typeArguments">An array of types corresponding to the types of the generic arguments.</param>
///// <returns>Result of method call</returns>
//public object Invoke(string name, Type[] parameterTypes, object[] args, Type[] typeArguments)
//{
// return this.Invoke(name, BindToEveryThing, parameterTypes, args, CultureInfo.InvariantCulture, typeArguments);
//}

///// <summary>
///// Invokes the specified method
///// </summary>
///// <param name="name">Name of the method</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <param name="culture">Culture info</param>
///// <returns>Result of method call</returns>
//public object Invoke(string name, object[] args, CultureInfo culture)
//{
// return this.Invoke(name, null, args, culture);
//}

///// <summary>
///// Invokes the specified method
///// </summary>
///// <param name="name">Name of the method</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the method to get.</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <param name="culture">Culture info</param>
///// <returns>Result of method call</returns>
//public object Invoke(string name, Type[] parameterTypes, object[] args, CultureInfo culture)
//{
// return this.Invoke(name, BindToEveryThing, parameterTypes, args, culture);
//}

///// <summary>
///// Invokes the specified method
///// </summary>
///// <param name="name">Name of the method</param>
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <returns>Result of method call</returns>
//public object Invoke(string name, BindingFlags bindingFlags, params object[] args)
//{
// return this.Invoke(name, bindingFlags, null, args, CultureInfo.InvariantCulture);
//}

///// <summary>
///// Invokes the specified method
///// </summary>
///// <param name="name">Name of the method</param>
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the method to get.</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <returns>Result of method call</returns>
//public object Invoke(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args)
//{
// return this.Invoke(name, bindingFlags, parameterTypes, args, CultureInfo.InvariantCulture);
//}

///// <summary>
///// Invokes the specified method
///// </summary>
///// <param name="name">Name of the method</param>
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <param name="culture">Culture info</param>
///// <returns>Result of method call</returns>
//public object Invoke(string name, BindingFlags bindingFlags, object[] args, CultureInfo culture)
//{
// return this.Invoke(name, bindingFlags, null, args, culture);
//}

///// <summary>
///// Invokes the specified method
///// </summary>
///// <param name="name">Name of the method</param>
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the method to get.</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <param name="culture">Culture info</param>
///// <returns>Result of method call</returns>
//public object Invoke(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args, CultureInfo culture)
//{
// return this.Invoke(name, bindingFlags, parameterTypes, args, culture, null);
//}

///// <summary>
///// Invokes the specified method
///// </summary>
///// <param name="name">Name of the method</param>
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the method to get.</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <param name="culture">Culture info</param>
///// <param name="typeArguments">An array of types corresponding to the types of the generic arguments.</param>
///// <returns>Result of method call</returns>
//public object Invoke(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args, CultureInfo culture, Type[] typeArguments)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// if (parameterTypes != null)
// {
// bindingFlags |= BindToEveryThing | BindingFlags.Instance;

// // Fix up the parameter types
// MethodInfo member = this.originalType.GetMethod(name, bindingFlags, null, parameterTypes, null);

// // If the method was not found and type arguments were provided for generic paramaters,
// // attempt to look up a generic method.
// if ((member == null) && (typeArguments != null))
// {
// // This method may contain generic parameters...if so, the previous call to
// // GetMethod() will fail because it doesn't fully support generic parameters.

// // Look in the method cache to see if there is a generic method
// // on the incoming type that contains the correct signature.
// member = this.GetGenericMethodFromCache(name, parameterTypes, typeArguments, bindingFlags, null);
// }

// if (member == null)
// {
// throw new ArgumentException(
// string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name));
// }

// try
// {
// if (member.IsGenericMethodDefinition)
// {
// MethodInfo constructed = member.MakeGenericMethod(typeArguments);
// return constructed.Invoke(this.target, bindingFlags, null, args, culture);
// }
// else
// {
// return member.Invoke(this.target, bindingFlags, null, args, culture);
// }
// }
// catch (TargetInvocationException e)
// {
// Debug.Assert(e.InnerException != null, "Inner exception should not be null.");
// if (e.InnerException != null)
// {
// throw e.InnerException;
// }

// throw;
// }
// }
// else
// {
// return this.InvokeHelper(name, bindingFlags | BindingFlags.InvokeMethod, args, culture);
// }
//}

///// <summary>
///// Gets the array element using array of subsrcipts for each dimension
///// </summary>
///// <param name="name">Name of the member</param>
///// <param name="indices">the indices of array</param>
///// <returns>An arrya of elements.</returns>
//public object GetArrayElement(string name, params int[] indices)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// return this.GetArrayElement(name, BindToEveryThing, indices);
//}

///// <summary>
///// Sets the array element using array of subsrcipts for each dimension
///// </summary>
///// <param name="name">Name of the member</param>
///// <param name="value">Value to set</param>
///// <param name="indices">the indices of array</param>
//public void SetArrayElement(string name, object value, params int[] indices)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// this.SetArrayElement(name, BindToEveryThing, value, indices);
//}

///// <summary>
///// Gets the array element using array of subsrcipts for each dimension
///// </summary>
///// <param name="name">Name of the member</param>
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
///// <param name="indices">the indices of array</param>
///// <returns>An arrya of elements.</returns>
//public object GetArrayElement(string name, BindingFlags bindingFlags, params int[] indices)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// Array arr = (Array)this.InvokeHelper(name, BindingFlags.GetField | bindingFlags, null, CultureInfo.InvariantCulture);
// return arr.GetValue(indices);
//}

///// <summary>
///// Sets the array element using array of subsrcipts for each dimension
///// </summary>
///// <param name="name">Name of the member</param>
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
///// <param name="value">Value to set</param>
///// <param name="indices">the indices of array</param>
//public void SetArrayElement(string name, BindingFlags bindingFlags, object value, params int[] indices)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// Array arr = (Array)this.InvokeHelper(name, BindingFlags.GetField | bindingFlags, null, CultureInfo.InvariantCulture);
// arr.SetValue(value, indices);
//}

///// <summary>
///// Get the field
///// </summary>
///// <param name="name">Name of the field</param>
///// <returns>The field.</returns>
//public object GetField(string name)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// return this.GetField(name, BindToEveryThing);
//}

///// <summary>
///// Sets the field
///// </summary>
///// <param name="name">Name of the field</param>
///// <param name="value">value to set</param>
//public void SetField(string name, object value)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// this.SetField(name, BindToEveryThing, value);
//}

///// <summary>
///// Gets the field
///// </summary>
///// <param name="name">Name of the field</param>
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
///// <returns>The field.</returns>
//public object GetField(string name, BindingFlags bindingFlags)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// return this.InvokeHelper(name, BindingFlags.GetField | bindingFlags, null, CultureInfo.InvariantCulture);
//}

///// <summary>
///// Sets the field
///// </summary>
///// <param name="name">Name of the field</param>
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
///// <param name="value">value to set</param>
//public void SetField(string name, BindingFlags bindingFlags, object value)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// this.InvokeHelper(name, BindingFlags.SetField | bindingFlags, new object[] { value }, CultureInfo.InvariantCulture);
//}

/// <summary>
/// Get the field or property
/// </summary>
/// <param name="name">Name of the field or property</param>
/// <returns>The field or property.</returns>
public object GetFieldOrProperty(string name)
{
Helper.CheckParameterNotNull(name, "name", string.Empty);
return this.GetFieldOrProperty(name, BindToEveryThing);
}

/// <summary>
/// Sets the field or property
/// </summary>
/// <param name="name">Name of the field or property</param>
/// <param name="value">value to set</param>
public void SetFieldOrProperty(string name, object value)
{
Helper.CheckParameterNotNull(name, "name", string.Empty);
this.SetFieldOrProperty(name, BindToEveryThing, value);
}

/// <summary>
/// Gets the field or property
/// </summary>
/// <param name="name">Name of the field or property</param>
/// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
/// <returns>The field or property.</returns>
public object GetFieldOrProperty(string name, BindingFlags bindingFlags)
{
Helper.CheckParameterNotNull(name, "name", string.Empty);
return this.InvokeHelper(name, BindingFlags.GetField | BindingFlags.GetProperty | bindingFlags, null, CultureInfo.InvariantCulture);
}

/// <summary>
/// Sets the field or property
/// </summary>
/// <param name="name">Name of the field or property</param>
/// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
/// <param name="value">value to set</param>
public void SetFieldOrProperty(string name, BindingFlags bindingFlags, object value)
{
Helper.CheckParameterNotNull(name, "name", string.Empty);
this.InvokeHelper(name, BindingFlags.SetField | BindingFlags.SetProperty | bindingFlags, new object[] {value}, CultureInfo.InvariantCulture);
}

///// <summary>
///// Gets the property
///// </summary>
///// <param name="name">Name of the property</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <returns>The property.</returns>
//public object GetProperty(string name, params object[] args)
//{
// return this.GetProperty(name, null, args);
//}

///// <summary>
///// Gets the property
///// </summary>
///// <param name="name">Name of the property</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the indexed property.</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <returns>The property.</returns>
//public object GetProperty(string name, Type[] parameterTypes, object[] args)
//{
// return this.GetProperty(name, BindToEveryThing, parameterTypes, args);
//}

///// <summary>
///// Set the property
///// </summary>
///// <param name="name">Name of the property</param>
///// <param name="value">value to set</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
//public void SetProperty(string name, object value, params object[] args)
//{
// this.SetProperty(name, null, value, args);
//}

///// <summary>
///// Set the property
///// </summary>
///// <param name="name">Name of the property</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the indexed property.</param>
///// <param name="value">value to set</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
//public void SetProperty(string name, Type[] parameterTypes, object value, object[] args)
//{
// this.SetProperty(name, BindToEveryThing, value, parameterTypes, args);
//}

///// <summary>
///// Gets the property
///// </summary>
///// <param name="name">Name of the property</param>
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <returns>The property.</returns>
//public object GetProperty(string name, BindingFlags bindingFlags, params object[] args)
//{
// return this.GetProperty(name, bindingFlags, null, args);
//}

///// <summary>
///// Gets the property
///// </summary>
///// <param name="name">Name of the property</param>
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the indexed property.</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <returns>The property.</returns>
//public object GetProperty(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// if (parameterTypes != null)
// {
// PropertyInfo pi = this.originalType.GetProperty(name, bindingFlags, null, null, parameterTypes, null);
// if (pi == null)
// {
// throw new ArgumentException(
// string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name));
// }

// return pi.GetValue(this.target, args);
// }
// else
// {
// return this.InvokeHelper(name, bindingFlags | BindingFlags.GetProperty, args, null);
// }
//}

///// <summary>
///// Sets the property
///// </summary>
///// <param name="name">Name of the property</param>
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
///// <param name="value">value to set</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
//public void SetProperty(string name, BindingFlags bindingFlags, object value, params object[] args)
//{
// this.SetProperty(name, bindingFlags, value, null, args);
//}

///// <summary>
///// Sets the property
///// </summary>
///// <param name="name">Name of the property</param>
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
///// <param name="value">value to set</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the indexed property.</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
//public void SetProperty(string name, BindingFlags bindingFlags, object value, Type[] parameterTypes, object[] args)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);

// if (parameterTypes != null)
// {
// PropertyInfo pi = this.originalType.GetProperty(name, bindingFlags, null, null, parameterTypes, null);
// if (pi == null)
// {
// throw new ArgumentException(
// string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name));
// }

// pi.SetValue(this.target, value, args);
// }
// else
// {
// object[] pass = new object[(args?.Length ?? 0) + 1];
// pass[0] = value;
// args?.CopyTo(pass, 1);
// this.InvokeHelper(name, bindingFlags | BindingFlags.SetProperty, pass, null);
// }
//}

#region Private Helpers

///// <summary>
///// Validate access string
///// </summary>
///// <param name="access"> access string</param>
//private static void ValidateAccessString(string access)
//{
// Helper.CheckParameterNotNull(access, "access", string.Empty);
// if (access.Length == 0)
// {
// throw new ArgumentException(FrameworkMessages.AccessStringInvalidSyntax);
// }

// string[] arr = access.Split('.');
// foreach (string str in arr)
// {
// if ((str.Length == 0) || (str.IndexOfAny(new char[] { ' ', '\t', '\n' }) != -1))
// {
// throw new ArgumentException(FrameworkMessages.AccessStringInvalidSyntax);
// }
// }
//}

/// <summary>
/// Invokes the memeber
/// </summary>
/// <param name="name">Name of the member</param>
/// <param name="bindingFlags">Additional attributes</param>
/// <param name="args">Arguments for the invocation</param>
/// <param name="culture">Culture</param>
/// <returns>Result of the invocation</returns>
private object InvokeHelper(string name, BindingFlags bindingFlags, object[] args, CultureInfo culture)
{
Helper.CheckParameterNotNull(name, "name", string.Empty);
//Debug.Assert(this.target != null, "Internal Error: Null reference is returned for internal object");

// Invoke the actual Method
try
{
return this.originalType.InvokeMember(name, bindingFlags, null, this.target, args, culture);
} catch (TargetInvocationException e)
{
//Debug.Assert(e.InnerException != null, "Inner exception should not be null.");
if (e.InnerException != null)
{
throw e.InnerException;
}

throw;
}
}

private void ConstructFrom(object obj)
{
Helper.CheckParameterNotNull(obj, "obj", string.Empty);
this.target = obj;
this.originalType = obj.GetType();
}

//private void BuildGenericMethodCacheForType(Type t)
//{
// Debug.Assert(t != null, "type should not be null.");
// this.methodCache = new Dictionary<string, LinkedList<MethodInfo>>();

// MethodInfo[] members = t.GetMethods(BindToEveryThing);
// LinkedList<MethodInfo> listByName; // automatically initialized to null

// foreach (MethodInfo member in members)
// {
// if (member.IsGenericMethod || member.IsGenericMethodDefinition)
// {
// if (!this.GenericMethodCache.TryGetValue(member.Name, out listByName))
// {
// listByName = new LinkedList<MethodInfo>();
// this.GenericMethodCache.Add(member.Name, listByName);
// }

// Debug.Assert(listByName != null, "list should not be null.");
// listByName.AddLast(member);
// }
// }
//}

///// <summary>
///// Extracts the most appropriate generic method signature from the current private type.
///// </summary>
///// <param name="methodName">The name of the method in which to search the signature cache.</param>
///// <param name="parameterTypes">An array of types corresponding to the types of the parameters in which to search.</param>
///// <param name="typeArguments">An array of types corresponding to the types of the generic arguments.</param>
///// <param name="bindingFlags"><see cref="BindingFlags"/> to further filter the method signatures.</param>
///// <param name="modifiers">Modifiers for parameters.</param>
///// <returns>A methodinfo instance.</returns>
//private MethodInfo GetGenericMethodFromCache(string methodName, Type[] parameterTypes, Type[] typeArguments, BindingFlags bindingFlags, ParameterModifier[] modifiers)
//{
// Debug.Assert(!string.IsNullOrEmpty(methodName), "Invalid method name.");
// Debug.Assert(parameterTypes != null, "Invalid parameter type array.");
// Debug.Assert(typeArguments != null, "Invalid type arguments array.");

// // Build a preliminary list of method candidates that contain roughly the same signature.
// var methodCandidates = this.GetMethodCandidates(methodName, parameterTypes, typeArguments, bindingFlags, modifiers);

// // Search of ambiguous methods (methods with the same signature).
// MethodInfo[] finalCandidates = new MethodInfo[methodCandidates.Count];
// methodCandidates.CopyTo(finalCandidates, 0);

// if ((parameterTypes != null) && (parameterTypes.Length == 0))
// {
// for (int i = 0; i < finalCandidates.Length; i++)
// {
// MethodInfo methodInfo = finalCandidates[i];

// if (!RuntimeTypeHelper.CompareMethodSigAndName(methodInfo, finalCandidates[0]))
// {
// throw new AmbiguousMatchException();
// }
// }

// // All the methods have the exact same name and sig so return the most derived one.
// return RuntimeTypeHelper.FindMostDerivedNewSlotMeth(finalCandidates, finalCandidates.Length) as MethodInfo;
// }

// // Now that we have a preliminary list of candidates, select the most appropriate one.
// return RuntimeTypeHelper.SelectMethod(bindingFlags, finalCandidates, parameterTypes, modifiers) as MethodInfo;
//}

//private LinkedList<MethodInfo> GetMethodCandidates(string methodName, Type[] parameterTypes, Type[] typeArguments, BindingFlags bindingFlags, ParameterModifier[] modifiers)
//{
// Debug.Assert(!string.IsNullOrEmpty(methodName), "methodName should not be null.");
// Debug.Assert(parameterTypes != null, "parameterTypes should not be null.");
// Debug.Assert(typeArguments != null, "typeArguments should not be null.");

// LinkedList<MethodInfo> methodCandidates = new LinkedList<MethodInfo>();
// LinkedList<MethodInfo> methods = null;

// if (!this.GenericMethodCache.TryGetValue(methodName, out methods))
// {
// return methodCandidates;
// }

// Debug.Assert(methods != null, "methods should not be null.");

// foreach (MethodInfo candidate in methods)
// {
// bool paramMatch = true;
// ParameterInfo[] candidateParams = null;
// Type[] genericArgs = candidate.GetGenericArguments();
// Type sourceParameterType = null;

// if (genericArgs.Length != typeArguments.Length)
// {
// continue;
// }

// // Since we can't just get the correct MethodInfo from Reflection,
// // we will just match the number of parameters, their order, and their type
// var methodCandidate = candidate;
// candidateParams = methodCandidate.GetParameters();

// if (candidateParams.Length != parameterTypes.Length)
// {
// continue;
// }

// // Exact binding
// if ((bindingFlags & BindingFlags.ExactBinding) != 0)
// {
// int i = 0;

// foreach (ParameterInfo candidateParam in candidateParams)
// {
// sourceParameterType = parameterTypes[i++];

// if (candidateParam.ParameterType.ContainsGenericParameters)
// {
// // Since we have a generic parameter here, just make sure the IsArray matches.
// if (candidateParam.ParameterType.IsArray != sourceParameterType.IsArray)
// {
// paramMatch = false;
// break;
// }
// }
// else
// {
// if (candidateParam.ParameterType != sourceParameterType)
// {
// paramMatch = false;
// break;
// }
// }
// }

// if (paramMatch)
// {
// methodCandidates.AddLast(methodCandidate);
// continue;
// }
// }
// else
// {
// methodCandidates.AddLast(methodCandidate);
// }
// }

// return methodCandidates;
//}

#endregion
}
}

+ 314
- 0
test/TensorFlowNET.UnitTest/Utilities/PrivateObjectExtensions.cs View File

@@ -0,0 +1,314 @@
// <copyright file="PrivateObjectExtensions.cs">
// Copyright (c) 2019 cactuaroid All Rights Reserved
// </copyright>
// <summary>
// Released under the MIT license
// https://github.com/cactuaroid/PrivateObjectExtensions
// </summary>

using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Linq;
using System.Reflection;

namespace System
{
/// <summary>
/// Extension methods for PrivateObject
/// </summary>
public static class PrivateObjectExtensions
{
private static readonly BindingFlags Static = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly | BindingFlags.Static;
private static readonly BindingFlags Instance = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly | BindingFlags.Instance;

/// <summary>
/// Get from private (and any other) field/property.
/// If the real type of specified object doesn't contain the specified field/property,
/// base types are searched automatically.
/// </summary>
/// <param name="obj">The object to get from</param>
/// <param name="name">The name of the field/property</param>
/// <returns>The object got from the field/property</returns>
/// <exception cref="ArgumentException">'name' is not found.</exception>
/// <exception cref="ArgumentNullException">Arguments contain null.</exception>
public static object GetPrivate(this object obj, string name)
{
if (obj == null) { throw new ArgumentNullException("obj"); }

return GetPrivate(obj, name, obj.GetType(), null);
}

/// <summary>
/// Get from private (and any other) field/property.
/// If the real type of specified object doesn't contain the specified field/property,
/// base types are searched automatically.
/// </summary>
/// <typeparam name="T">The type of the field/property</typeparam>
/// <param name="obj">The object to get from</param>
/// <param name="name">The name of the field/property</param>
/// <returns>The object got from the field/property</returns>
/// <exception cref="ArgumentException">'name' is not found.</exception>
/// <exception cref="ArgumentNullException">Arguments contain null.</exception>
public static T GetPrivate<T>(this object obj, string name)
{
if (obj == null) { throw new ArgumentNullException("obj"); }

return (T)GetPrivate(obj, name, obj.GetType(), typeof(T));
}

/// <summary>
/// Get from private (and any other) field/property with assuming the specified object as specified type.
/// If the specified type doesn't contain the specified field/property,
/// base types are searched automatically.
/// </summary>
/// <param name="obj">The object to get from</param>
/// <param name="name">The name of the field/property</param>
/// <param name="objType">The type of 'obj' for seaching member starting from. Real type of 'obj' is ignored.</param>
/// <returns>The object got from the field/property</returns>
/// <exception cref="ArgumentException">'name' is not found.</exception>
/// <exception cref="ArgumentException">'objType' is not assignable from 'obj'.</exception>
/// <exception cref="ArgumentNullException">Arguments contain null.</exception>
public static object GetPrivate(this object obj, string name, Type objType)
{
return GetPrivate(obj, name, objType, null);
}

/// <summary>
/// Get from private (and any other) field/property with assuming the specified object as specified type.
/// If the specified type doesn't contain the specified field/property,
/// base types are searched automatically.
/// </summary>
/// <typeparam name="T">The type of the field/property</typeparam>
/// <param name="obj">The object to get from</param>
/// <param name="name">The name of the field/property</param>
/// <param name="objType">The type of 'obj' for seaching member starting from. Real type of 'obj' is ignored.</param>
/// <returns>The object got from the field/property</returns>
/// <exception cref="ArgumentException">'name' is not found.</exception>
/// <exception cref="ArgumentException">'objType' is not assignable from 'obj'.</exception>
/// <exception cref="ArgumentNullException">Arguments contain null.</exception>
public static T GetPrivate<T>(this object obj, string name, Type objType)
{
return (T)GetPrivate(obj, name, objType, typeof(T));
}

private static object GetPrivate(object obj, string name, Type objType, Type memberType)
{
if (obj == null) { throw new ArgumentNullException("obj"); }
if (name == null) { throw new ArgumentNullException("name"); }
if (string.IsNullOrWhiteSpace(name)) { throw new ArgumentException("name is empty or white-space.", "name"); }
if (objType == null) { throw new ArgumentNullException("objType"); }
if (!objType.IsAssignableFrom(obj.GetType())) { throw new ArgumentException($"{objType} is not assignable from {obj.GetType()}.", "objType"); }

bool memberTypeMatching(Type actualType) => actualType == memberType;

if (TryFindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, Instance, out var ownerType))
{
return new PrivateObject(obj, new PrivateType(ownerType)).GetFieldOrProperty(name);
}
else if (TryFindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, Static, out ownerType))
{
return new PrivateType(ownerType).GetStaticFieldOrProperty(name);
}

throw new ArgumentException(((memberType != null) ? memberType + " " : "") + name + " is not found.");
}

/// <summary>
/// Get from private (and any other) static field/property.
/// </summary>
/// <param name="type">The type to get from</param>
/// <param name="name">The name of the static field/property</param>
/// <returns>The object got from the static field/property</returns>
/// <exception cref="ArgumentException">'name' is not found.</exception>
/// <exception cref="ArgumentNullException">Arguments contain null.</exception>
public static object GetPrivate(this Type type, string name)
{
return GetPrivate(type, name, null);
}

/// <summary>
/// Get from private (and any other) static field/property.
/// </summary>
/// <typeparam name="T">The type of the field/property</typeparam>
/// <param name="type">The type to get from</param>
/// <param name="name">The name of the static field/property</param>
/// <returns>The object got from the static field/property</returns>
/// <exception cref="ArgumentException">'name' is not found.</exception>
/// <exception cref="ArgumentNullException">Arguments contain null.</exception>
public static T GetPrivate<T>(this Type type, string name)
{
return (T)GetPrivate(type, name, typeof(T));
}

private static object GetPrivate(this Type type, string name, Type memberType)
{
if (type == null) { throw new ArgumentNullException("type"); }
if (name == null) { throw new ArgumentNullException("name"); }
if (string.IsNullOrWhiteSpace(name)) { throw new ArgumentException("name is empty or white-space.", "name"); }

bool memberTypeMatching(Type actualType) => actualType == memberType;

if (type.ContainsFieldOrProperty(name, memberType, memberTypeMatching, Static))
{
return new PrivateType(type).GetStaticFieldOrProperty(name);
}

throw new ArgumentException(((memberType != null) ? memberType + " " : "") + name + " is not found.");
}

/// <summary>
/// Set to private (and any other) field/property.
/// If the real type of specified object doesn't contain the specified field/property,
/// base types are searched automatically.
/// </summary>
/// <param name="obj">The object to set to</param>
/// <param name="name">The name of the field/property</param>
/// <param name="value">The value to set for 'name'</param>
/// <exception cref="ArgumentException">'name' is not found.</exception>
/// <exception cref="ArgumentNullException">Arguments contain null.</exception>
public static void SetPrivate<T>(this object obj, string name, T value)
{
if (obj == null) { throw new ArgumentNullException("obj"); }

SetPrivate(obj, name, value, obj.GetType());
}

/// <summary>
/// Set to private (and any other) field/property with assuming the specified object as specified type.
/// If the specified type doesn't contain the specified field/property,
/// base types are searched automatically.
/// </summary>
/// <param name="obj">The object to set to</param>
/// <param name="name">The name of the field/property</param>
/// <param name="value">The value to set for 'name'</param>
/// <param name="objType">The type of 'obj' for seaching member starting from. Real type of 'obj' is ignored.</param>
/// <exception cref="ArgumentException">'name' is not found.</exception>
/// <exception cref="ArgumentException">'objType' is not assignable from 'obj'.</exception>
/// <exception cref="ArgumentNullException">Arguments contain null.</exception>
public static void SetPrivate<T>(this object obj, string name, T value, Type objType)
{
if (obj == null) { throw new ArgumentNullException("obj"); }
if (name == null) { throw new ArgumentNullException("name"); }
if (string.IsNullOrWhiteSpace(name)) { throw new ArgumentException("name is empty or white-space.", "name"); }
if (value == null) { throw new ArgumentNullException("value"); }
if (objType == null) { throw new ArgumentNullException("objType"); }
if (!objType.IsAssignableFrom(obj.GetType())) { throw new ArgumentException($"{objType} is not assignable from {obj.GetType()}.", "objType"); }

if (TrySetPrivate(obj, name, value, objType)) { return; }

// retry for the case of getter only property
if (TrySetPrivate(obj, GetBackingFieldName(name), value, objType)) { return; }

throw new ArgumentException($"{typeof(T)} {name} is not found.");
}

private static bool TrySetPrivate<T>(object obj, string name, T value, Type objType)
{
var memberType = typeof(T);
bool memberTypeMatching(Type actualType) => actualType.IsAssignableFrom(memberType);

try
{
if (TryFindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, Instance, out var ownerType))
{
new PrivateObject(obj, new PrivateType(ownerType)).SetFieldOrProperty(name, value);
return true;
}
else if (TryFindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, Static, out ownerType))
{
new PrivateType(ownerType).SetStaticFieldOrProperty(name, value);
return true;
}
}
catch(MissingMethodException)
{
// When getter only property name is given, the property is found but fails to set.
return false;
}

return false;
}

/// <summary>
/// Set to private (and any other) static field/property.
/// </summary>
/// <param name="type">The type to set to</param>
/// <param name="name">The name of the field/property</param>
/// <param name="value">The value to set for 'name'</param>
/// <exception cref="ArgumentException">'name' is not found.</exception>
/// <exception cref="ArgumentNullException">Arguments contain null.</exception>
public static void SetPrivate<T>(this Type type, string name, T value)
{
if (type == null) { throw new ArgumentNullException("type"); }
if (name == null) { throw new ArgumentNullException("name"); }
if (string.IsNullOrWhiteSpace(name)) { throw new ArgumentException("name is empty or white-space.", "name"); }

if (TrySetPrivate(type, name, value)) { return; }

// retry for the case of getter only property
if (TrySetPrivate(type, GetBackingFieldName(name), value)) { return; }

throw new ArgumentException($"{typeof(T)} {name} is not found.");
}

private static bool TrySetPrivate<T>(this Type type, string name, T value)
{
var memberType = typeof(T);
bool memberTypeMatching(Type actualType) => actualType.IsAssignableFrom(memberType);

try
{
if (type.ContainsFieldOrProperty(name, memberType, memberTypeMatching, Static))
{
new PrivateType(type).SetStaticFieldOrProperty(name, value);
return true;
}
}
catch (MissingMethodException)
{
// When getter only property name is given, the property is found but fails to set.
return false;
}

return false;
}

private static string GetBackingFieldName(string propertyName)
=> $"<{propertyName}>k__BackingField"; // generated backing field name

private static bool TryFindFieldOrPropertyOwnerType(Type objType, string name, Type memberType, Func<Type, bool> memberTypeMatching, BindingFlags bindingFlag, out Type ownerType)
{
ownerType = FindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, bindingFlag);

return (ownerType != null);
}

private static Type FindFieldOrPropertyOwnerType(Type objectType, string name, Type memberType, Func<Type, bool> memberTypeMatching, BindingFlags bindingFlags)
{
if (objectType == null) { return null; }

if (objectType.ContainsFieldOrProperty(name, memberType, memberTypeMatching, bindingFlags))
{
return objectType;
}

return FindFieldOrPropertyOwnerType(objectType.BaseType, name, memberType, memberTypeMatching, bindingFlags);
}

private static bool ContainsFieldOrProperty(this Type objectType, string name, Type memberType, Func<Type, bool> memberTypeMatching, BindingFlags bindingFlags)
{
var fields = objectType
.GetFields(bindingFlags)
.Select((x) => new { Type = x.FieldType, Member = x as MemberInfo });

var properties = objectType
.GetProperties(bindingFlags)
.Select((x) => new { Type = x.PropertyType, Member = x as MemberInfo });

var members = fields.Concat(properties);

return members.Any((actual) =>
(memberType == null || memberTypeMatching.Invoke(actual.Type))
&& actual.Member.Name == name);
}
}
}

+ 572
- 0
test/TensorFlowNET.UnitTest/Utilities/PrivateType.cs View File

@@ -0,0 +1,572 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

namespace Microsoft.VisualStudio.TestTools.UnitTesting
{
using System;
//using System.Diagnostics;
using System.Globalization;
using System.Reflection;

/// <summary>
/// This class represents a private class for the Private Accessor functionality.
/// </summary>
internal class PrivateType
{
/// <summary>
/// Binds to everything
/// </summary>
private const BindingFlags BindToEveryThing = BindingFlags.Default
| BindingFlags.NonPublic | BindingFlags.Instance
| BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy;

/// <summary>
/// The wrapped type.
/// </summary>
private Type type;

///// <summary>
///// Initializes a new instance of the <see cref="PrivateType"/> class that contains the private type.
///// </summary>
///// <param name="assemblyName">Assembly name</param>
///// <param name="typeName">fully qualified name of the </param>
//public PrivateType(string assemblyName, string typeName)
//{
// Helper.CheckParameterNotNullOrEmpty(assemblyName, "assemblyName", string.Empty);
// Helper.CheckParameterNotNullOrEmpty(typeName, "typeName", string.Empty);
// Assembly asm = Assembly.Load(assemblyName);

// this.type = asm.GetType(typeName, true);
//}

/// <summary>
/// Initializes a new instance of the <see cref="PrivateType"/> class that contains
/// the private type from the type object
/// </summary>
/// <param name="type">The wrapped Type to create.</param>
public PrivateType(Type type)
{
if (type == null)
{
throw new ArgumentNullException("type");
}

this.type = type;
}

/// <summary>
/// Gets the referenced type
/// </summary>
public Type ReferencedType => this.type;

///// <summary>
///// Invokes static member
///// </summary>
///// <param name="name">Name of the member to InvokeHelper</param>
///// <param name="args">Arguements to the invoction</param>
///// <returns>Result of invocation</returns>
//public object InvokeStatic(string name, params object[] args)
//{
// return this.InvokeStatic(name, null, args, CultureInfo.InvariantCulture);
//}

///// <summary>
///// Invokes static member
///// </summary>
///// <param name="name">Name of the member to InvokeHelper</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the method to invoke</param>
///// <param name="args">Arguements to the invoction</param>
///// <returns>Result of invocation</returns>
//public object InvokeStatic(string name, Type[] parameterTypes, object[] args)
//{
// return this.InvokeStatic(name, parameterTypes, args, CultureInfo.InvariantCulture);
//}

///// <summary>
///// Invokes static member
///// </summary>
///// <param name="name">Name of the member to InvokeHelper</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the method to invoke</param>
///// <param name="args">Arguements to the invoction</param>
///// <param name="typeArguments">An array of types corresponding to the types of the generic arguments.</param>
///// <returns>Result of invocation</returns>
//public object InvokeStatic(string name, Type[] parameterTypes, object[] args, Type[] typeArguments)
//{
// return this.InvokeStatic(name, BindToEveryThing, parameterTypes, args, CultureInfo.InvariantCulture, typeArguments);
//}

///// <summary>
///// Invokes the static method
///// </summary>
///// <param name="name">Name of the member</param>
///// <param name="args">Arguements to the invocation</param>
///// <param name="culture">Culture</param>
///// <returns>Result of invocation</returns>
//public object InvokeStatic(string name, object[] args, CultureInfo culture)
//{
// return this.InvokeStatic(name, null, args, culture);
//}

///// <summary>
///// Invokes the static method
///// </summary>
///// <param name="name">Name of the member</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the method to invoke</param>
///// <param name="args">Arguements to the invocation</param>
///// <param name="culture">Culture info</param>
///// <returns>Result of invocation</returns>
//public object InvokeStatic(string name, Type[] parameterTypes, object[] args, CultureInfo culture)
//{
// return this.InvokeStatic(name, BindingFlags.InvokeMethod, parameterTypes, args, culture);
//}

///// <summary>
///// Invokes the static method
///// </summary>
///// <param name="name">Name of the member</param>
///// <param name="bindingFlags">Additional invocation attributes</param>
///// <param name="args">Arguements to the invocation</param>
///// <returns>Result of invocation</returns>
//public object InvokeStatic(string name, BindingFlags bindingFlags, params object[] args)
//{
// return this.InvokeStatic(name, bindingFlags, null, args, CultureInfo.InvariantCulture);
//}

///// <summary>
///// Invokes the static method
///// </summary>
///// <param name="name">Name of the member</param>
///// <param name="bindingFlags">Additional invocation attributes</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the method to invoke</param>
///// <param name="args">Arguements to the invocation</param>
///// <returns>Result of invocation</returns>
//public object InvokeStatic(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args)
//{
// return this.InvokeStatic(name, bindingFlags, parameterTypes, args, CultureInfo.InvariantCulture);
//}

///// <summary>
///// Invokes the static method
///// </summary>
///// <param name="name">Name of the member</param>
///// <param name="bindingFlags">Additional invocation attributes</param>
///// <param name="args">Arguements to the invocation</param>
///// <param name="culture">Culture</param>
///// <returns>Result of invocation</returns>
//public object InvokeStatic(string name, BindingFlags bindingFlags, object[] args, CultureInfo culture)
//{
// return this.InvokeStatic(name, bindingFlags, null, args, culture);
//}

///// <summary>
///// Invokes the static method
///// </summary>
///// <param name="name">Name of the member</param>
///// <param name="bindingFlags">Additional invocation attributes</param>
///// /// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the method to invoke</param>
///// <param name="args">Arguements to the invocation</param>
///// <param name="culture">Culture</param>
///// <returns>Result of invocation</returns>
//public object InvokeStatic(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args, CultureInfo culture)
//{
// return this.InvokeStatic(name, bindingFlags, parameterTypes, args, culture, null);
//}

///// <summary>
///// Invokes the static method
///// </summary>
///// <param name="name">Name of the member</param>
///// <param name="bindingFlags">Additional invocation attributes</param>
///// /// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the method to invoke</param>
///// <param name="args">Arguements to the invocation</param>
///// <param name="culture">Culture</param>
///// <param name="typeArguments">An array of types corresponding to the types of the generic arguments.</param>
///// <returns>Result of invocation</returns>
//public object InvokeStatic(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args, CultureInfo culture, Type[] typeArguments)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// if (parameterTypes != null)
// {
// MethodInfo member = this.type.GetMethod(name, bindingFlags | BindToEveryThing | BindingFlags.Static, null, parameterTypes, null);
// if (member == null)
// {
// throw new ArgumentException(string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name));
// }

// try
// {
// if (member.IsGenericMethodDefinition)
// {
// MethodInfo constructed = member.MakeGenericMethod(typeArguments);
// return constructed.Invoke(null, bindingFlags, null, args, culture);
// }
// else
// {
// return member.Invoke(null, bindingFlags, null, args, culture);
// }
// }
// catch (TargetInvocationException e)
// {
// Debug.Assert(e.InnerException != null, "Inner Exception should not be null.");
// if (e.InnerException != null)
// {
// throw e.InnerException;
// }

// throw;
// }
// }
// else
// {
// return this.InvokeHelperStatic(name, bindingFlags | BindingFlags.InvokeMethod, args, culture);
// }
//}

///// <summary>
///// Gets the element in static array
///// </summary>
///// <param name="name">Name of the array</param>
///// <param name="indices">
///// A one-dimensional array of 32-bit integers that represent the indexes specifying
///// the position of the element to get. For instance, to access a[10][11] the indices would be {10,11}
///// </param>
///// <returns>element at the specified location</returns>
//public object GetStaticArrayElement(string name, params int[] indices)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// return this.GetStaticArrayElement(name, BindToEveryThing, indices);
//}

///// <summary>
///// Sets the memeber of the static array
///// </summary>
///// <param name="name">Name of the array</param>
///// <param name="value">value to set</param>
///// <param name="indices">
///// A one-dimensional array of 32-bit integers that represent the indexes specifying
///// the position of the element to set. For instance, to access a[10][11] the array would be {10,11}
///// </param>
//public void SetStaticArrayElement(string name, object value, params int[] indices)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// this.SetStaticArrayElement(name, BindToEveryThing, value, indices);
//}

///// <summary>
///// Gets the element in satatic array
///// </summary>
///// <param name="name">Name of the array</param>
///// <param name="bindingFlags">Additional InvokeHelper attributes</param>
///// <param name="indices">
///// A one-dimensional array of 32-bit integers that represent the indexes specifying
///// the position of the element to get. For instance, to access a[10][11] the array would be {10,11}
///// </param>
///// <returns>element at the spcified location</returns>
//public object GetStaticArrayElement(string name, BindingFlags bindingFlags, params int[] indices)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// Array arr = (Array)this.InvokeHelperStatic(name, BindingFlags.GetField | BindingFlags.GetProperty | bindingFlags, null, CultureInfo.InvariantCulture);
// return arr.GetValue(indices);
//}

///// <summary>
///// Sets the memeber of the static array
///// </summary>
///// <param name="name">Name of the array</param>
///// <param name="bindingFlags">Additional InvokeHelper attributes</param>
///// <param name="value">value to set</param>
///// <param name="indices">
///// A one-dimensional array of 32-bit integers that represent the indexes specifying
///// the position of the element to set. For instance, to access a[10][11] the array would be {10,11}
///// </param>
//public void SetStaticArrayElement(string name, BindingFlags bindingFlags, object value, params int[] indices)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// Array arr = (Array)this.InvokeHelperStatic(name, BindingFlags.GetField | BindingFlags.GetProperty | BindingFlags.Static | bindingFlags, null, CultureInfo.InvariantCulture);
// arr.SetValue(value, indices);
//}

///// <summary>
///// Gets the static field
///// </summary>
///// <param name="name">Name of the field</param>
///// <returns>The static field.</returns>
//public object GetStaticField(string name)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// return this.GetStaticField(name, BindToEveryThing);
//}

///// <summary>
///// Sets the static field
///// </summary>
///// <param name="name">Name of the field</param>
///// <param name="value">Arguement to the invocation</param>
//public void SetStaticField(string name, object value)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// this.SetStaticField(name, BindToEveryThing, value);
//}

///// <summary>
///// Gets the static field using specified InvokeHelper attributes
///// </summary>
///// <param name="name">Name of the field</param>
///// <param name="bindingFlags">Additional invocation attributes</param>
///// <returns>The static field.</returns>
//public object GetStaticField(string name, BindingFlags bindingFlags)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// return this.InvokeHelperStatic(name, BindingFlags.GetField | BindingFlags.Static | bindingFlags, null, CultureInfo.InvariantCulture);
//}

///// <summary>
///// Sets the static field using binding attributes
///// </summary>
///// <param name="name">Name of the field</param>
///// <param name="bindingFlags">Additional InvokeHelper attributes</param>
///// <param name="value">Arguement to the invocation</param>
//public void SetStaticField(string name, BindingFlags bindingFlags, object value)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// this.InvokeHelperStatic(name, BindingFlags.SetField | bindingFlags | BindingFlags.Static, new[] { value }, CultureInfo.InvariantCulture);
//}

/// <summary>
/// Gets the static field or property
/// </summary>
/// <param name="name">Name of the field or property</param>
/// <returns>The static field or property.</returns>
public object GetStaticFieldOrProperty(string name)
{
Helper.CheckParameterNotNull(name, "name", string.Empty);
return this.GetStaticFieldOrProperty(name, BindToEveryThing);
}

/// <summary>
/// Sets the static field or property
/// </summary>
/// <param name="name">Name of the field or property</param>
/// <param name="value">Value to be set to field or property</param>
public void SetStaticFieldOrProperty(string name, object value)
{
Helper.CheckParameterNotNull(name, "name", string.Empty);
this.SetStaticFieldOrProperty(name, BindToEveryThing, value);
}

/// <summary>
/// Gets the static field or property using specified InvokeHelper attributes
/// </summary>
/// <param name="name">Name of the field or property</param>
/// <param name="bindingFlags">Additional invocation attributes</param>
/// <returns>The static field or property.</returns>
public object GetStaticFieldOrProperty(string name, BindingFlags bindingFlags)
{
Helper.CheckParameterNotNull(name, "name", string.Empty);
return this.InvokeHelperStatic(name, BindingFlags.GetField | BindingFlags.GetProperty | BindingFlags.Static | bindingFlags, null, CultureInfo.InvariantCulture);
}

/// <summary>
/// Sets the static field or property using binding attributes
/// </summary>
/// <param name="name">Name of the field or property</param>
/// <param name="bindingFlags">Additional invocation attributes</param>
/// <param name="value">Value to be set to field or property</param>
public void SetStaticFieldOrProperty(string name, BindingFlags bindingFlags, object value)
{
Helper.CheckParameterNotNull(name, "name", string.Empty);
this.InvokeHelperStatic(name, BindingFlags.SetField | BindingFlags.SetProperty | bindingFlags | BindingFlags.Static, new[] {value}, CultureInfo.InvariantCulture);
}

///// <summary>
///// Gets the static property
///// </summary>
///// <param name="name">Name of the field or property</param>
///// <param name="args">Arguements to the invocation</param>
///// <returns>The static property.</returns>
//public object GetStaticProperty(string name, params object[] args)
//{
// return this.GetStaticProperty(name, BindToEveryThing, args);
//}

///// <summary>
///// Sets the static property
///// </summary>
///// <param name="name">Name of the property</param>
///// <param name="value">Value to be set to field or property</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
//public void SetStaticProperty(string name, object value, params object[] args)
//{
// this.SetStaticProperty(name, BindToEveryThing, value, null, args);
//}

///// <summary>
///// Sets the static property
///// </summary>
///// <param name="name">Name of the property</param>
///// <param name="value">Value to be set to field or property</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the indexed property.</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
//public void SetStaticProperty(string name, object value, Type[] parameterTypes, object[] args)
//{
// this.SetStaticProperty(name, BindingFlags.SetProperty, value, parameterTypes, args);
//}

///// <summary>
///// Gets the static property
///// </summary>
///// <param name="name">Name of the property</param>
///// <param name="bindingFlags">Additional invocation attributes.</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <returns>The static property.</returns>
//public object GetStaticProperty(string name, BindingFlags bindingFlags, params object[] args)
//{
// return this.GetStaticProperty(name, BindingFlags.GetProperty | BindingFlags.Static | bindingFlags, null, args);
//}

///// <summary>
///// Gets the static property
///// </summary>
///// <param name="name">Name of the property</param>
///// <param name="bindingFlags">Additional invocation attributes.</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the indexed property.</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <returns>The static property.</returns>
//public object GetStaticProperty(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// if (parameterTypes != null)
// {
// PropertyInfo pi = this.type.GetProperty(name, bindingFlags | BindingFlags.Static, null, null, parameterTypes, null);
// if (pi == null)
// {
// throw new ArgumentException(string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name));
// }

// return pi.GetValue(null, args);
// }
// else
// {
// return this.InvokeHelperStatic(name, bindingFlags | BindingFlags.GetProperty, args, null);
// }
//}

///// <summary>
///// Sets the static property
///// </summary>
///// <param name="name">Name of the property</param>
///// <param name="bindingFlags">Additional invocation attributes.</param>
///// <param name="value">Value to be set to field or property</param>
///// <param name="args">Optional index values for indexed properties. The indexes of indexed properties are zero-based. This value should be null for non-indexed properties. </param>
//public void SetStaticProperty(string name, BindingFlags bindingFlags, object value, params object[] args)
//{
// this.SetStaticProperty(name, bindingFlags, value, null, args);
//}

///// <summary>
///// Sets the static property
///// </summary>
///// <param name="name">Name of the property</param>
///// <param name="bindingFlags">Additional invocation attributes.</param>
///// <param name="value">Value to be set to field or property</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the indexed property.</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
//public void SetStaticProperty(string name, BindingFlags bindingFlags, object value, Type[] parameterTypes, object[] args)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);

// if (parameterTypes != null)
// {
// PropertyInfo pi = this.type.GetProperty(name, bindingFlags | BindingFlags.Static, null, null, parameterTypes, null);
// if (pi == null)
// {
// throw new ArgumentException(
// string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name));
// }

// pi.SetValue(null, value, args);
// }
// else
// {
// object[] pass = new object[(args?.Length ?? 0) + 1];
// pass[0] = value;
// args?.CopyTo(pass, 1);
// this.InvokeHelperStatic(name, bindingFlags | BindingFlags.SetProperty, pass, null);
// }
//}

/// <summary>
/// Invokes the static method
/// </summary>
/// <param name="name">Name of the member</param>
/// <param name="bindingFlags">Additional invocation attributes</param>
/// <param name="args">Arguements to the invocation</param>
/// <param name="culture">Culture</param>
/// <returns>Result of invocation</returns>
private object InvokeHelperStatic(string name, BindingFlags bindingFlags, object[] args, CultureInfo culture)
{
Helper.CheckParameterNotNull(name, "name", string.Empty);
try
{
return this.type.InvokeMember(name, bindingFlags | BindToEveryThing | BindingFlags.Static, null, null, args, culture);
} catch (TargetInvocationException e)
{
//Debug.Assert(e.InnerException != null, "Inner Exception should not be null.");
if (e.InnerException != null)
{
throw e.InnerException;
}

throw;
}
}
}

/// <summary>
/// The helper.
/// </summary>
internal static class Helper
{
/// <summary>
/// The check parameter not null.
/// </summary>
/// <param name="param">
/// The parameter.
/// </param>
/// <param name="parameterName">
/// The parameter name.
/// </param>
/// <param name="message">
/// The message.
/// </param>
/// <exception cref="ArgumentNullException"> Throws argument null exception when parameter is null. </exception>
internal static void CheckParameterNotNull(object param, string parameterName, string message)
{
if (param == null)
{
throw new ArgumentNullException(parameterName, message);
}
}

/// <summary>
/// The check parameter not null or empty.
/// </summary>
/// <param name="param">
/// The parameter.
/// </param>
/// <param name="parameterName">
/// The parameter name.
/// </param>
/// <param name="message">
/// The message.
/// </param>
/// <exception cref="ArgumentException"> Throws ArgumentException when parameter is null. </exception>
//internal static void CheckParameterNotNullOrEmpty(string param, string parameterName, string message)
//{
// if (string.IsNullOrEmpty(param))
// {
// throw new ArgumentException(message, parameterName);
// }
//}
}
}

+ 1
- 1
test/TensorFlowNET.UnitTest/VariableTest.cs View File

@@ -29,7 +29,7 @@ namespace TensorFlowNET.UnitTest
} }


/// <summary> /// <summary>
/// https://www.tf.org/api_docs/python/tf/variable_scope
/// https://www.tensorflow.org/api_docs/python/tf/variable_scope
/// how to create a new variable /// how to create a new variable
/// </summary> /// </summary>
[TestMethod] [TestMethod]


+ 76
- 61
test/TensorFlowNET.UnitTest/c_test_util.cs View File

@@ -12,42 +12,51 @@ namespace TensorFlowNET.UnitTest
{ {
public static Operation Add(Operation l, Operation r, Graph graph, Status s, string name = "add") public static Operation Add(Operation l, Operation r, Graph graph, Status s, string name = "add")
{ {
var desc = c_api.TF_NewOperation(graph, "AddN", name);

var inputs = new TF_Output[]
lock (Locks.ProcessWide)
{ {
new TF_Output(l, 0),
new TF_Output(r, 0),
};
var desc = c_api.TF_NewOperation(graph, "AddN", name);


c_api.TF_AddInputList(desc, inputs, inputs.Length);
var inputs = new TF_Output[]
{
new TF_Output(l, 0),
new TF_Output(r, 0),
};


var op = c_api.TF_FinishOperation(desc, s);
s.Check();
c_api.TF_AddInputList(desc, inputs, inputs.Length);


return op;
var op = c_api.TF_FinishOperation(desc, s);
s.Check();

return op;
}
} }


[SuppressMessage("ReSharper", "RedundantAssignment")] [SuppressMessage("ReSharper", "RedundantAssignment")]
public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s)
{ {
using (var buffer = new Buffer())
lock (Locks.ProcessWide)
{ {
c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s);
attr_value = AttrValue.Parser.ParseFrom(buffer.MemoryBlock.Stream());
}
using (var buffer = new Buffer())
{
c_api.TF_OperationGetAttrValueProto(oper, attr_name, buffer, s);
attr_value = AttrValue.Parser.ParseFrom(buffer.MemoryBlock.Stream());
}


return s.Code == TF_Code.TF_OK;
return s.Code == TF_Code.TF_OK;
}
} }


public static GraphDef GetGraphDef(Graph graph) public static GraphDef GetGraphDef(Graph graph)
{ {
using (var s = new Status())
using (var buffer = new Buffer())
lock (Locks.ProcessWide)
{ {
c_api.TF_GraphToGraphDef(graph, buffer, s);
s.Check();
return GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream());
using (var s = new Status())
using (var buffer = new Buffer())
{
c_api.TF_GraphToGraphDef(graph, buffer, s);
s.Check();
return GraphDef.Parser.ParseFrom(buffer.MemoryBlock.Stream());
}
} }
} }


@@ -58,6 +67,7 @@ namespace TensorFlowNET.UnitTest
{ {
return false; return false;
} }

bool found_t = false; bool found_t = false;
bool found_n = false; bool found_n = false;
foreach (var attr in node_def.Attr) foreach (var attr in node_def.Attr)
@@ -67,19 +77,16 @@ namespace TensorFlowNET.UnitTest
if (attr.Value.Type == DataType.DtInt32) if (attr.Value.Type == DataType.DtInt32)
{ {
found_t = true; found_t = true;
}
else
} else
{ {
return false; return false;
} }
}
else if (attr.Key == "N")
} else if (attr.Key == "N")
{ {
if (attr.Value.I == n) if (attr.Value.I == n)
{ {
found_n = true; found_n = true;
}
else
} else
{ {
return false; return false;
} }
@@ -92,7 +99,7 @@ namespace TensorFlowNET.UnitTest
public static bool IsNeg(NodeDef node_def, string input) public static bool IsNeg(NodeDef node_def, string input)
{ {
return node_def.Op == "Neg" && node_def.Name == "neg" && return node_def.Op == "Neg" && node_def.Name == "neg" &&
node_def.Input.Count == 1 && node_def.Input[0] == input;
node_def.Input.Count == 1 && node_def.Input[0] == input;
} }


public static bool IsPlaceholder(NodeDef node_def) public static bool IsPlaceholder(NodeDef node_def)
@@ -111,13 +118,11 @@ namespace TensorFlowNET.UnitTest
if (attr.Value.Type == DataType.DtInt32) if (attr.Value.Type == DataType.DtInt32)
{ {
found_dtype = true; found_dtype = true;
}
else
} else
{ {
return false; return false;
} }
}
else if (attr.Key == "shape")
} else if (attr.Key == "shape")
{ {
found_shape = true; found_shape = true;
} }
@@ -132,72 +137,82 @@ namespace TensorFlowNET.UnitTest
{ {
return false; return false;
} }

bool found_dtype = false; bool found_dtype = false;
bool found_value = false; bool found_value = false;
foreach (var attr in node_def.Attr) {
foreach (var attr in node_def.Attr)
{
if (attr.Key == "dtype") if (attr.Key == "dtype")
{ {
if (attr.Value.Type == DataType.DtInt32) if (attr.Value.Type == DataType.DtInt32)
{ {
found_dtype = true; found_dtype = true;
}
else
} else
{ {
return false; return false;
} }
}
else if (attr.Key == "value")
} else if (attr.Key == "value")
{ {
if (attr.Value.Tensor != null && if (attr.Value.Tensor != null &&
attr.Value.Tensor.IntVal.Count == 1 && attr.Value.Tensor.IntVal.Count == 1 &&
attr.Value.Tensor.IntVal[0] == v) attr.Value.Tensor.IntVal[0] == v)
{ {
found_value = true; found_value = true;
}
else
} else
{ {
return false; return false;
} }
} }
} }

return found_dtype && found_value; return found_dtype && found_value;
} }


public static Operation Neg(Operation n, Graph graph, Status s, string name = "neg") public static Operation Neg(Operation n, Graph graph, Status s, string name = "neg")
{ {
OperationDescription desc = c_api.TF_NewOperation(graph, "Neg", name);
var neg_input = new TF_Output(n, 0);
c_api.TF_AddInput(desc, neg_input);
var op = c_api.TF_FinishOperation(desc, s);
s.Check();
lock (Locks.ProcessWide)
{
OperationDescription desc = c_api.TF_NewOperation(graph, "Neg", name);
var neg_input = new TF_Output(n, 0);
c_api.TF_AddInput(desc, neg_input);
var op = c_api.TF_FinishOperation(desc, s);
s.Check();


return op;
return op;
}
} }


public static Operation Placeholder(Graph graph, Status s, string name = "feed", TF_DataType dtype = TF_DataType.TF_INT32, long[] dims = null) public static Operation Placeholder(Graph graph, Status s, string name = "feed", TF_DataType dtype = TF_DataType.TF_INT32, long[] dims = null)
{ {
var desc = c_api.TF_NewOperation(graph, "Placeholder", name);
c_api.TF_SetAttrType(desc, "dtype", dtype);
if (dims != null)
lock (Locks.ProcessWide)
{ {
c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length);
}
var op = c_api.TF_FinishOperation(desc, s);
s.Check();
var desc = c_api.TF_NewOperation(graph, "Placeholder", name);
c_api.TF_SetAttrType(desc, "dtype", dtype);
if (dims != null)
{
c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length);
}

var op = c_api.TF_FinishOperation(desc, s);
s.Check();


return op;
return op;
}
} }


public static Operation Const(Tensor t, Graph graph, Status s, string name) public static Operation Const(Tensor t, Graph graph, Status s, string name)
{ {
var desc = c_api.TF_NewOperation(graph, "Const", name);
c_api.TF_SetAttrTensor(desc, "value", t, s);
s.Check();
c_api.TF_SetAttrType(desc, "dtype", t.dtype);
var op = c_api.TF_FinishOperation(desc, s);
s.Check();

return op;
lock (Locks.ProcessWide)
{
var desc = c_api.TF_NewOperation(graph, "Const", name);
c_api.TF_SetAttrTensor(desc, "value", t, s);
s.Check();
c_api.TF_SetAttrType(desc, "dtype", t.dtype);
var op = c_api.TF_FinishOperation(desc, s);
s.Check();

return op;
}
} }


public static Operation ScalarConst(int v, Graph graph, Status s, string name = "scalar") public static Operation ScalarConst(int v, Graph graph, Status s, string name = "scalar")
@@ -205,4 +220,4 @@ namespace TensorFlowNET.UnitTest
return Const(new Tensor(v), graph, s, name); return Const(new Tensor(v), graph, s, name);
} }
} }
}
}

Loading…
Cancel
Save