Browse Source

seperate TensorFlowNET.Graph.UnitTest project.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
71488b4126
25 changed files with 533 additions and 2707 deletions
  1. +16
    -2
      TensorFlow.NET.sln
  2. +0
    -1
      test/TensorFlowNET.Graph.UnitTest/Basics/QueueTest.cs
  3. +1
    -2
      test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/CondTestCases.cs
  4. +1
    -2
      test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/ShapeTestCase.cs
  5. +1
    -2
      test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/WhileContextTestCase.cs
  6. +5
    -4
      test/TensorFlowNET.Graph.UnitTest/FunctionalOpsTest/ScanTestCase.cs
  7. +0
    -1
      test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs
  8. +1
    -4
      test/TensorFlowNET.Graph.UnitTest/GraphModeTestBase.cs
  9. +6
    -8
      test/TensorFlowNET.Graph.UnitTest/ImageTest.cs
  10. +25
    -27
      test/TensorFlowNET.Graph.UnitTest/MultithreadingTests.cs
  11. +78
    -0
      test/TensorFlowNET.Graph.UnitTest/NameScopeTest.cs
  12. +0
    -2
      test/TensorFlowNET.Graph.UnitTest/OperationsTest.cs
  13. +335
    -0
      test/TensorFlowNET.Graph.UnitTest/PythonTest.cs
  14. +36
    -0
      test/TensorFlowNET.Graph.UnitTest/TensorFlowNET.Graph.UnitTest.csproj
  15. +0
    -0
      test/TensorFlowNET.Graph.UnitTest/Utilities/MultiThreadedUnitTestExecuter.cs
  16. +22
    -0
      test/TensorFlowNET.Graph.UnitTest/Utilities/TestHelper.cs
  17. +4
    -4
      test/TensorFlowNET.UnitTest/Basics/RandomTest.cs
  18. +0
    -67
      test/TensorFlowNET.UnitTest/Basics/TensorShapeTest.cs
  19. +1
    -1
      test/TensorFlowNET.UnitTest/EnforcedSinglethreadingTests.cs
  20. +0
    -1
      test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs
  21. +0
    -1104
      test/TensorFlowNET.UnitTest/GradientTest/gradients_test.py
  22. +1
    -72
      test/TensorFlowNET.UnitTest/NameScopeTest.cs
  23. +0
    -917
      test/TensorFlowNET.UnitTest/Utilities/PrivateObject.cs
  24. +0
    -314
      test/TensorFlowNET.UnitTest/Utilities/PrivateObjectExtensions.cs
  25. +0
    -172
      test/TensorFlowNET.UnitTest/control_flow_ops_test/SwitchTestCase.cs

+ 16
- 2
TensorFlow.NET.sln View File

@@ -1,7 +1,7 @@

Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio Version 16
VisualStudioVersion = 16.0.29102.190
# Visual Studio Version 17
VisualStudioVersion = 17.0.31423.177
MinimumVisualStudioVersion = 10.0.40219.1
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Binding", "src\TensorFlowNET.Core\Tensorflow.Binding.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}"
EndProject
@@ -21,6 +21,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Native.UnitTest"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras.UnitTest", "test\TensorFlowNET.Keras.UnitTest\Tensorflow.Keras.UnitTest.csproj", "{79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Graph.UnitTest", "test\TensorFlowNET.Graph.UnitTest\TensorFlowNET.Graph.UnitTest.csproj", "{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -139,6 +141,18 @@ Global
{79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Release|x64.Build.0 = Release|Any CPU
{79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Release|x86.ActiveCfg = Release|Any CPU
{79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Release|x86.Build.0 = Release|Any CPU
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Debug|Any CPU.Build.0 = Debug|Any CPU
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Debug|x64.ActiveCfg = Debug|x64
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Debug|x64.Build.0 = Debug|x64
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Debug|x86.ActiveCfg = Debug|Any CPU
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Debug|x86.Build.0 = Debug|Any CPU
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|Any CPU.ActiveCfg = Release|Any CPU
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|Any CPU.Build.0 = Release|Any CPU
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.ActiveCfg = Release|Any CPU
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.Build.0 = Release|Any CPU
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.ActiveCfg = Release|Any CPU
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE


test/TensorFlowNET.UnitTest/Basics/QueueTest.cs → test/TensorFlowNET.Graph.UnitTest/Basics/QueueTest.cs View File

@@ -1,7 +1,6 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Linq;
using Tensorflow;
using Tensorflow.UnitTest;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.Basics

test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs → test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/CondTestCases.cs View File

@@ -1,9 +1,8 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
using Tensorflow.UnitTest;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.control_flow_ops_test
namespace TensorFlowNET.UnitTest.ControlFlowTest
{
/// <summary>
/// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py

test/TensorFlowNET.UnitTest/control_flow_ops_test/ShapeTestCase.cs → test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/ShapeTestCase.cs View File

@@ -1,8 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
using Tensorflow.UnitTest;

namespace TensorFlowNET.UnitTest.control_flow_ops_test
namespace TensorFlowNET.UnitTest.ControlFlowTest
{
/// <summary>
/// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py

test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs → test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/WhileContextTestCase.cs View File

@@ -1,10 +1,9 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;
using Tensorflow.UnitTest;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.control_flow_ops_test
namespace TensorFlowNET.UnitTest.ControlFlowTest
{
[TestClass]
public class WhileContextTestCase : GraphModeTestBase

test/TensorFlowNET.UnitTest/functional_ops_test/ScanTestCase.cs → test/TensorFlowNET.Graph.UnitTest/FunctionalOpsTest/ScanTestCase.cs View File

@@ -2,10 +2,9 @@
using Tensorflow.NumPy;
using System;
using Tensorflow;
using Tensorflow.UnitTest;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.functional_ops_test
namespace TensorFlowNET.UnitTest.FunctionalOpsTest
{
/// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/scan
@@ -22,7 +21,8 @@ namespace TensorFlowNET.UnitTest.functional_ops_test

var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(6));
var scan = functional_ops.scan(fn, input);
sess.run(scan, (input, np.array(1, 2, 3, 4, 5, 6))).Should().Be(np.array(1, 3, 6, 10, 15, 21));
var result = sess.run(scan, (input, np.array(1, 2, 3, 4, 5, 6)));
Assert.AreEqual(result, np.array(1, 3, 6, 10, 15, 21));
}

[TestMethod, Ignore("need UpdateEdge API")]
@@ -34,7 +34,8 @@ namespace TensorFlowNET.UnitTest.functional_ops_test

var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(6));
var scan = functional_ops.scan(fn, input, reverse: true);
sess.run(scan, (input, np.array(1, 2, 3, 4, 5, 6))).Should().Be(np.array(21, 20, 18, 15, 11, 6));
var result = sess.run(scan, (input, np.array(1, 2, 3, 4, 5, 6)));
Assert.AreEqual(result, np.array(21, 20, 18, 15, 11, 6));
}
}
}

test/TensorFlowNET.UnitTest/GradientTest/GradientTest.cs → test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs View File

@@ -4,7 +4,6 @@ using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow;
using Tensorflow.UnitTest;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.Gradient

test/TensorFlowNET.UnitTest/GraphModeTestBase.cs → test/TensorFlowNET.Graph.UnitTest/GraphModeTestBase.cs View File

@@ -1,9 +1,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using TensorFlowNET.UnitTest;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace Tensorflow.UnitTest
namespace TensorFlowNET.UnitTest
{
public class GraphModeTestBase : PythonTest
{
@@ -16,7 +14,6 @@ namespace Tensorflow.UnitTest
[TestCleanup]
public void TestClean()
{
keras.backend.clear_session();
tf.enable_eager_execution();
}
}

test/TensorFlowNET.UnitTest/ImageTest.cs → test/TensorFlowNET.Graph.UnitTest/ImageTest.cs View File

@@ -1,12 +1,10 @@
using FluentAssertions;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.NumPy;
using System.Linq;
using Tensorflow;
using Tensorflow.UnitTest;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.Basics
namespace TensorFlowNET.UnitTest
{
/// <summary>
/// Find more examples in https://www.programcreek.com/python/example/90444/tensorflow.read_file
@@ -84,14 +82,14 @@ namespace TensorFlowNET.UnitTest.Basics

var result = sess.run(cropped);
// check if cropped to 1x1 center was succesfull
result.size.Should().Be(1);
result[0, 0, 0, 0].Should().Be(4f);
Assert.AreEqual(result.size, 1);
Assert.AreEqual(result[0, 0, 0, 0], 4f);

cropped = tf.image.crop_and_resize(image2, box, boxInd, cropSize2_2);
result = sess.run(cropped);
// check if flipped and no cropping occured
result.size.Should().Be(16);
result[0, 0, 0, 0].Should().Be(12f);
Assert.AreEqual(result.size, 16);
Assert.AreEqual(result[0, 0, 0, 0], 12f);

}
}

test/TensorFlowNET.UnitTest/MultithreadingTests.cs → test/TensorFlowNET.Graph.UnitTest/MultithreadingTests.cs View File

@@ -1,12 +1,10 @@
using FluentAssertions;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.NumPy;
using System;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using Tensorflow;
using Tensorflow.UnitTest;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest
@@ -24,15 +22,15 @@ namespace TensorFlowNET.UnitTest
//the core method
void Core(int tid)
{
tf.peak_default_graph().Should().BeNull();
Assert.IsNull(tf.peak_default_graph());

using (var sess = tf.Session())
{
var default_graph = tf.peak_default_graph();
var sess_graph = sess.GetPrivate<Graph>("_graph");
sess_graph.Should().NotBeNull();
default_graph.Should().NotBeNull()
.And.BeEquivalentTo(sess_graph);
var sess_graph = sess.graph;
Assert.IsNotNull(default_graph);
Assert.IsNotNull(sess_graph);
Assert.AreEqual(default_graph, sess_graph);
}
}
}
@@ -47,15 +45,15 @@ namespace TensorFlowNET.UnitTest
//the core method
void Core(int tid)
{
tf.peak_default_graph().Should().BeNull();
Assert.IsNull(tf.peak_default_graph());
//tf.Session created an other graph
using (var sess = tf.Session())
{
var default_graph = tf.peak_default_graph();
var sess_graph = sess.GetPrivate<Graph>("_graph");
sess_graph.Should().NotBeNull();
default_graph.Should().NotBeNull()
.And.BeEquivalentTo(sess_graph);
var sess_graph = sess.graph;
Assert.IsNotNull(default_graph);
Assert.IsNotNull(sess_graph);
Assert.AreEqual(default_graph, sess_graph);
}
}
}
@@ -70,19 +68,18 @@ namespace TensorFlowNET.UnitTest
//the core method
void Core(int tid)
{
tf.peak_default_graph().Should().BeNull();
Assert.IsNull(tf.peak_default_graph());
var beforehand = tf.get_default_graph(); //this should create default automatically.
beforehand.graph_key.Should().NotContain("-0/", "Already created a graph in an other thread.");
beforehand.as_default();
tf.peak_default_graph().Should().NotBeNull();
Assert.IsNotNull(tf.peak_default_graph());

using (var sess = tf.Session())
{
var default_graph = tf.peak_default_graph();
var sess_graph = sess.GetPrivate<Graph>("_graph");
sess_graph.Should().NotBeNull();
default_graph.Should().NotBeNull()
.And.BeEquivalentTo(sess_graph);
var sess_graph = sess.graph;
Assert.IsNotNull(default_graph);
Assert.IsNotNull(sess_graph);
Assert.AreEqual(default_graph, sess_graph);

Console.WriteLine($"{tid}-{default_graph.graph_key}");

@@ -188,7 +185,7 @@ namespace TensorFlowNET.UnitTest
//the core method
void Core(int tid)
{
tf.peak_default_graph().Should().BeNull();
Assert.IsNull(tf.peak_default_graph());
//graph is created automatically to perform create these operations
var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 });
var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 });
@@ -197,7 +194,8 @@ namespace TensorFlowNET.UnitTest
{
using (var sess = tf.Session())
{
sess.run(math).GetAtIndex<float>(0).Should().Be(5);
var result = sess.run(math);
Assert.AreEqual(result.GetAtIndex<float>(0), 5f);
}
}
}
@@ -213,14 +211,14 @@ namespace TensorFlowNET.UnitTest
{
using (var sess = tf.Session())
{
tf.peak_default_graph().Should().NotBeNull();
Assert.IsNotNull(tf.peak_default_graph());
//graph is created automatically to perform create these operations
var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 });
var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 });
var math = a1 + a2;

var result = sess.run(math);
result.GetAtIndex<float>(0).Should().Be(5);
Assert.AreEqual(result.GetAtIndex<float>(0), 5f);
}
}
}
@@ -235,7 +233,7 @@ namespace TensorFlowNET.UnitTest
{
using (var sess = tf.Session())
{
tf.peak_default_graph().Should().NotBeNull();
Assert.IsNotNull(tf.peak_default_graph());
//graph is created automatically to perform create these operations
var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 });
var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 });
@@ -252,7 +250,7 @@ namespace TensorFlowNET.UnitTest
//the core method
void Core(int tid)
{
tf.peak_default_graph().Should().BeNull();
Assert.IsNull(tf.peak_default_graph());
//graph is created automatically to perform create these operations
var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 });
var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 });
@@ -268,7 +266,7 @@ namespace TensorFlowNET.UnitTest
//the core method
void Core(int tid)
{
tf.peak_default_graph().Should().BeNull();
Assert.IsNull(tf.peak_default_graph());
//graph is created automatically to perform create these operations
var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 });
var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }, name: "ConstantK");

+ 78
- 0
test/TensorFlowNET.Graph.UnitTest/NameScopeTest.cs View File

@@ -0,0 +1,78 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.Basics
{
[TestClass]
public class NameScopeTest : GraphModeTestBase
{
string name = "";

[TestMethod]
public void NestedNameScope()
{
Graph g = tf.Graph().as_default();

tf_with(new ops.NameScope("scope1"), scope1 =>
{
name = scope1;
Assert.AreEqual("scope1", g._name_stack);
Assert.AreEqual("scope1/", name);

var const1 = tf.constant(1.0);
Assert.AreEqual("scope1/Const:0", const1.name);

tf_with(new ops.NameScope("scope2"), scope2 =>
{
name = scope2;
Assert.AreEqual("scope1/scope2", g._name_stack);
Assert.AreEqual("scope1/scope2/", name);

var const2 = tf.constant(2.0);
Assert.AreEqual("scope1/scope2/Const:0", const2.name);
});

Assert.AreEqual("scope1", g._name_stack);
var const3 = tf.constant(2.0);
Assert.AreEqual("scope1/Const_1:0", const3.name);
});

g.Dispose();

Assert.AreEqual("", g._name_stack);
}

[TestMethod, Ignore("Unimplemented Usage")]
public void NestedNameScope_Using()
{
Graph g = tf.Graph().as_default();

using (var name = new ops.NameScope("scope1"))
{
Assert.AreEqual("scope1", g._name_stack);
Assert.AreEqual("scope1/", name);

var const1 = tf.constant(1.0);
Assert.AreEqual("scope1/Const:0", const1.name);

using (var name2 = new ops.NameScope("scope2"))
{
Assert.AreEqual("scope1/scope2", g._name_stack);
Assert.AreEqual("scope1/scope2/", name);

var const2 = tf.constant(2.0);
Assert.AreEqual("scope1/scope2/Const:0", const2.name);
}

Assert.AreEqual("scope1", g._name_stack);
var const3 = tf.constant(2.0);
Assert.AreEqual("scope1/Const_1:0", const3.name);
};

g.Dispose();

Assert.AreEqual("", g._name_stack);
}
}
}

test/TensorFlowNET.UnitTest/OperationsTest.cs → test/TensorFlowNET.Graph.UnitTest/OperationsTest.cs View File

@@ -4,8 +4,6 @@ using System;
using System.Collections.Generic;
using System.Linq;
using Tensorflow;
using Tensorflow.UnitTest;
using Tensorflow.Util;
using static Tensorflow.Binding;
using Buffer = Tensorflow.Buffer;


+ 335
- 0
test/TensorFlowNET.Graph.UnitTest/PythonTest.cs View File

@@ -0,0 +1,335 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Newtonsoft.Json.Linq;
using Tensorflow.NumPy;
using System;
using System.Collections;
using System.Linq;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest
{
/// <summary>
/// Use as base class for test classes to get additional assertions
/// </summary>
public class PythonTest
{
#region python compatibility layer
protected PythonTest self { get => this; }
protected int None => -1;
#endregion

#region pytest assertions

public void assertItemsEqual(ICollection given, ICollection expected)
{
if (given is Hashtable && expected is Hashtable)
{
Assert.AreEqual(JObject.FromObject(expected).ToString(), JObject.FromObject(given).ToString());
return;
}
Assert.IsNotNull(expected);
Assert.IsNotNull(given);
var e = expected.OfType<object>().ToArray();
var g = given.OfType<object>().ToArray();
Assert.AreEqual(e.Length, g.Length, $"The collections differ in length expected {e.Length} but got {g.Length}");
for (int i = 0; i < e.Length; i++)
{
/*if (g[i] is NDArray && e[i] is NDArray)
assertItemsEqual((g[i] as NDArray).GetData<object>(), (e[i] as NDArray).GetData<object>());
else*/
if (e[i] is ICollection && g[i] is ICollection)
assertEqual(g[i], e[i]);
else
Assert.AreEqual(e[i], g[i], $"Items differ at index {i}, expected {e[i]} but got {g[i]}");
}
}

public void assertAllEqual(ICollection given, ICollection expected)
{
assertItemsEqual(given, expected);
}

public void assertFloat32Equal(float expected, float actual, string msg)
{
float eps = 1e-6f;
Assert.IsTrue(Math.Abs(expected - actual) < eps * Math.Max(1.0f, Math.Abs(expected)), $"{msg}: expected {expected} vs actual {actual}");
}

public void assertFloat64Equal(double expected, double actual, string msg)
{
double eps = 1e-16f;
Assert.IsTrue(Math.Abs(expected - actual) < eps * Math.Max(1.0f, Math.Abs(expected)), $"{msg}: expected {expected} vs actual {actual}");
}

public void assertEqual(object given, object expected)
{
/*if (given is NDArray && expected is NDArray)
{
assertItemsEqual((given as NDArray).GetData<object>(), (expected as NDArray).GetData<object>());
return;
}*/
if (given is Hashtable && expected is Hashtable)
{
Assert.AreEqual(JObject.FromObject(expected).ToString(), JObject.FromObject(given).ToString());
return;
}
if (given is ICollection && expected is ICollection)
{
assertItemsEqual(given as ICollection, expected as ICollection);
return;
}
if (given is float && expected is float)
{
assertFloat32Equal((float)expected, (float)given, "");
return;
}
if (given is double && expected is double)
{
assertFloat64Equal((double)expected, (double)given, "");
return;
}
Assert.AreEqual(expected, given);
}

public void assertEquals(object given, object expected)
{
assertEqual(given, expected);
}

public void assert(object given)
{
if (given is bool)
Assert.IsTrue((bool)given);
Assert.IsNotNull(given);
}

public void assertIsNotNone(object given)
{
Assert.IsNotNull(given);
}

public void assertFalse(bool cond)
{
Assert.IsFalse(cond);
}

public void assertTrue(bool cond)
{
Assert.IsTrue(cond);
}

public void assertAllClose(NDArray array1, NDArray array2, double eps = 1e-5)
{
Assert.IsTrue(np.allclose(array1, array2, rtol: eps));
}

public void assertAllClose(double value, NDArray array2, double eps = 1e-5)
{
var array1 = np.ones_like(array2) * value;
Assert.IsTrue(np.allclose(array1, array2, rtol: eps));
}

public void assertProtoEquals(object toProto, object o)
{
throw new NotImplementedException();
}

#endregion

#region tensor evaluation and test session

//protected object _eval_helper(Tensor[] tensors)
//{
// if (tensors == null)
// return null;
// return nest.map_structure(self._eval_tensor, tensors);
//}

protected object _eval_tensor(object tensor)
{
if (tensor == null)
return None;
//else if (callable(tensor))
// return self._eval_helper(tensor())
else
{
try
{
//TODO:
// if sparse_tensor.is_sparse(tensor):
// return sparse_tensor.SparseTensorValue(tensor.indices, tensor.values,
// tensor.dense_shape)
//return (tensor as Tensor).numpy();
}
catch (Exception)
{
throw new ValueError("Unsupported type: " + tensor.GetType());
}
return null;
}
}

/// <summary>
/// This function is used in many original tensorflow unit tests to evaluate tensors
/// in a test session with special settings (for instance constant folding off)
///
/// </summary>
public T evaluate<T>(Tensor tensor)
{
object result = null;
// if context.executing_eagerly():
// return self._eval_helper(tensors)
// else:
{
using (var sess = tf.Session())
{
var ndarray = tensor.eval(sess);
if (typeof(T) == typeof(double))
{
double x = ndarray;
result = x;
}
else if (typeof(T) == typeof(int))
{
int x = ndarray;
result = x;
}
else
{
result = ndarray;
}
}

return (T)result;
}
}


public Session cached_session()
{
throw new NotImplementedException();
}

//Returns a TensorFlow Session for use in executing tests.
public Session session(Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false)
{
//Note that this will set this session and the graph as global defaults.

//Use the `use_gpu` and `force_gpu` options to control where ops are run.If
//`force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if
//`use_gpu` is True, TensorFlow tries to run as many ops on the GPU as
//possible.If both `force_gpu and `use_gpu` are False, all ops are pinned to
//the CPU.

//Example:
//```python
//class MyOperatorTest(test_util.TensorFlowTestCase):
// def testMyOperator(self):
// with self.session(use_gpu= True):
// valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
// result = MyOperator(valid_input).eval()
// self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
// invalid_input = [-1.0, 2.0, 7.0]
// with self.assertRaisesOpError("negative input not supported"):
// MyOperator(invalid_input).eval()
//```

//Args:
// graph: Optional graph to use during the returned session.
// config: An optional config_pb2.ConfigProto to use to configure the
// session.
// use_gpu: If True, attempt to run as many ops as possible on GPU.
// force_gpu: If True, pin all ops to `/device:GPU:0`.

//Yields:
// A Session object that should be used as a context manager to surround
// the graph building and execution code in a test case.

Session s = null;
//if (context.executing_eagerly())
// yield None
//else
//{
s = self._create_session(graph, config, force_gpu);
self._constrain_devices_and_set_default(s, use_gpu, force_gpu);
//}
return s.as_default();
}

private ITensorFlowObject _constrain_devices_and_set_default(Session sess, bool useGpu, bool forceGpu)
{
//def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
//"""Set the session and its graph to global default and constrain devices."""
//if context.executing_eagerly():
// yield None
//else:
// with sess.graph.as_default(), sess.as_default():
// if force_gpu:
// # Use the name of an actual device if one is detected, or
// # '/device:GPU:0' otherwise
// gpu_name = gpu_device_name()
// if not gpu_name:
// gpu_name = "/device:GPU:0"
// with sess.graph.device(gpu_name):
// yield sess
// elif use_gpu:
// yield sess
// else:
// with sess.graph.device("/device:CPU:0"):
// yield sess
return sess;
}

// See session() for details.
private Session _create_session(Graph graph, object cfg, bool forceGpu)
{
var prepare_config = new Func<object, object>((config) =>
{
// """Returns a config for sessions.
// Args:
// config: An optional config_pb2.ConfigProto to use to configure the
// session.
// Returns:
// A config_pb2.ConfigProto object.

//TODO: config

// # use_gpu=False. Currently many tests rely on the fact that any device
// # will be used even when a specific device is supposed to be used.
// allow_soft_placement = not force_gpu
// if config is None:
// config = config_pb2.ConfigProto()
// config.allow_soft_placement = allow_soft_placement
// config.gpu_options.per_process_gpu_memory_fraction = 0.3
// elif not allow_soft_placement and config.allow_soft_placement:
// config_copy = config_pb2.ConfigProto()
// config_copy.CopyFrom(config)
// config = config_copy
// config.allow_soft_placement = False
// # Don't perform optimizations for tests so we don't inadvertently run
// # gpu ops on cpu
// config.graph_options.optimizer_options.opt_level = -1
// # Disable Grappler constant folding since some tests & benchmarks
// # use constant input and become meaningless after constant folding.
// # DO NOT DISABLE GRAPPLER OPTIMIZERS WITHOUT CONSULTING WITH THE
// # GRAPPLER TEAM.
// config.graph_options.rewrite_options.constant_folding = (
// rewriter_config_pb2.RewriterConfig.OFF)
// config.graph_options.rewrite_options.pin_to_host_optimization = (
// rewriter_config_pb2.RewriterConfig.OFF)
return config;
});
//TODO: use this instead of normal session
//return new ErrorLoggingSession(graph = graph, config = prepare_config(config))
return new Session(graph);//, config = prepare_config(config))
}

#endregion

public void AssetSequenceEqual<T>(T[] a, T[] b)
{
Assert.IsTrue(Enumerable.SequenceEqual(a, b));
}
}
}

+ 36
- 0
test/TensorFlowNET.Graph.UnitTest/TensorFlowNET.Graph.UnitTest.csproj View File

@@ -0,0 +1,36 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>net5.0</TargetFramework>
<LangVersion>9.0</LangVersion>
<IsPackable>false</IsPackable>
<AssemblyName>TensorFlowNET.UnitTest</AssemblyName>
<Platforms>AnyCPU;x64</Platforms>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
<DefineConstants>DEBUG;TRACE</DefineConstants>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<DefineConstants>DEBUG;TRACE</DefineConstants>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.10.0" />
<PackageReference Include="MSTest.TestAdapter" Version="2.2.5" />
<PackageReference Include="MSTest.TestFramework" Version="2.2.5" />
<PackageReference Include="coverlet.collector" Version="3.0.3">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.6.0-rc0" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\src\TensorFlowNET.Core\Tensorflow.Binding.csproj" />
</ItemGroup>

</Project>

test/TensorFlowNET.UnitTest/Utilities/MultiThreadedUnitTestExecuter.cs → test/TensorFlowNET.Graph.UnitTest/Utilities/MultiThreadedUnitTestExecuter.cs View File


+ 22
- 0
test/TensorFlowNET.Graph.UnitTest/Utilities/TestHelper.cs View File

@@ -0,0 +1,22 @@
using System;
using System.IO;

namespace TensorFlowNET.UnitTest
{
public class TestHelper
{
public static string GetFullPathFromDataDir(string fileName)
{
var dataDir = GetRootContentDir(Directory.GetCurrentDirectory());
return Path.Combine(dataDir, fileName);
}

static string GetRootContentDir(string dir)
{
var path = Path.GetFullPath(Path.Combine(dir, "data"));
if (Directory.Exists(path))
return path;
return GetRootContentDir(Path.GetFullPath(Path.Combine(dir, "..")));
}
}
}

+ 4
- 4
test/TensorFlowNET.UnitTest/Basics/RandomTest.cs View File

@@ -30,8 +30,8 @@ namespace TensorFlowNET.UnitTest.Basics
tf.set_random_seed(1234);
var a2 = tf.random_uniform(1);
var b2 = tf.random_shuffle(tf.constant(initValue));
Assert.AreEqual(a1, a2);
Assert.AreEqual(b1, b2);
Assert.AreEqual(a1.numpy(), a2.numpy());
Assert.AreEqual(b1.numpy(), b2.numpy());
}
/// <summary>
@@ -76,8 +76,8 @@ namespace TensorFlowNET.UnitTest.Basics
var a2 = tf.random.normal(1);
var b2 = tf.random.truncated_normal(1);
Assert.AreEqual(a1, a2);
Assert.AreEqual(b1, b2);
Assert.AreEqual(a1.numpy(), a2.numpy());
Assert.AreEqual(b1.numpy(), b2.numpy());
}
/// <summary>


+ 0
- 67
test/TensorFlowNET.UnitTest/Basics/TensorShapeTest.cs View File

@@ -1,67 +0,0 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.NumPy;
using System;
using Tensorflow;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.Basics
{
[TestClass]
public class TensorShapeTest
{
[TestMethod]
public void Case1()
{
int a = 2;
int b = 3;
var dims = new[] { Unknown, a, b };
new TensorShape(dims).GetPrivate<Shape>("shape").Should().BeShaped(-1, 2, 3);
}

[TestMethod]
public void Case2()
{
int a = 2;
int b = 3;
var dims = new[] { Unknown, a, b };
//new TensorShape(new[] { dims }).GetPrivate<Shape>("shape").Should().BeShaped(-1, 2, 3);
}

[TestMethod]
public void Case3()
{
int a = 2;
int b = Unknown;
var dims = new[] { Unknown, a, b };
//new TensorShape(new[] { dims }).GetPrivate<Shape>("shape").Should().BeShaped(-1, 2, -1);
}

[TestMethod]
public void Case4()
{
TensorShape shape = (Unknown, Unknown);
shape.GetPrivate<Shape>("shape").Should().BeShaped(-1, -1);
}

[TestMethod]
public void Case5()
{
TensorShape shape = (1, Unknown, 3);
shape.GetPrivate<Shape>("shape").Should().BeShaped(1, -1, 3);
}

[TestMethod]
public void Case6()
{
TensorShape shape = (Unknown, 1, 2, 3, Unknown);
shape.GetPrivate<Shape>("shape").Should().BeShaped(-1, 1, 2, 3, -1);
}

[TestMethod]
public void Case7()
{
TensorShape shape = new TensorShape();
Assert.AreEqual(shape.rank, -1);
}
}
}

+ 1
- 1
test/TensorFlowNET.UnitTest/EnforcedSinglethreadingTests.cs View File

@@ -34,7 +34,7 @@ namespace TensorFlowNET.UnitTest
using (var sess = tf.Session())
{
var default_graph = tf.peak_default_graph();
var sess_graph = sess.GetPrivate<Graph>("_graph");
var sess_graph = sess.graph;
sess_graph.Should().NotBeNull();
default_graph.Should().NotBeNull()
.And.BeEquivalentTo(sess_graph);


+ 0
- 1
test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs View File

@@ -4,7 +4,6 @@ using System.Collections.Generic;
using System.Linq;
using Tensorflow;
using Tensorflow.NumPy;
using Tensorflow.UnitTest;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.Gradient


+ 0
- 1104
test/TensorFlowNET.UnitTest/GradientTest/gradients_test.py
File diff suppressed because it is too large
View File


+ 1
- 72
test/TensorFlowNET.UnitTest/NameScopeTest.cs View File

@@ -1,93 +1,22 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
using Tensorflow.UnitTest;
using static Tensorflow.Binding;

namespace TensorFlowNET.UnitTest.Basics
{
[TestClass]
public class NameScopeTest : GraphModeTestBase
public class NameScopeTest : EagerModeTestBase
{
string name = "";

[TestMethod]
public void NestedNameScope()
{
Graph g = tf.Graph().as_default();

tf_with(new ops.NameScope("scope1"), scope1 =>
{
name = scope1;
Assert.AreEqual("scope1", g._name_stack);
Assert.AreEqual("scope1/", name);

var const1 = tf.constant(1.0);
Assert.AreEqual("scope1/Const:0", const1.name);

tf_with(new ops.NameScope("scope2"), scope2 =>
{
name = scope2;
Assert.AreEqual("scope1/scope2", g._name_stack);
Assert.AreEqual("scope1/scope2/", name);

var const2 = tf.constant(2.0);
Assert.AreEqual("scope1/scope2/Const:0", const2.name);
});

Assert.AreEqual("scope1", g._name_stack);
var const3 = tf.constant(2.0);
Assert.AreEqual("scope1/Const_1:0", const3.name);
});

g.Dispose();

Assert.AreEqual("", g._name_stack);
}

[TestMethod]
public void NameScopeInEagerMode()
{
tf.enable_eager_execution();

tf_with(new ops.NameScope("scope"), scope =>
{
string name = scope;
var const1 = tf.constant(1.0);
});

tf.compat.v1.disable_eager_execution();
}

[TestMethod, Ignore("Unimplemented Usage")]
public void NestedNameScope_Using()
{
Graph g = tf.Graph().as_default();

using (var name = new ops.NameScope("scope1"))
{
Assert.AreEqual("scope1", g._name_stack);
Assert.AreEqual("scope1/", name);

var const1 = tf.constant(1.0);
Assert.AreEqual("scope1/Const:0", const1.name);

using (var name2 = new ops.NameScope("scope2"))
{
Assert.AreEqual("scope1/scope2", g._name_stack);
Assert.AreEqual("scope1/scope2/", name);

var const2 = tf.constant(2.0);
Assert.AreEqual("scope1/scope2/Const:0", const2.name);
}

Assert.AreEqual("scope1", g._name_stack);
var const3 = tf.constant(2.0);
Assert.AreEqual("scope1/Const_1:0", const3.name);
};

g.Dispose();

Assert.AreEqual("", g._name_stack);
}
}
}

+ 0
- 917
test/TensorFlowNET.UnitTest/Utilities/PrivateObject.cs View File

@@ -1,917 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

namespace Microsoft.VisualStudio.TestTools.UnitTesting
{
using System;
//using System.Diagnostics;
//using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Reflection;

/// <summary>
/// This class represents the live NON public INTERNAL object in the system
/// </summary>
internal class PrivateObject
{
#region Data

// bind everything
private const BindingFlags BindToEveryThing = BindingFlags.Default | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Public;

#pragma warning disable CS0414 // The field 'PrivateObject.constructorFlags' is assigned but its value is never used
private static BindingFlags constructorFlags = BindingFlags.Instance | BindingFlags.Public | BindingFlags.CreateInstance | BindingFlags.NonPublic;
#pragma warning restore CS0414 // The field 'PrivateObject.constructorFlags' is assigned but its value is never used

private object target; // automatically initialized to null
private Type originalType; // automatically initialized to null

//private Dictionary<string, LinkedList<MethodInfo>> methodCache; // automatically initialized to null

#endregion

#region Constructors

///// <summary>
///// Initializes a new instance of the <see cref="PrivateObject"/> class that contains
///// the already existing object of the private class
///// </summary>
///// <param name="obj"> object that serves as starting point to reach the private members</param>
///// <param name="memberToAccess">the derefrencing string using . that points to the object to be retrived as in m_X.m_Y.m_Z</param>
//[SuppressMessage("Microsoft.Naming", "CA1720:IdentifiersShouldNotContainTypeNames", MessageId = "obj", Justification = "We don't know anything about the object other than that it's an object, so 'obj' seems reasonable")]
//public PrivateObject(object obj, string memberToAccess)
//{
// Helper.CheckParameterNotNull(obj, "obj", string.Empty);
// ValidateAccessString(memberToAccess);

// PrivateObject temp = obj as PrivateObject;
// if (temp == null)
// {
// temp = new PrivateObject(obj);
// }

// // Split The access string
// string[] arr = memberToAccess.Split(new char[] { '.' });

// for (int i = 0; i < arr.Length; i++)
// {
// object next = temp.InvokeHelper(arr[i], BindToEveryThing | BindingFlags.Instance | BindingFlags.GetField | BindingFlags.GetProperty, null, CultureInfo.InvariantCulture);
// temp = new PrivateObject(next);
// }

// this.target = temp.target;
// this.originalType = temp.originalType;
//}

///// <summary>
///// Initializes a new instance of the <see cref="PrivateObject"/> class that wraps the
///// specified type.
///// </summary>
///// <param name="assemblyName">Name of the assembly</param>
///// <param name="typeName">fully qualified name</param>
///// <param name="args">Argmenets to pass to the constructor</param>
//public PrivateObject(string assemblyName, string typeName, params object[] args)
// : this(assemblyName, typeName, null, args)
//{
//}

///// <summary>
///// Initializes a new instance of the <see cref="PrivateObject"/> class that wraps the
///// specified type.
///// </summary>
///// <param name="assemblyName">Name of the assembly</param>
///// <param name="typeName">fully qualified name</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the constructor to get</param>
///// <param name="args">Argmenets to pass to the constructor</param>
//public PrivateObject(string assemblyName, string typeName, Type[] parameterTypes, object[] args)
// : this(Type.GetType(string.Format(CultureInfo.InvariantCulture, "{0}, {1}", typeName, assemblyName), false), parameterTypes, args)
//{
// Helper.CheckParameterNotNull(assemblyName, "assemblyName", string.Empty);
// Helper.CheckParameterNotNull(typeName, "typeName", string.Empty);
//}

///// <summary>
///// Initializes a new instance of the <see cref="PrivateObject"/> class that wraps the
///// specified type.
///// </summary>
///// <param name="type">type of the object to create</param>
///// <param name="args">Argmenets to pass to the constructor</param>
//public PrivateObject(Type type, params object[] args)
// : this(type, null, args)
//{
// Helper.CheckParameterNotNull(type, "type", string.Empty);
//}

///// <summary>
///// Initializes a new instance of the <see cref="PrivateObject"/> class that wraps the
///// specified type.
///// </summary>
///// <param name="type">type of the object to create</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the constructor to get</param>
///// <param name="args">Argmenets to pass to the constructor</param>
//public PrivateObject(Type type, Type[] parameterTypes, object[] args)
//{
// Helper.CheckParameterNotNull(type, "type", string.Empty);
// object o;
// if (parameterTypes != null)
// {
// ConstructorInfo ci = type.GetConstructor(BindToEveryThing, null, parameterTypes, null);
// if (ci == null)
// {
// throw new ArgumentException(FrameworkMessages.PrivateAccessorConstructorNotFound);
// }

// try
// {
// o = ci.Invoke(args);
// }
// catch (TargetInvocationException e)
// {
// Debug.Assert(e.InnerException != null, "Inner exception should not be null.");
// if (e.InnerException != null)
// {
// throw e.InnerException;
// }

// throw;
// }
// }
// else
// {
// o = Activator.CreateInstance(type, constructorFlags, null, args, null);
// }

// this.ConstructFrom(o);
//}

/// <summary>
/// Initializes a new instance of the <see cref="PrivateObject"/> class that wraps
/// the given object.
/// </summary>
/// <param name="obj">object to wrap</param>
//[SuppressMessage("Microsoft.Naming", "CA1720:IdentifiersShouldNotContainTypeNames", MessageId = "obj", Justification = "We don't know anything about the object other than that it's an object, so 'obj' seems reasonable")]
public PrivateObject(object obj)
{
Helper.CheckParameterNotNull(obj, "obj", string.Empty);
this.ConstructFrom(obj);
}

/// <summary>
/// Initializes a new instance of the <see cref="PrivateObject"/> class that wraps
/// the given object.
/// </summary>
/// <param name="obj">object to wrap</param>
/// <param name="type">PrivateType object</param>
//[SuppressMessage("Microsoft.Naming", "CA1720:IdentifiersShouldNotContainTypeNames", MessageId = "obj", Justification = "We don't know anything about the object other than that it's an an object, so 'obj' seems reasonable")]
public PrivateObject(object obj, PrivateType type)
{
Helper.CheckParameterNotNull(type, "type", string.Empty);
this.target = obj;
this.originalType = type.ReferencedType;
}

#endregion

///// <summary>
///// Gets or sets the target
///// </summary>
//public object Target
//{
// get
// {
// return this.target;
// }

// set
// {
// Helper.CheckParameterNotNull(value, "Target", string.Empty);
// this.target = value;
// this.originalType = value.GetType();
// }
//}

///// <summary>
///// Gets the type of underlying object
///// </summary>
//public Type RealType
//{
// get
// {
// return this.originalType;
// }
//}

//private Dictionary<string, LinkedList<MethodInfo>> GenericMethodCache
//{
// get
// {
// if (this.methodCache == null)
// {
// this.BuildGenericMethodCacheForType(this.originalType);
// }

// Debug.Assert(this.methodCache != null, "Invalid method cache for type.");

// return this.methodCache;
// }
//}

/// <summary>
/// returns the hash code of the target object
/// </summary>
/// <returns>int representing hashcode of the target object</returns>
public override int GetHashCode()
{
//Debug.Assert(this.target != null, "target should not be null.");
return this.target.GetHashCode();
}

/// <summary>
/// Equals
/// </summary>
/// <param name="obj">Object with whom to compare</param>
/// <returns>returns true if the objects are equal.</returns>
public override bool Equals(object obj)
{
if (this != obj)
{
//Debug.Assert(this.target != null, "target should not be null.");
if (typeof(PrivateObject) == obj?.GetType())
{
return this.target.Equals(((PrivateObject)obj).target);
}
else
{
return false;
}
}

return true;
}

///// <summary>
///// Invokes the specified method
///// </summary>
///// <param name="name">Name of the method</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <returns>Result of method call</returns>
//public object Invoke(string name, params object[] args)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// return this.Invoke(name, null, args, CultureInfo.InvariantCulture);
//}

///// <summary>
///// Invokes the specified method
///// </summary>
///// <param name="name">Name of the method</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the method to get.</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <returns>Result of method call</returns>
//public object Invoke(string name, Type[] parameterTypes, object[] args)
//{
// return this.Invoke(name, parameterTypes, args, CultureInfo.InvariantCulture);
//}

///// <summary>
///// Invokes the specified method
///// </summary>
///// <param name="name">Name of the method</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the method to get.</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <param name="typeArguments">An array of types corresponding to the types of the generic arguments.</param>
///// <returns>Result of method call</returns>
//public object Invoke(string name, Type[] parameterTypes, object[] args, Type[] typeArguments)
//{
// return this.Invoke(name, BindToEveryThing, parameterTypes, args, CultureInfo.InvariantCulture, typeArguments);
//}

///// <summary>
///// Invokes the specified method
///// </summary>
///// <param name="name">Name of the method</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <param name="culture">Culture info</param>
///// <returns>Result of method call</returns>
//public object Invoke(string name, object[] args, CultureInfo culture)
//{
// return this.Invoke(name, null, args, culture);
//}

///// <summary>
///// Invokes the specified method
///// </summary>
///// <param name="name">Name of the method</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the method to get.</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <param name="culture">Culture info</param>
///// <returns>Result of method call</returns>
//public object Invoke(string name, Type[] parameterTypes, object[] args, CultureInfo culture)
//{
// return this.Invoke(name, BindToEveryThing, parameterTypes, args, culture);
//}

///// <summary>
///// Invokes the specified method
///// </summary>
///// <param name="name">Name of the method</param>
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <returns>Result of method call</returns>
//public object Invoke(string name, BindingFlags bindingFlags, params object[] args)
//{
// return this.Invoke(name, bindingFlags, null, args, CultureInfo.InvariantCulture);
//}

///// <summary>
///// Invokes the specified method
///// </summary>
///// <param name="name">Name of the method</param>
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the method to get.</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <returns>Result of method call</returns>
//public object Invoke(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args)
//{
// return this.Invoke(name, bindingFlags, parameterTypes, args, CultureInfo.InvariantCulture);
//}

///// <summary>
///// Invokes the specified method
///// </summary>
///// <param name="name">Name of the method</param>
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <param name="culture">Culture info</param>
///// <returns>Result of method call</returns>
//public object Invoke(string name, BindingFlags bindingFlags, object[] args, CultureInfo culture)
//{
// return this.Invoke(name, bindingFlags, null, args, culture);
//}

///// <summary>
///// Invokes the specified method
///// </summary>
///// <param name="name">Name of the method</param>
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the method to get.</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <param name="culture">Culture info</param>
///// <returns>Result of method call</returns>
//public object Invoke(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args, CultureInfo culture)
//{
// return this.Invoke(name, bindingFlags, parameterTypes, args, culture, null);
//}

///// <summary>
///// Invokes the specified method
///// </summary>
///// <param name="name">Name of the method</param>
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the method to get.</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <param name="culture">Culture info</param>
///// <param name="typeArguments">An array of types corresponding to the types of the generic arguments.</param>
///// <returns>Result of method call</returns>
//public object Invoke(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args, CultureInfo culture, Type[] typeArguments)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// if (parameterTypes != null)
// {
// bindingFlags |= BindToEveryThing | BindingFlags.Instance;

// // Fix up the parameter types
// MethodInfo member = this.originalType.GetMethod(name, bindingFlags, null, parameterTypes, null);

// // If the method was not found and type arguments were provided for generic paramaters,
// // attempt to look up a generic method.
// if ((member == null) && (typeArguments != null))
// {
// // This method may contain generic parameters...if so, the previous call to
// // GetMethod() will fail because it doesn't fully support generic parameters.

// // Look in the method cache to see if there is a generic method
// // on the incoming type that contains the correct signature.
// member = this.GetGenericMethodFromCache(name, parameterTypes, typeArguments, bindingFlags, null);
// }

// if (member == null)
// {
// throw new ArgumentException(
// string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name));
// }

// try
// {
// if (member.IsGenericMethodDefinition)
// {
// MethodInfo constructed = member.MakeGenericMethod(typeArguments);
// return constructed.Invoke(this.target, bindingFlags, null, args, culture);
// }
// else
// {
// return member.Invoke(this.target, bindingFlags, null, args, culture);
// }
// }
// catch (TargetInvocationException e)
// {
// Debug.Assert(e.InnerException != null, "Inner exception should not be null.");
// if (e.InnerException != null)
// {
// throw e.InnerException;
// }

// throw;
// }
// }
// else
// {
// return this.InvokeHelper(name, bindingFlags | BindingFlags.InvokeMethod, args, culture);
// }
//}

///// <summary>
///// Gets the array element using array of subsrcipts for each dimension
///// </summary>
///// <param name="name">Name of the member</param>
///// <param name="indices">the indices of array</param>
///// <returns>An arrya of elements.</returns>
//public object GetArrayElement(string name, params int[] indices)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// return this.GetArrayElement(name, BindToEveryThing, indices);
//}

///// <summary>
///// Sets the array element using array of subsrcipts for each dimension
///// </summary>
///// <param name="name">Name of the member</param>
///// <param name="value">Value to set</param>
///// <param name="indices">the indices of array</param>
//public void SetArrayElement(string name, object value, params int[] indices)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// this.SetArrayElement(name, BindToEveryThing, value, indices);
//}

///// <summary>
///// Gets the array element using array of subsrcipts for each dimension
///// </summary>
///// <param name="name">Name of the member</param>
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
///// <param name="indices">the indices of array</param>
///// <returns>An arrya of elements.</returns>
//public object GetArrayElement(string name, BindingFlags bindingFlags, params int[] indices)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// Array arr = (Array)this.InvokeHelper(name, BindingFlags.GetField | bindingFlags, null, CultureInfo.InvariantCulture);
// return arr.GetValue(indices);
//}

///// <summary>
///// Sets the array element using array of subsrcipts for each dimension
///// </summary>
///// <param name="name">Name of the member</param>
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
///// <param name="value">Value to set</param>
///// <param name="indices">the indices of array</param>
//public void SetArrayElement(string name, BindingFlags bindingFlags, object value, params int[] indices)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// Array arr = (Array)this.InvokeHelper(name, BindingFlags.GetField | bindingFlags, null, CultureInfo.InvariantCulture);
// arr.SetValue(value, indices);
//}

///// <summary>
///// Get the field
///// </summary>
///// <param name="name">Name of the field</param>
///// <returns>The field.</returns>
//public object GetField(string name)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// return this.GetField(name, BindToEveryThing);
//}

///// <summary>
///// Sets the field
///// </summary>
///// <param name="name">Name of the field</param>
///// <param name="value">value to set</param>
//public void SetField(string name, object value)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// this.SetField(name, BindToEveryThing, value);
//}

///// <summary>
///// Gets the field
///// </summary>
///// <param name="name">Name of the field</param>
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
///// <returns>The field.</returns>
//public object GetField(string name, BindingFlags bindingFlags)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// return this.InvokeHelper(name, BindingFlags.GetField | bindingFlags, null, CultureInfo.InvariantCulture);
//}

///// <summary>
///// Sets the field
///// </summary>
///// <param name="name">Name of the field</param>
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
///// <param name="value">value to set</param>
//public void SetField(string name, BindingFlags bindingFlags, object value)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// this.InvokeHelper(name, BindingFlags.SetField | bindingFlags, new object[] { value }, CultureInfo.InvariantCulture);
//}

/// <summary>
/// Get the field or property
/// </summary>
/// <param name="name">Name of the field or property</param>
/// <returns>The field or property.</returns>
public object GetFieldOrProperty(string name)
{
Helper.CheckParameterNotNull(name, "name", string.Empty);
return this.GetFieldOrProperty(name, BindToEveryThing);
}

/// <summary>
/// Sets the field or property
/// </summary>
/// <param name="name">Name of the field or property</param>
/// <param name="value">value to set</param>
public void SetFieldOrProperty(string name, object value)
{
Helper.CheckParameterNotNull(name, "name", string.Empty);
this.SetFieldOrProperty(name, BindToEveryThing, value);
}

/// <summary>
/// Gets the field or property
/// </summary>
/// <param name="name">Name of the field or property</param>
/// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
/// <returns>The field or property.</returns>
public object GetFieldOrProperty(string name, BindingFlags bindingFlags)
{
Helper.CheckParameterNotNull(name, "name", string.Empty);
return this.InvokeHelper(name, BindingFlags.GetField | BindingFlags.GetProperty | bindingFlags, null, CultureInfo.InvariantCulture);
}

/// <summary>
/// Sets the field or property
/// </summary>
/// <param name="name">Name of the field or property</param>
/// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
/// <param name="value">value to set</param>
public void SetFieldOrProperty(string name, BindingFlags bindingFlags, object value)
{
Helper.CheckParameterNotNull(name, "name", string.Empty);
this.InvokeHelper(name, BindingFlags.SetField | BindingFlags.SetProperty | bindingFlags, new object[] { value }, CultureInfo.InvariantCulture);
}

///// <summary>
///// Gets the property
///// </summary>
///// <param name="name">Name of the property</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <returns>The property.</returns>
//public object GetProperty(string name, params object[] args)
//{
// return this.GetProperty(name, null, args);
//}

///// <summary>
///// Gets the property
///// </summary>
///// <param name="name">Name of the property</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the indexed property.</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <returns>The property.</returns>
//public object GetProperty(string name, Type[] parameterTypes, object[] args)
//{
// return this.GetProperty(name, BindToEveryThing, parameterTypes, args);
//}

///// <summary>
///// Set the property
///// </summary>
///// <param name="name">Name of the property</param>
///// <param name="value">value to set</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
//public void SetProperty(string name, object value, params object[] args)
//{
// this.SetProperty(name, null, value, args);
//}

///// <summary>
///// Set the property
///// </summary>
///// <param name="name">Name of the property</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the indexed property.</param>
///// <param name="value">value to set</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
//public void SetProperty(string name, Type[] parameterTypes, object value, object[] args)
//{
// this.SetProperty(name, BindToEveryThing, value, parameterTypes, args);
//}

///// <summary>
///// Gets the property
///// </summary>
///// <param name="name">Name of the property</param>
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <returns>The property.</returns>
//public object GetProperty(string name, BindingFlags bindingFlags, params object[] args)
//{
// return this.GetProperty(name, bindingFlags, null, args);
//}

///// <summary>
///// Gets the property
///// </summary>
///// <param name="name">Name of the property</param>
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the indexed property.</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
///// <returns>The property.</returns>
//public object GetProperty(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);
// if (parameterTypes != null)
// {
// PropertyInfo pi = this.originalType.GetProperty(name, bindingFlags, null, null, parameterTypes, null);
// if (pi == null)
// {
// throw new ArgumentException(
// string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name));
// }

// return pi.GetValue(this.target, args);
// }
// else
// {
// return this.InvokeHelper(name, bindingFlags | BindingFlags.GetProperty, args, null);
// }
//}

///// <summary>
///// Sets the property
///// </summary>
///// <param name="name">Name of the property</param>
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
///// <param name="value">value to set</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
//public void SetProperty(string name, BindingFlags bindingFlags, object value, params object[] args)
//{
// this.SetProperty(name, bindingFlags, value, null, args);
//}

///// <summary>
///// Sets the property
///// </summary>
///// <param name="name">Name of the property</param>
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param>
///// <param name="value">value to set</param>
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the indexed property.</param>
///// <param name="args">Arguments to pass to the member to invoke.</param>
//public void SetProperty(string name, BindingFlags bindingFlags, object value, Type[] parameterTypes, object[] args)
//{
// Helper.CheckParameterNotNull(name, "name", string.Empty);

// if (parameterTypes != null)
// {
// PropertyInfo pi = this.originalType.GetProperty(name, bindingFlags, null, null, parameterTypes, null);
// if (pi == null)
// {
// throw new ArgumentException(
// string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name));
// }

// pi.SetValue(this.target, value, args);
// }
// else
// {
// object[] pass = new object[(args?.Length ?? 0) + 1];
// pass[0] = value;
// args?.CopyTo(pass, 1);
// this.InvokeHelper(name, bindingFlags | BindingFlags.SetProperty, pass, null);
// }
//}

#region Private Helpers

///// <summary>
///// Validate access string
///// </summary>
///// <param name="access"> access string</param>
//private static void ValidateAccessString(string access)
//{
// Helper.CheckParameterNotNull(access, "access", string.Empty);
// if (access.Length == 0)
// {
// throw new ArgumentException(FrameworkMessages.AccessStringInvalidSyntax);
// }

// string[] arr = access.Split('.');
// foreach (string str in arr)
// {
// if ((str.Length == 0) || (str.IndexOfAny(new char[] { ' ', '\t', '\n' }) != -1))
// {
// throw new ArgumentException(FrameworkMessages.AccessStringInvalidSyntax);
// }
// }
//}

/// <summary>
/// Invokes the memeber
/// </summary>
/// <param name="name">Name of the member</param>
/// <param name="bindingFlags">Additional attributes</param>
/// <param name="args">Arguments for the invocation</param>
/// <param name="culture">Culture</param>
/// <returns>Result of the invocation</returns>
private object InvokeHelper(string name, BindingFlags bindingFlags, object[] args, CultureInfo culture)
{
Helper.CheckParameterNotNull(name, "name", string.Empty);
//Debug.Assert(this.target != null, "Internal Error: Null reference is returned for internal object");

// Invoke the actual Method
try
{
return this.originalType.InvokeMember(name, bindingFlags, null, this.target, args, culture);
}
catch (TargetInvocationException e)
{
//Debug.Assert(e.InnerException != null, "Inner exception should not be null.");
if (e.InnerException != null)
{
throw e.InnerException;
}

throw;
}
}

private void ConstructFrom(object obj)
{
Helper.CheckParameterNotNull(obj, "obj", string.Empty);
this.target = obj;
this.originalType = obj.GetType();
}

//private void BuildGenericMethodCacheForType(Type t)
//{
// Debug.Assert(t != null, "type should not be null.");
// this.methodCache = new Dictionary<string, LinkedList<MethodInfo>>();

// MethodInfo[] members = t.GetMethods(BindToEveryThing);
// LinkedList<MethodInfo> listByName; // automatically initialized to null

// foreach (MethodInfo member in members)
// {
// if (member.IsGenericMethod || member.IsGenericMethodDefinition)
// {
// if (!this.GenericMethodCache.TryGetValue(member.Name, out listByName))
// {
// listByName = new LinkedList<MethodInfo>();
// this.GenericMethodCache.Add(member.Name, listByName);
// }

// Debug.Assert(listByName != null, "list should not be null.");
// listByName.AddLast(member);
// }
// }
//}

///// <summary>
///// Extracts the most appropriate generic method signature from the current private type.
///// </summary>
///// <param name="methodName">The name of the method in which to search the signature cache.</param>
///// <param name="parameterTypes">An array of types corresponding to the types of the parameters in which to search.</param>
///// <param name="typeArguments">An array of types corresponding to the types of the generic arguments.</param>
///// <param name="bindingFlags"><see cref="BindingFlags"/> to further filter the method signatures.</param>
///// <param name="modifiers">Modifiers for parameters.</param>
///// <returns>A methodinfo instance.</returns>
//private MethodInfo GetGenericMethodFromCache(string methodName, Type[] parameterTypes, Type[] typeArguments, BindingFlags bindingFlags, ParameterModifier[] modifiers)
//{
// Debug.Assert(!string.IsNullOrEmpty(methodName), "Invalid method name.");
// Debug.Assert(parameterTypes != null, "Invalid parameter type array.");
// Debug.Assert(typeArguments != null, "Invalid type arguments array.");

// // Build a preliminary list of method candidates that contain roughly the same signature.
// var methodCandidates = this.GetMethodCandidates(methodName, parameterTypes, typeArguments, bindingFlags, modifiers);

// // Search of ambiguous methods (methods with the same signature).
// MethodInfo[] finalCandidates = new MethodInfo[methodCandidates.Count];
// methodCandidates.CopyTo(finalCandidates, 0);

// if ((parameterTypes != null) && (parameterTypes.Length == 0))
// {
// for (int i = 0; i < finalCandidates.Length; i++)
// {
// MethodInfo methodInfo = finalCandidates[i];

// if (!RuntimeTypeHelper.CompareMethodSigAndName(methodInfo, finalCandidates[0]))
// {
// throw new AmbiguousMatchException();
// }
// }

// // All the methods have the exact same name and sig so return the most derived one.
// return RuntimeTypeHelper.FindMostDerivedNewSlotMeth(finalCandidates, finalCandidates.Length) as MethodInfo;
// }

// // Now that we have a preliminary list of candidates, select the most appropriate one.
// return RuntimeTypeHelper.SelectMethod(bindingFlags, finalCandidates, parameterTypes, modifiers) as MethodInfo;
//}

//private LinkedList<MethodInfo> GetMethodCandidates(string methodName, Type[] parameterTypes, Type[] typeArguments, BindingFlags bindingFlags, ParameterModifier[] modifiers)
//{
// Debug.Assert(!string.IsNullOrEmpty(methodName), "methodName should not be null.");
// Debug.Assert(parameterTypes != null, "parameterTypes should not be null.");
// Debug.Assert(typeArguments != null, "typeArguments should not be null.");

// LinkedList<MethodInfo> methodCandidates = new LinkedList<MethodInfo>();
// LinkedList<MethodInfo> methods = null;

// if (!this.GenericMethodCache.TryGetValue(methodName, out methods))
// {
// return methodCandidates;
// }

// Debug.Assert(methods != null, "methods should not be null.");

// foreach (MethodInfo candidate in methods)
// {
// bool paramMatch = true;
// ParameterInfo[] candidateParams = null;
// Type[] genericArgs = candidate.GetGenericArguments();
// Type sourceParameterType = null;

// if (genericArgs.Length != typeArguments.Length)
// {
// continue;
// }

// // Since we can't just get the correct MethodInfo from Reflection,
// // we will just match the number of parameters, their order, and their type
// var methodCandidate = candidate;
// candidateParams = methodCandidate.GetParameters();

// if (candidateParams.Length != parameterTypes.Length)
// {
// continue;
// }

// // Exact binding
// if ((bindingFlags & BindingFlags.ExactBinding) != 0)
// {
// int i = 0;

// foreach (ParameterInfo candidateParam in candidateParams)
// {
// sourceParameterType = parameterTypes[i++];

// if (candidateParam.ParameterType.ContainsGenericParameters)
// {
// // Since we have a generic parameter here, just make sure the IsArray matches.
// if (candidateParam.ParameterType.IsArray != sourceParameterType.IsArray)
// {
// paramMatch = false;
// break;
// }
// }
// else
// {
// if (candidateParam.ParameterType != sourceParameterType)
// {
// paramMatch = false;
// break;
// }
// }
// }

// if (paramMatch)
// {
// methodCandidates.AddLast(methodCandidate);
// continue;
// }
// }
// else
// {
// methodCandidates.AddLast(methodCandidate);
// }
// }

// return methodCandidates;
//}

#endregion
}
}

+ 0
- 314
test/TensorFlowNET.UnitTest/Utilities/PrivateObjectExtensions.cs View File

@@ -1,314 +0,0 @@
// <copyright file="PrivateObjectExtensions.cs">
// Copyright (c) 2019 cactuaroid All Rights Reserved
// </copyright>
// <summary>
// Released under the MIT license
// https://github.com/cactuaroid/PrivateObjectExtensions
// </summary>

using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Linq;
using System.Reflection;

namespace System
{
/// <summary>
/// Extension methods for PrivateObject
/// </summary>
public static class PrivateObjectExtensions
{
private static readonly BindingFlags Static = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly | BindingFlags.Static;
private static readonly BindingFlags Instance = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly | BindingFlags.Instance;

/// <summary>
/// Get from private (and any other) field/property.
/// If the real type of specified object doesn't contain the specified field/property,
/// base types are searched automatically.
/// </summary>
/// <param name="obj">The object to get from</param>
/// <param name="name">The name of the field/property</param>
/// <returns>The object got from the field/property</returns>
/// <exception cref="ArgumentException">'name' is not found.</exception>
/// <exception cref="ArgumentNullException">Arguments contain null.</exception>
public static object GetPrivate(this object obj, string name)
{
if (obj == null) { throw new ArgumentNullException("obj"); }

return GetPrivate(obj, name, obj.GetType(), null);
}

/// <summary>
/// Get from private (and any other) field/property.
/// If the real type of specified object doesn't contain the specified field/property,
/// base types are searched automatically.
/// </summary>
/// <typeparam name="T">The type of the field/property</typeparam>
/// <param name="obj">The object to get from</param>
/// <param name="name">The name of the field/property</param>
/// <returns>The object got from the field/property</returns>
/// <exception cref="ArgumentException">'name' is not found.</exception>
/// <exception cref="ArgumentNullException">Arguments contain null.</exception>
public static T GetPrivate<T>(this object obj, string name)
{
if (obj == null) { throw new ArgumentNullException("obj"); }

return (T)GetPrivate(obj, name, obj.GetType(), typeof(T));
}

/// <summary>
/// Get from private (and any other) field/property with assuming the specified object as specified type.
/// If the specified type doesn't contain the specified field/property,
/// base types are searched automatically.
/// </summary>
/// <param name="obj">The object to get from</param>
/// <param name="name">The name of the field/property</param>
/// <param name="objType">The type of 'obj' for seaching member starting from. Real type of 'obj' is ignored.</param>
/// <returns>The object got from the field/property</returns>
/// <exception cref="ArgumentException">'name' is not found.</exception>
/// <exception cref="ArgumentException">'objType' is not assignable from 'obj'.</exception>
/// <exception cref="ArgumentNullException">Arguments contain null.</exception>
public static object GetPrivate(this object obj, string name, Type objType)
{
return GetPrivate(obj, name, objType, null);
}

/// <summary>
/// Get from private (and any other) field/property with assuming the specified object as specified type.
/// If the specified type doesn't contain the specified field/property,
/// base types are searched automatically.
/// </summary>
/// <typeparam name="T">The type of the field/property</typeparam>
/// <param name="obj">The object to get from</param>
/// <param name="name">The name of the field/property</param>
/// <param name="objType">The type of 'obj' for seaching member starting from. Real type of 'obj' is ignored.</param>
/// <returns>The object got from the field/property</returns>
/// <exception cref="ArgumentException">'name' is not found.</exception>
/// <exception cref="ArgumentException">'objType' is not assignable from 'obj'.</exception>
/// <exception cref="ArgumentNullException">Arguments contain null.</exception>
public static T GetPrivate<T>(this object obj, string name, Type objType)
{
return (T)GetPrivate(obj, name, objType, typeof(T));
}

private static object GetPrivate(object obj, string name, Type objType, Type memberType)
{
if (obj == null) { throw new ArgumentNullException("obj"); }
if (name == null) { throw new ArgumentNullException("name"); }
if (string.IsNullOrWhiteSpace(name)) { throw new ArgumentException("name is empty or white-space.", "name"); }
if (objType == null) { throw new ArgumentNullException("objType"); }
if (!objType.IsAssignableFrom(obj.GetType())) { throw new ArgumentException($"{objType} is not assignable from {obj.GetType()}.", "objType"); }

bool memberTypeMatching(Type actualType) => actualType == memberType;

if (TryFindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, Instance, out var ownerType))
{
return new PrivateObject(obj, new PrivateType(ownerType)).GetFieldOrProperty(name);
}
else if (TryFindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, Static, out ownerType))
{
return new PrivateType(ownerType).GetStaticFieldOrProperty(name);
}

throw new ArgumentException(((memberType != null) ? memberType + " " : "") + name + " is not found.");
}

/// <summary>
/// Get from private (and any other) static field/property.
/// </summary>
/// <param name="type">The type to get from</param>
/// <param name="name">The name of the static field/property</param>
/// <returns>The object got from the static field/property</returns>
/// <exception cref="ArgumentException">'name' is not found.</exception>
/// <exception cref="ArgumentNullException">Arguments contain null.</exception>
public static object GetPrivate(this Type type, string name)
{
return GetPrivate(type, name, null);
}

/// <summary>
/// Get from private (and any other) static field/property.
/// </summary>
/// <typeparam name="T">The type of the field/property</typeparam>
/// <param name="type">The type to get from</param>
/// <param name="name">The name of the static field/property</param>
/// <returns>The object got from the static field/property</returns>
/// <exception cref="ArgumentException">'name' is not found.</exception>
/// <exception cref="ArgumentNullException">Arguments contain null.</exception>
public static T GetPrivate<T>(this Type type, string name)
{
return (T)GetPrivate(type, name, typeof(T));
}

private static object GetPrivate(this Type type, string name, Type memberType)
{
if (type == null) { throw new ArgumentNullException("type"); }
if (name == null) { throw new ArgumentNullException("name"); }
if (string.IsNullOrWhiteSpace(name)) { throw new ArgumentException("name is empty or white-space.", "name"); }

bool memberTypeMatching(Type actualType) => actualType == memberType;

if (type.ContainsFieldOrProperty(name, memberType, memberTypeMatching, Static))
{
return new PrivateType(type).GetStaticFieldOrProperty(name);
}

throw new ArgumentException(((memberType != null) ? memberType + " " : "") + name + " is not found.");
}

/// <summary>
/// Set to private (and any other) field/property.
/// If the real type of specified object doesn't contain the specified field/property,
/// base types are searched automatically.
/// </summary>
/// <param name="obj">The object to set to</param>
/// <param name="name">The name of the field/property</param>
/// <param name="value">The value to set for 'name'</param>
/// <exception cref="ArgumentException">'name' is not found.</exception>
/// <exception cref="ArgumentNullException">Arguments contain null.</exception>
public static void SetPrivate<T>(this object obj, string name, T value)
{
if (obj == null) { throw new ArgumentNullException("obj"); }

SetPrivate(obj, name, value, obj.GetType());
}

/// <summary>
/// Set to private (and any other) field/property with assuming the specified object as specified type.
/// If the specified type doesn't contain the specified field/property,
/// base types are searched automatically.
/// </summary>
/// <param name="obj">The object to set to</param>
/// <param name="name">The name of the field/property</param>
/// <param name="value">The value to set for 'name'</param>
/// <param name="objType">The type of 'obj' for seaching member starting from. Real type of 'obj' is ignored.</param>
/// <exception cref="ArgumentException">'name' is not found.</exception>
/// <exception cref="ArgumentException">'objType' is not assignable from 'obj'.</exception>
/// <exception cref="ArgumentNullException">Arguments contain null.</exception>
public static void SetPrivate<T>(this object obj, string name, T value, Type objType)
{
if (obj == null) { throw new ArgumentNullException("obj"); }
if (name == null) { throw new ArgumentNullException("name"); }
if (string.IsNullOrWhiteSpace(name)) { throw new ArgumentException("name is empty or white-space.", "name"); }
if (value == null) { throw new ArgumentNullException("value"); }
if (objType == null) { throw new ArgumentNullException("objType"); }
if (!objType.IsAssignableFrom(obj.GetType())) { throw new ArgumentException($"{objType} is not assignable from {obj.GetType()}.", "objType"); }

if (TrySetPrivate(obj, name, value, objType)) { return; }

// retry for the case of getter only property
if (TrySetPrivate(obj, GetBackingFieldName(name), value, objType)) { return; }

throw new ArgumentException($"{typeof(T)} {name} is not found.");
}

private static bool TrySetPrivate<T>(object obj, string name, T value, Type objType)
{
var memberType = typeof(T);
bool memberTypeMatching(Type actualType) => actualType.IsAssignableFrom(memberType);

try
{
if (TryFindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, Instance, out var ownerType))
{
new PrivateObject(obj, new PrivateType(ownerType)).SetFieldOrProperty(name, value);
return true;
}
else if (TryFindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, Static, out ownerType))
{
new PrivateType(ownerType).SetStaticFieldOrProperty(name, value);
return true;
}
}
catch (MissingMethodException)
{
// When getter only property name is given, the property is found but fails to set.
return false;
}

return false;
}

/// <summary>
/// Set to private (and any other) static field/property.
/// </summary>
/// <param name="type">The type to set to</param>
/// <param name="name">The name of the field/property</param>
/// <param name="value">The value to set for 'name'</param>
/// <exception cref="ArgumentException">'name' is not found.</exception>
/// <exception cref="ArgumentNullException">Arguments contain null.</exception>
public static void SetPrivate<T>(this Type type, string name, T value)
{
if (type == null) { throw new ArgumentNullException("type"); }
if (name == null) { throw new ArgumentNullException("name"); }
if (string.IsNullOrWhiteSpace(name)) { throw new ArgumentException("name is empty or white-space.", "name"); }

if (TrySetPrivate(type, name, value)) { return; }

// retry for the case of getter only property
if (TrySetPrivate(type, GetBackingFieldName(name), value)) { return; }

throw new ArgumentException($"{typeof(T)} {name} is not found.");
}

private static bool TrySetPrivate<T>(this Type type, string name, T value)
{
var memberType = typeof(T);
bool memberTypeMatching(Type actualType) => actualType.IsAssignableFrom(memberType);

try
{
if (type.ContainsFieldOrProperty(name, memberType, memberTypeMatching, Static))
{
new PrivateType(type).SetStaticFieldOrProperty(name, value);
return true;
}
}
catch (MissingMethodException)
{
// When getter only property name is given, the property is found but fails to set.
return false;
}

return false;
}

private static string GetBackingFieldName(string propertyName)
=> $"<{propertyName}>k__BackingField"; // generated backing field name

private static bool TryFindFieldOrPropertyOwnerType(Type objType, string name, Type memberType, Func<Type, bool> memberTypeMatching, BindingFlags bindingFlag, out Type ownerType)
{
ownerType = FindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, bindingFlag);

return (ownerType != null);
}

private static Type FindFieldOrPropertyOwnerType(Type objectType, string name, Type memberType, Func<Type, bool> memberTypeMatching, BindingFlags bindingFlags)
{
if (objectType == null) { return null; }

if (objectType.ContainsFieldOrProperty(name, memberType, memberTypeMatching, bindingFlags))
{
return objectType;
}

return FindFieldOrPropertyOwnerType(objectType.BaseType, name, memberType, memberTypeMatching, bindingFlags);
}

private static bool ContainsFieldOrProperty(this Type objectType, string name, Type memberType, Func<Type, bool> memberTypeMatching, BindingFlags bindingFlags)
{
var fields = objectType
.GetFields(bindingFlags)
.Select((x) => new { Type = x.FieldType, Member = x as MemberInfo });

var properties = objectType
.GetProperties(bindingFlags)
.Select((x) => new { Type = x.PropertyType, Member = x as MemberInfo });

var members = fields.Concat(properties);

return members.Any((actual) =>
(memberType == null || memberTypeMatching.Invoke(actual.Type))
&& actual.Member.Name == name);
}
}
}

+ 0
- 172
test/TensorFlowNET.UnitTest/control_flow_ops_test/SwitchTestCase.cs View File

@@ -1,172 +0,0 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;

namespace TensorFlowNET.UnitTest.control_flow_ops_test
{
/// <summary>
/// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py
/// </summary>
[TestClass]
public class SwitchTestCase : PythonTest
{

[Ignore("TODO")]
[TestMethod]
public void testResourceReadInLoop()
{

//var embedding_matrix = variable_scope.get_variable(
//"embedding_matrix", initializer: new double[,] { { 2.0 }, { 3.0 } }, use_resource: true);

/*
Tensor cond(Tensor it, Tensor _)
{
return it < 5;
}
*/

// TODO: below code doesn't compile
//(Tensor, Tensor) body(Tensor it, Tensor cost)
//{
// var embedding = embedding_ops.embedding_lookup(embedding_matrix, new int[]{0});
// cost += math_ops.reduce_sum(embedding);
// return (it + 1, cost);
//}
//var (_, cost1) = control_flow_ops.while_loop(
// cond, body, new[]
// {
// constant_op.constant(0),
// constant_op.constant(0.0)
// });
//with<Session>(this.cached_session(), sess =>
//{
// self.evaluate(variables.global_variables_initializer());
// self.assertAllEqual(10.0, self.evaluate(cost1));
//});
}


[Ignore("TODO")]
[TestMethod]
public void testIndexedSlicesGradientInCondInWhileLoop()
{
doTestIndexedSlicesGradientInCondInWhileLoop(use_resource: false);
}

[Ignore("TODO")]
[TestMethod]
public void testIndexedSlicesGradientInCondInWhileLoopResource()
{
doTestIndexedSlicesGradientInCondInWhileLoop(use_resource: true);
}

private void doTestIndexedSlicesGradientInCondInWhileLoop(bool use_resource = false)
{
//def doTestIndexedSlicesGradientInCondInWhileLoop(self, use_resource=False):
// embedding_matrix = variable_scope.get_variable(
// "embedding_matrix", [5, 5],
// initializer=init_ops.random_normal_initializer(),
// use_resource=use_resource)

// def cond(it, _):
// return it < 5

// def body(it, cost):
// embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
// cost = control_flow_ops.cond(
// math_ops.equal(it, 3), lambda: math_ops.square(cost),
// (lambda: cost + math_ops.reduce_sum(embedding)))
// return it + 1, cost

// _, cost = control_flow_ops.while_loop(
// cond, body, [constant_op.constant(0),
// constant_op.constant(0.0)])

// dynamic_grads = gradients_impl.gradients(cost, [embedding_matrix])[0]
// dynamic_grads = math_ops.segment_sum(dynamic_grads.values,
// dynamic_grads.indices)

// embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
// static = math_ops.square(
// math_ops.reduce_sum(embedding) + math_ops.reduce_sum(embedding) +
// math_ops.reduce_sum(embedding)) + math_ops.reduce_sum(embedding)
// static_grads = gradients_impl.gradients(static, [embedding_matrix])[0]
// static_grads = math_ops.segment_sum(static_grads.values,
// static_grads.indices)

// with self.cached_session():
// self.evaluate(variables.global_variables_initializer())
// self.assertAllEqual(*self.evaluate([static_grads, dynamic_grads]))
}

[Ignore("TODO")]
[TestMethod]
public void testIndexedSlicesWithShapeGradientInWhileLoop()
{
//@test_util.run_v1_only("b/120545219")
//def testIndexedSlicesWithShapeGradientInWhileLoop(self):
// for dtype in [dtypes.float32, dtypes.float64]:
// with self.cached_session() as sess:
// num_steps = 9

// inputs = array_ops.placeholder(dtype=dtype, shape=[num_steps])
// initial_outputs = tensor_array_ops.TensorArray(
// dtype=dtype, size=num_steps)
// initial_i = constant_op.constant(0, dtype=dtypes.int32)

// def cond(i, _):
// return i < num_steps # pylint: disable=cell-var-from-loop

// def body(i, outputs):
// x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop
// outputs = outputs.write(i, x)
// return i + 1, outputs

// _, outputs = control_flow_ops.while_loop(cond, body,
// [initial_i, initial_outputs])

// outputs = math_ops.reduce_sum(outputs.stack())
// r = gradients_impl.gradients([outputs], [inputs])[0]
// grad_wr_inputs = ops.convert_to_tensor(r)
// o, grad = sess.run([outputs, grad_wr_inputs],
// feed_dict={inputs: [4, 6, 0, 7, 0, 0, 1, 2, 0]})
// self.assertEquals(o, 20)
// self.assertAllEqual(grad, [1] * num_steps)

}

[Ignore("TODO")]
[TestMethod]
public void testIndexedSlicesWithDynamicShapeGradientInWhileLoop()
{
//@test_util.run_v1_only("b/120545219")
//def testIndexedSlicesWithDynamicShapeGradientInWhileLoop(self):
// for dtype in [dtypes.float32, dtypes.float64]:
// with self.cached_session() as sess:
// inputs = array_ops.placeholder(dtype=dtype)
// initial_outputs = tensor_array_ops.TensorArray(
// dtype=dtype, dynamic_size=True, size=1)
// initial_i = constant_op.constant(0, dtype=dtypes.int32)

// def cond(i, _):
// return i < array_ops.size(inputs) # pylint: disable=cell-var-from-loop

// def body(i, outputs):
// x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop
// outputs = outputs.write(i, x)
// return i + 1, outputs

// _, outputs = control_flow_ops.while_loop(cond, body,
// [initial_i, initial_outputs])

// outputs = math_ops.reduce_sum(outputs.stack())
// r = gradients_impl.gradients([outputs], [inputs])[0]
// grad_wr_inputs = ops.convert_to_tensor(r)
// o, grad = sess.run([outputs, grad_wr_inputs],
// feed_dict={inputs: [1, 3, 2]})
// self.assertEquals(o, 6)
// self.assertAllEqual(grad, [1] * 3)

}

}
}

Loading…
Cancel
Save