@@ -1,7 +1,7 @@ | |||
| |||
Microsoft Visual Studio Solution File, Format Version 12.00 | |||
# Visual Studio Version 16 | |||
VisualStudioVersion = 16.0.29102.190 | |||
# Visual Studio Version 17 | |||
VisualStudioVersion = 17.0.31423.177 | |||
MinimumVisualStudioVersion = 10.0.40219.1 | |||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Binding", "src\TensorFlowNET.Core\Tensorflow.Binding.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}" | |||
EndProject | |||
@@ -21,6 +21,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Native.UnitTest" | |||
EndProject | |||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras.UnitTest", "test\TensorFlowNET.Keras.UnitTest\Tensorflow.Keras.UnitTest.csproj", "{79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}" | |||
EndProject | |||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Graph.UnitTest", "test\TensorFlowNET.Graph.UnitTest\TensorFlowNET.Graph.UnitTest.csproj", "{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}" | |||
EndProject | |||
Global | |||
GlobalSection(SolutionConfigurationPlatforms) = preSolution | |||
Debug|Any CPU = Debug|Any CPU | |||
@@ -139,6 +141,18 @@ Global | |||
{79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Release|x64.Build.0 = Release|Any CPU | |||
{79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Release|x86.ActiveCfg = Release|Any CPU | |||
{79EB56DF-E29E-4AE2-A7D9-FE403FD919BA}.Release|x86.Build.0 = Release|Any CPU | |||
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Debug|x64.ActiveCfg = Debug|x64 | |||
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Debug|x64.Build.0 = Debug|x64 | |||
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Debug|x86.ActiveCfg = Debug|Any CPU | |||
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Debug|x86.Build.0 = Debug|Any CPU | |||
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|Any CPU.Build.0 = Release|Any CPU | |||
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.ActiveCfg = Release|Any CPU | |||
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x64.Build.0 = Release|Any CPU | |||
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.ActiveCfg = Release|Any CPU | |||
{3F5388FF-FBB4-462B-8F6F-829FFBAEB8A3}.Release|x86.Build.0 = Release|Any CPU | |||
EndGlobalSection | |||
GlobalSection(SolutionProperties) = preSolution | |||
HideSolutionNode = FALSE | |||
@@ -1,7 +1,6 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System.Linq; | |||
using Tensorflow; | |||
using Tensorflow.UnitTest; | |||
using static Tensorflow.Binding; | |||
namespace TensorFlowNET.UnitTest.Basics |
@@ -1,9 +1,8 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using Tensorflow; | |||
using Tensorflow.UnitTest; | |||
using static Tensorflow.Binding; | |||
namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||
namespace TensorFlowNET.UnitTest.ControlFlowTest | |||
{ | |||
/// <summary> | |||
/// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py |
@@ -1,8 +1,7 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using Tensorflow; | |||
using Tensorflow.UnitTest; | |||
namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||
namespace TensorFlowNET.UnitTest.ControlFlowTest | |||
{ | |||
/// <summary> | |||
/// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py |
@@ -1,10 +1,9 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System; | |||
using Tensorflow; | |||
using Tensorflow.UnitTest; | |||
using static Tensorflow.Binding; | |||
namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||
namespace TensorFlowNET.UnitTest.ControlFlowTest | |||
{ | |||
[TestClass] | |||
public class WhileContextTestCase : GraphModeTestBase |
@@ -2,10 +2,9 @@ | |||
using Tensorflow.NumPy; | |||
using System; | |||
using Tensorflow; | |||
using Tensorflow.UnitTest; | |||
using static Tensorflow.Binding; | |||
namespace TensorFlowNET.UnitTest.functional_ops_test | |||
namespace TensorFlowNET.UnitTest.FunctionalOpsTest | |||
{ | |||
/// <summary> | |||
/// https://www.tensorflow.org/api_docs/python/tf/scan | |||
@@ -22,7 +21,8 @@ namespace TensorFlowNET.UnitTest.functional_ops_test | |||
var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(6)); | |||
var scan = functional_ops.scan(fn, input); | |||
sess.run(scan, (input, np.array(1, 2, 3, 4, 5, 6))).Should().Be(np.array(1, 3, 6, 10, 15, 21)); | |||
var result = sess.run(scan, (input, np.array(1, 2, 3, 4, 5, 6))); | |||
Assert.AreEqual(result, np.array(1, 3, 6, 10, 15, 21)); | |||
} | |||
[TestMethod, Ignore("need UpdateEdge API")] | |||
@@ -34,7 +34,8 @@ namespace TensorFlowNET.UnitTest.functional_ops_test | |||
var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(6)); | |||
var scan = functional_ops.scan(fn, input, reverse: true); | |||
sess.run(scan, (input, np.array(1, 2, 3, 4, 5, 6))).Should().Be(np.array(21, 20, 18, 15, 11, 6)); | |||
var result = sess.run(scan, (input, np.array(1, 2, 3, 4, 5, 6))); | |||
Assert.AreEqual(result, np.array(21, 20, 18, 15, 11, 6)); | |||
} | |||
} | |||
} |
@@ -4,7 +4,6 @@ using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow; | |||
using Tensorflow.UnitTest; | |||
using static Tensorflow.Binding; | |||
namespace TensorFlowNET.UnitTest.Gradient |
@@ -1,9 +1,7 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using TensorFlowNET.UnitTest; | |||
using static Tensorflow.Binding; | |||
using static Tensorflow.KerasApi; | |||
namespace Tensorflow.UnitTest | |||
namespace TensorFlowNET.UnitTest | |||
{ | |||
public class GraphModeTestBase : PythonTest | |||
{ | |||
@@ -16,7 +14,6 @@ namespace Tensorflow.UnitTest | |||
[TestCleanup] | |||
public void TestClean() | |||
{ | |||
keras.backend.clear_session(); | |||
tf.enable_eager_execution(); | |||
} | |||
} |
@@ -1,12 +1,10 @@ | |||
using FluentAssertions; | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using Tensorflow.NumPy; | |||
using System.Linq; | |||
using Tensorflow; | |||
using Tensorflow.UnitTest; | |||
using static Tensorflow.Binding; | |||
namespace TensorFlowNET.UnitTest.Basics | |||
namespace TensorFlowNET.UnitTest | |||
{ | |||
/// <summary> | |||
/// Find more examples in https://www.programcreek.com/python/example/90444/tensorflow.read_file | |||
@@ -84,14 +82,14 @@ namespace TensorFlowNET.UnitTest.Basics | |||
var result = sess.run(cropped); | |||
// check if cropped to 1x1 center was succesfull | |||
result.size.Should().Be(1); | |||
result[0, 0, 0, 0].Should().Be(4f); | |||
Assert.AreEqual(result.size, 1); | |||
Assert.AreEqual(result[0, 0, 0, 0], 4f); | |||
cropped = tf.image.crop_and_resize(image2, box, boxInd, cropSize2_2); | |||
result = sess.run(cropped); | |||
// check if flipped and no cropping occured | |||
result.size.Should().Be(16); | |||
result[0, 0, 0, 0].Should().Be(12f); | |||
Assert.AreEqual(result.size, 16); | |||
Assert.AreEqual(result[0, 0, 0, 0], 12f); | |||
} | |||
} |
@@ -1,12 +1,10 @@ | |||
using FluentAssertions; | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using Tensorflow.NumPy; | |||
using System; | |||
using System.IO; | |||
using System.Linq; | |||
using System.Runtime.InteropServices; | |||
using Tensorflow; | |||
using Tensorflow.UnitTest; | |||
using static Tensorflow.Binding; | |||
namespace TensorFlowNET.UnitTest | |||
@@ -24,15 +22,15 @@ namespace TensorFlowNET.UnitTest | |||
//the core method | |||
void Core(int tid) | |||
{ | |||
tf.peak_default_graph().Should().BeNull(); | |||
Assert.IsNull(tf.peak_default_graph()); | |||
using (var sess = tf.Session()) | |||
{ | |||
var default_graph = tf.peak_default_graph(); | |||
var sess_graph = sess.GetPrivate<Graph>("_graph"); | |||
sess_graph.Should().NotBeNull(); | |||
default_graph.Should().NotBeNull() | |||
.And.BeEquivalentTo(sess_graph); | |||
var sess_graph = sess.graph; | |||
Assert.IsNotNull(default_graph); | |||
Assert.IsNotNull(sess_graph); | |||
Assert.AreEqual(default_graph, sess_graph); | |||
} | |||
} | |||
} | |||
@@ -47,15 +45,15 @@ namespace TensorFlowNET.UnitTest | |||
//the core method | |||
void Core(int tid) | |||
{ | |||
tf.peak_default_graph().Should().BeNull(); | |||
Assert.IsNull(tf.peak_default_graph()); | |||
//tf.Session created an other graph | |||
using (var sess = tf.Session()) | |||
{ | |||
var default_graph = tf.peak_default_graph(); | |||
var sess_graph = sess.GetPrivate<Graph>("_graph"); | |||
sess_graph.Should().NotBeNull(); | |||
default_graph.Should().NotBeNull() | |||
.And.BeEquivalentTo(sess_graph); | |||
var sess_graph = sess.graph; | |||
Assert.IsNotNull(default_graph); | |||
Assert.IsNotNull(sess_graph); | |||
Assert.AreEqual(default_graph, sess_graph); | |||
} | |||
} | |||
} | |||
@@ -70,19 +68,18 @@ namespace TensorFlowNET.UnitTest | |||
//the core method | |||
void Core(int tid) | |||
{ | |||
tf.peak_default_graph().Should().BeNull(); | |||
Assert.IsNull(tf.peak_default_graph()); | |||
var beforehand = tf.get_default_graph(); //this should create default automatically. | |||
beforehand.graph_key.Should().NotContain("-0/", "Already created a graph in an other thread."); | |||
beforehand.as_default(); | |||
tf.peak_default_graph().Should().NotBeNull(); | |||
Assert.IsNotNull(tf.peak_default_graph()); | |||
using (var sess = tf.Session()) | |||
{ | |||
var default_graph = tf.peak_default_graph(); | |||
var sess_graph = sess.GetPrivate<Graph>("_graph"); | |||
sess_graph.Should().NotBeNull(); | |||
default_graph.Should().NotBeNull() | |||
.And.BeEquivalentTo(sess_graph); | |||
var sess_graph = sess.graph; | |||
Assert.IsNotNull(default_graph); | |||
Assert.IsNotNull(sess_graph); | |||
Assert.AreEqual(default_graph, sess_graph); | |||
Console.WriteLine($"{tid}-{default_graph.graph_key}"); | |||
@@ -188,7 +185,7 @@ namespace TensorFlowNET.UnitTest | |||
//the core method | |||
void Core(int tid) | |||
{ | |||
tf.peak_default_graph().Should().BeNull(); | |||
Assert.IsNull(tf.peak_default_graph()); | |||
//graph is created automatically to perform create these operations | |||
var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | |||
var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); | |||
@@ -197,7 +194,8 @@ namespace TensorFlowNET.UnitTest | |||
{ | |||
using (var sess = tf.Session()) | |||
{ | |||
sess.run(math).GetAtIndex<float>(0).Should().Be(5); | |||
var result = sess.run(math); | |||
Assert.AreEqual(result.GetAtIndex<float>(0), 5f); | |||
} | |||
} | |||
} | |||
@@ -213,14 +211,14 @@ namespace TensorFlowNET.UnitTest | |||
{ | |||
using (var sess = tf.Session()) | |||
{ | |||
tf.peak_default_graph().Should().NotBeNull(); | |||
Assert.IsNotNull(tf.peak_default_graph()); | |||
//graph is created automatically to perform create these operations | |||
var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | |||
var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); | |||
var math = a1 + a2; | |||
var result = sess.run(math); | |||
result.GetAtIndex<float>(0).Should().Be(5); | |||
Assert.AreEqual(result.GetAtIndex<float>(0), 5f); | |||
} | |||
} | |||
} | |||
@@ -235,7 +233,7 @@ namespace TensorFlowNET.UnitTest | |||
{ | |||
using (var sess = tf.Session()) | |||
{ | |||
tf.peak_default_graph().Should().NotBeNull(); | |||
Assert.IsNotNull(tf.peak_default_graph()); | |||
//graph is created automatically to perform create these operations | |||
var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | |||
var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); | |||
@@ -252,7 +250,7 @@ namespace TensorFlowNET.UnitTest | |||
//the core method | |||
void Core(int tid) | |||
{ | |||
tf.peak_default_graph().Should().BeNull(); | |||
Assert.IsNull(tf.peak_default_graph()); | |||
//graph is created automatically to perform create these operations | |||
var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | |||
var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }); | |||
@@ -268,7 +266,7 @@ namespace TensorFlowNET.UnitTest | |||
//the core method | |||
void Core(int tid) | |||
{ | |||
tf.peak_default_graph().Should().BeNull(); | |||
Assert.IsNull(tf.peak_default_graph()); | |||
//graph is created automatically to perform create these operations | |||
var a1 = tf.constant(new[] { 2f }, shape: new[] { 1 }); | |||
var a2 = tf.constant(new[] { 3f }, shape: new[] { 1 }, name: "ConstantK"); |
@@ -0,0 +1,78 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using Tensorflow; | |||
using static Tensorflow.Binding; | |||
namespace TensorFlowNET.UnitTest.Basics | |||
{ | |||
[TestClass] | |||
public class NameScopeTest : GraphModeTestBase | |||
{ | |||
string name = ""; | |||
[TestMethod] | |||
public void NestedNameScope() | |||
{ | |||
Graph g = tf.Graph().as_default(); | |||
tf_with(new ops.NameScope("scope1"), scope1 => | |||
{ | |||
name = scope1; | |||
Assert.AreEqual("scope1", g._name_stack); | |||
Assert.AreEqual("scope1/", name); | |||
var const1 = tf.constant(1.0); | |||
Assert.AreEqual("scope1/Const:0", const1.name); | |||
tf_with(new ops.NameScope("scope2"), scope2 => | |||
{ | |||
name = scope2; | |||
Assert.AreEqual("scope1/scope2", g._name_stack); | |||
Assert.AreEqual("scope1/scope2/", name); | |||
var const2 = tf.constant(2.0); | |||
Assert.AreEqual("scope1/scope2/Const:0", const2.name); | |||
}); | |||
Assert.AreEqual("scope1", g._name_stack); | |||
var const3 = tf.constant(2.0); | |||
Assert.AreEqual("scope1/Const_1:0", const3.name); | |||
}); | |||
g.Dispose(); | |||
Assert.AreEqual("", g._name_stack); | |||
} | |||
[TestMethod, Ignore("Unimplemented Usage")] | |||
public void NestedNameScope_Using() | |||
{ | |||
Graph g = tf.Graph().as_default(); | |||
using (var name = new ops.NameScope("scope1")) | |||
{ | |||
Assert.AreEqual("scope1", g._name_stack); | |||
Assert.AreEqual("scope1/", name); | |||
var const1 = tf.constant(1.0); | |||
Assert.AreEqual("scope1/Const:0", const1.name); | |||
using (var name2 = new ops.NameScope("scope2")) | |||
{ | |||
Assert.AreEqual("scope1/scope2", g._name_stack); | |||
Assert.AreEqual("scope1/scope2/", name); | |||
var const2 = tf.constant(2.0); | |||
Assert.AreEqual("scope1/scope2/Const:0", const2.name); | |||
} | |||
Assert.AreEqual("scope1", g._name_stack); | |||
var const3 = tf.constant(2.0); | |||
Assert.AreEqual("scope1/Const_1:0", const3.name); | |||
}; | |||
g.Dispose(); | |||
Assert.AreEqual("", g._name_stack); | |||
} | |||
} | |||
} |
@@ -4,8 +4,6 @@ using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow; | |||
using Tensorflow.UnitTest; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Binding; | |||
using Buffer = Tensorflow.Buffer; | |||
@@ -0,0 +1,335 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using Newtonsoft.Json.Linq; | |||
using Tensorflow.NumPy; | |||
using System; | |||
using System.Collections; | |||
using System.Linq; | |||
using Tensorflow; | |||
using static Tensorflow.Binding; | |||
namespace TensorFlowNET.UnitTest | |||
{ | |||
/// <summary> | |||
/// Use as base class for test classes to get additional assertions | |||
/// </summary> | |||
public class PythonTest | |||
{ | |||
#region python compatibility layer | |||
protected PythonTest self { get => this; } | |||
protected int None => -1; | |||
#endregion | |||
#region pytest assertions | |||
public void assertItemsEqual(ICollection given, ICollection expected) | |||
{ | |||
if (given is Hashtable && expected is Hashtable) | |||
{ | |||
Assert.AreEqual(JObject.FromObject(expected).ToString(), JObject.FromObject(given).ToString()); | |||
return; | |||
} | |||
Assert.IsNotNull(expected); | |||
Assert.IsNotNull(given); | |||
var e = expected.OfType<object>().ToArray(); | |||
var g = given.OfType<object>().ToArray(); | |||
Assert.AreEqual(e.Length, g.Length, $"The collections differ in length expected {e.Length} but got {g.Length}"); | |||
for (int i = 0; i < e.Length; i++) | |||
{ | |||
/*if (g[i] is NDArray && e[i] is NDArray) | |||
assertItemsEqual((g[i] as NDArray).GetData<object>(), (e[i] as NDArray).GetData<object>()); | |||
else*/ | |||
if (e[i] is ICollection && g[i] is ICollection) | |||
assertEqual(g[i], e[i]); | |||
else | |||
Assert.AreEqual(e[i], g[i], $"Items differ at index {i}, expected {e[i]} but got {g[i]}"); | |||
} | |||
} | |||
public void assertAllEqual(ICollection given, ICollection expected) | |||
{ | |||
assertItemsEqual(given, expected); | |||
} | |||
public void assertFloat32Equal(float expected, float actual, string msg) | |||
{ | |||
float eps = 1e-6f; | |||
Assert.IsTrue(Math.Abs(expected - actual) < eps * Math.Max(1.0f, Math.Abs(expected)), $"{msg}: expected {expected} vs actual {actual}"); | |||
} | |||
public void assertFloat64Equal(double expected, double actual, string msg) | |||
{ | |||
double eps = 1e-16f; | |||
Assert.IsTrue(Math.Abs(expected - actual) < eps * Math.Max(1.0f, Math.Abs(expected)), $"{msg}: expected {expected} vs actual {actual}"); | |||
} | |||
public void assertEqual(object given, object expected) | |||
{ | |||
/*if (given is NDArray && expected is NDArray) | |||
{ | |||
assertItemsEqual((given as NDArray).GetData<object>(), (expected as NDArray).GetData<object>()); | |||
return; | |||
}*/ | |||
if (given is Hashtable && expected is Hashtable) | |||
{ | |||
Assert.AreEqual(JObject.FromObject(expected).ToString(), JObject.FromObject(given).ToString()); | |||
return; | |||
} | |||
if (given is ICollection && expected is ICollection) | |||
{ | |||
assertItemsEqual(given as ICollection, expected as ICollection); | |||
return; | |||
} | |||
if (given is float && expected is float) | |||
{ | |||
assertFloat32Equal((float)expected, (float)given, ""); | |||
return; | |||
} | |||
if (given is double && expected is double) | |||
{ | |||
assertFloat64Equal((double)expected, (double)given, ""); | |||
return; | |||
} | |||
Assert.AreEqual(expected, given); | |||
} | |||
public void assertEquals(object given, object expected) | |||
{ | |||
assertEqual(given, expected); | |||
} | |||
public void assert(object given) | |||
{ | |||
if (given is bool) | |||
Assert.IsTrue((bool)given); | |||
Assert.IsNotNull(given); | |||
} | |||
public void assertIsNotNone(object given) | |||
{ | |||
Assert.IsNotNull(given); | |||
} | |||
public void assertFalse(bool cond) | |||
{ | |||
Assert.IsFalse(cond); | |||
} | |||
public void assertTrue(bool cond) | |||
{ | |||
Assert.IsTrue(cond); | |||
} | |||
public void assertAllClose(NDArray array1, NDArray array2, double eps = 1e-5) | |||
{ | |||
Assert.IsTrue(np.allclose(array1, array2, rtol: eps)); | |||
} | |||
public void assertAllClose(double value, NDArray array2, double eps = 1e-5) | |||
{ | |||
var array1 = np.ones_like(array2) * value; | |||
Assert.IsTrue(np.allclose(array1, array2, rtol: eps)); | |||
} | |||
public void assertProtoEquals(object toProto, object o) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
#endregion | |||
#region tensor evaluation and test session | |||
//protected object _eval_helper(Tensor[] tensors) | |||
//{ | |||
// if (tensors == null) | |||
// return null; | |||
// return nest.map_structure(self._eval_tensor, tensors); | |||
//} | |||
protected object _eval_tensor(object tensor) | |||
{ | |||
if (tensor == null) | |||
return None; | |||
//else if (callable(tensor)) | |||
// return self._eval_helper(tensor()) | |||
else | |||
{ | |||
try | |||
{ | |||
//TODO: | |||
// if sparse_tensor.is_sparse(tensor): | |||
// return sparse_tensor.SparseTensorValue(tensor.indices, tensor.values, | |||
// tensor.dense_shape) | |||
//return (tensor as Tensor).numpy(); | |||
} | |||
catch (Exception) | |||
{ | |||
throw new ValueError("Unsupported type: " + tensor.GetType()); | |||
} | |||
return null; | |||
} | |||
} | |||
/// <summary> | |||
/// This function is used in many original tensorflow unit tests to evaluate tensors | |||
/// in a test session with special settings (for instance constant folding off) | |||
/// | |||
/// </summary> | |||
public T evaluate<T>(Tensor tensor) | |||
{ | |||
object result = null; | |||
// if context.executing_eagerly(): | |||
// return self._eval_helper(tensors) | |||
// else: | |||
{ | |||
using (var sess = tf.Session()) | |||
{ | |||
var ndarray = tensor.eval(sess); | |||
if (typeof(T) == typeof(double)) | |||
{ | |||
double x = ndarray; | |||
result = x; | |||
} | |||
else if (typeof(T) == typeof(int)) | |||
{ | |||
int x = ndarray; | |||
result = x; | |||
} | |||
else | |||
{ | |||
result = ndarray; | |||
} | |||
} | |||
return (T)result; | |||
} | |||
} | |||
public Session cached_session() | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
//Returns a TensorFlow Session for use in executing tests. | |||
public Session session(Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false) | |||
{ | |||
//Note that this will set this session and the graph as global defaults. | |||
//Use the `use_gpu` and `force_gpu` options to control where ops are run.If | |||
//`force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if | |||
//`use_gpu` is True, TensorFlow tries to run as many ops on the GPU as | |||
//possible.If both `force_gpu and `use_gpu` are False, all ops are pinned to | |||
//the CPU. | |||
//Example: | |||
//```python | |||
//class MyOperatorTest(test_util.TensorFlowTestCase): | |||
// def testMyOperator(self): | |||
// with self.session(use_gpu= True): | |||
// valid_input = [1.0, 2.0, 3.0, 4.0, 5.0] | |||
// result = MyOperator(valid_input).eval() | |||
// self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0] | |||
// invalid_input = [-1.0, 2.0, 7.0] | |||
// with self.assertRaisesOpError("negative input not supported"): | |||
// MyOperator(invalid_input).eval() | |||
//``` | |||
//Args: | |||
// graph: Optional graph to use during the returned session. | |||
// config: An optional config_pb2.ConfigProto to use to configure the | |||
// session. | |||
// use_gpu: If True, attempt to run as many ops as possible on GPU. | |||
// force_gpu: If True, pin all ops to `/device:GPU:0`. | |||
//Yields: | |||
// A Session object that should be used as a context manager to surround | |||
// the graph building and execution code in a test case. | |||
Session s = null; | |||
//if (context.executing_eagerly()) | |||
// yield None | |||
//else | |||
//{ | |||
s = self._create_session(graph, config, force_gpu); | |||
self._constrain_devices_and_set_default(s, use_gpu, force_gpu); | |||
//} | |||
return s.as_default(); | |||
} | |||
private ITensorFlowObject _constrain_devices_and_set_default(Session sess, bool useGpu, bool forceGpu) | |||
{ | |||
//def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu): | |||
//"""Set the session and its graph to global default and constrain devices.""" | |||
//if context.executing_eagerly(): | |||
// yield None | |||
//else: | |||
// with sess.graph.as_default(), sess.as_default(): | |||
// if force_gpu: | |||
// # Use the name of an actual device if one is detected, or | |||
// # '/device:GPU:0' otherwise | |||
// gpu_name = gpu_device_name() | |||
// if not gpu_name: | |||
// gpu_name = "/device:GPU:0" | |||
// with sess.graph.device(gpu_name): | |||
// yield sess | |||
// elif use_gpu: | |||
// yield sess | |||
// else: | |||
// with sess.graph.device("/device:CPU:0"): | |||
// yield sess | |||
return sess; | |||
} | |||
// See session() for details. | |||
private Session _create_session(Graph graph, object cfg, bool forceGpu) | |||
{ | |||
var prepare_config = new Func<object, object>((config) => | |||
{ | |||
// """Returns a config for sessions. | |||
// Args: | |||
// config: An optional config_pb2.ConfigProto to use to configure the | |||
// session. | |||
// Returns: | |||
// A config_pb2.ConfigProto object. | |||
//TODO: config | |||
// # use_gpu=False. Currently many tests rely on the fact that any device | |||
// # will be used even when a specific device is supposed to be used. | |||
// allow_soft_placement = not force_gpu | |||
// if config is None: | |||
// config = config_pb2.ConfigProto() | |||
// config.allow_soft_placement = allow_soft_placement | |||
// config.gpu_options.per_process_gpu_memory_fraction = 0.3 | |||
// elif not allow_soft_placement and config.allow_soft_placement: | |||
// config_copy = config_pb2.ConfigProto() | |||
// config_copy.CopyFrom(config) | |||
// config = config_copy | |||
// config.allow_soft_placement = False | |||
// # Don't perform optimizations for tests so we don't inadvertently run | |||
// # gpu ops on cpu | |||
// config.graph_options.optimizer_options.opt_level = -1 | |||
// # Disable Grappler constant folding since some tests & benchmarks | |||
// # use constant input and become meaningless after constant folding. | |||
// # DO NOT DISABLE GRAPPLER OPTIMIZERS WITHOUT CONSULTING WITH THE | |||
// # GRAPPLER TEAM. | |||
// config.graph_options.rewrite_options.constant_folding = ( | |||
// rewriter_config_pb2.RewriterConfig.OFF) | |||
// config.graph_options.rewrite_options.pin_to_host_optimization = ( | |||
// rewriter_config_pb2.RewriterConfig.OFF) | |||
return config; | |||
}); | |||
//TODO: use this instead of normal session | |||
//return new ErrorLoggingSession(graph = graph, config = prepare_config(config)) | |||
return new Session(graph);//, config = prepare_config(config)) | |||
} | |||
#endregion | |||
public void AssetSequenceEqual<T>(T[] a, T[] b) | |||
{ | |||
Assert.IsTrue(Enumerable.SequenceEqual(a, b)); | |||
} | |||
} | |||
} |
@@ -0,0 +1,36 @@ | |||
<Project Sdk="Microsoft.NET.Sdk"> | |||
<PropertyGroup> | |||
<TargetFramework>net5.0</TargetFramework> | |||
<LangVersion>9.0</LangVersion> | |||
<IsPackable>false</IsPackable> | |||
<AssemblyName>TensorFlowNET.UnitTest</AssemblyName> | |||
<Platforms>AnyCPU;x64</Platforms> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||
<DefineConstants>DEBUG;TRACE</DefineConstants> | |||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> | |||
<DefineConstants>DEBUG;TRACE</DefineConstants> | |||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | |||
</PropertyGroup> | |||
<ItemGroup> | |||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.10.0" /> | |||
<PackageReference Include="MSTest.TestAdapter" Version="2.2.5" /> | |||
<PackageReference Include="MSTest.TestFramework" Version="2.2.5" /> | |||
<PackageReference Include="coverlet.collector" Version="3.0.3"> | |||
<PrivateAssets>all</PrivateAssets> | |||
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> | |||
</PackageReference> | |||
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.6.0-rc0" /> | |||
</ItemGroup> | |||
<ItemGroup> | |||
<ProjectReference Include="..\..\src\TensorFlowNET.Core\Tensorflow.Binding.csproj" /> | |||
</ItemGroup> | |||
</Project> |
@@ -0,0 +1,22 @@ | |||
using System; | |||
using System.IO; | |||
namespace TensorFlowNET.UnitTest | |||
{ | |||
public class TestHelper | |||
{ | |||
public static string GetFullPathFromDataDir(string fileName) | |||
{ | |||
var dataDir = GetRootContentDir(Directory.GetCurrentDirectory()); | |||
return Path.Combine(dataDir, fileName); | |||
} | |||
static string GetRootContentDir(string dir) | |||
{ | |||
var path = Path.GetFullPath(Path.Combine(dir, "data")); | |||
if (Directory.Exists(path)) | |||
return path; | |||
return GetRootContentDir(Path.GetFullPath(Path.Combine(dir, ".."))); | |||
} | |||
} | |||
} |
@@ -30,8 +30,8 @@ namespace TensorFlowNET.UnitTest.Basics | |||
tf.set_random_seed(1234); | |||
var a2 = tf.random_uniform(1); | |||
var b2 = tf.random_shuffle(tf.constant(initValue)); | |||
Assert.AreEqual(a1, a2); | |||
Assert.AreEqual(b1, b2); | |||
Assert.AreEqual(a1.numpy(), a2.numpy()); | |||
Assert.AreEqual(b1.numpy(), b2.numpy()); | |||
} | |||
/// <summary> | |||
@@ -76,8 +76,8 @@ namespace TensorFlowNET.UnitTest.Basics | |||
var a2 = tf.random.normal(1); | |||
var b2 = tf.random.truncated_normal(1); | |||
Assert.AreEqual(a1, a2); | |||
Assert.AreEqual(b1, b2); | |||
Assert.AreEqual(a1.numpy(), a2.numpy()); | |||
Assert.AreEqual(b1.numpy(), b2.numpy()); | |||
} | |||
/// <summary> | |||
@@ -1,67 +0,0 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using Tensorflow.NumPy; | |||
using System; | |||
using Tensorflow; | |||
using static Tensorflow.Binding; | |||
namespace TensorFlowNET.UnitTest.Basics | |||
{ | |||
[TestClass] | |||
public class TensorShapeTest | |||
{ | |||
[TestMethod] | |||
public void Case1() | |||
{ | |||
int a = 2; | |||
int b = 3; | |||
var dims = new[] { Unknown, a, b }; | |||
new TensorShape(dims).GetPrivate<Shape>("shape").Should().BeShaped(-1, 2, 3); | |||
} | |||
[TestMethod] | |||
public void Case2() | |||
{ | |||
int a = 2; | |||
int b = 3; | |||
var dims = new[] { Unknown, a, b }; | |||
//new TensorShape(new[] { dims }).GetPrivate<Shape>("shape").Should().BeShaped(-1, 2, 3); | |||
} | |||
[TestMethod] | |||
public void Case3() | |||
{ | |||
int a = 2; | |||
int b = Unknown; | |||
var dims = new[] { Unknown, a, b }; | |||
//new TensorShape(new[] { dims }).GetPrivate<Shape>("shape").Should().BeShaped(-1, 2, -1); | |||
} | |||
[TestMethod] | |||
public void Case4() | |||
{ | |||
TensorShape shape = (Unknown, Unknown); | |||
shape.GetPrivate<Shape>("shape").Should().BeShaped(-1, -1); | |||
} | |||
[TestMethod] | |||
public void Case5() | |||
{ | |||
TensorShape shape = (1, Unknown, 3); | |||
shape.GetPrivate<Shape>("shape").Should().BeShaped(1, -1, 3); | |||
} | |||
[TestMethod] | |||
public void Case6() | |||
{ | |||
TensorShape shape = (Unknown, 1, 2, 3, Unknown); | |||
shape.GetPrivate<Shape>("shape").Should().BeShaped(-1, 1, 2, 3, -1); | |||
} | |||
[TestMethod] | |||
public void Case7() | |||
{ | |||
TensorShape shape = new TensorShape(); | |||
Assert.AreEqual(shape.rank, -1); | |||
} | |||
} | |||
} |
@@ -34,7 +34,7 @@ namespace TensorFlowNET.UnitTest | |||
using (var sess = tf.Session()) | |||
{ | |||
var default_graph = tf.peak_default_graph(); | |||
var sess_graph = sess.GetPrivate<Graph>("_graph"); | |||
var sess_graph = sess.graph; | |||
sess_graph.Should().NotBeNull(); | |||
default_graph.Should().NotBeNull() | |||
.And.BeEquivalentTo(sess_graph); | |||
@@ -4,7 +4,6 @@ using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow; | |||
using Tensorflow.NumPy; | |||
using Tensorflow.UnitTest; | |||
using static Tensorflow.Binding; | |||
namespace TensorFlowNET.UnitTest.Gradient | |||
@@ -1,93 +1,22 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using Tensorflow; | |||
using Tensorflow.UnitTest; | |||
using static Tensorflow.Binding; | |||
namespace TensorFlowNET.UnitTest.Basics | |||
{ | |||
[TestClass] | |||
public class NameScopeTest : GraphModeTestBase | |||
public class NameScopeTest : EagerModeTestBase | |||
{ | |||
string name = ""; | |||
[TestMethod] | |||
public void NestedNameScope() | |||
{ | |||
Graph g = tf.Graph().as_default(); | |||
tf_with(new ops.NameScope("scope1"), scope1 => | |||
{ | |||
name = scope1; | |||
Assert.AreEqual("scope1", g._name_stack); | |||
Assert.AreEqual("scope1/", name); | |||
var const1 = tf.constant(1.0); | |||
Assert.AreEqual("scope1/Const:0", const1.name); | |||
tf_with(new ops.NameScope("scope2"), scope2 => | |||
{ | |||
name = scope2; | |||
Assert.AreEqual("scope1/scope2", g._name_stack); | |||
Assert.AreEqual("scope1/scope2/", name); | |||
var const2 = tf.constant(2.0); | |||
Assert.AreEqual("scope1/scope2/Const:0", const2.name); | |||
}); | |||
Assert.AreEqual("scope1", g._name_stack); | |||
var const3 = tf.constant(2.0); | |||
Assert.AreEqual("scope1/Const_1:0", const3.name); | |||
}); | |||
g.Dispose(); | |||
Assert.AreEqual("", g._name_stack); | |||
} | |||
[TestMethod] | |||
public void NameScopeInEagerMode() | |||
{ | |||
tf.enable_eager_execution(); | |||
tf_with(new ops.NameScope("scope"), scope => | |||
{ | |||
string name = scope; | |||
var const1 = tf.constant(1.0); | |||
}); | |||
tf.compat.v1.disable_eager_execution(); | |||
} | |||
[TestMethod, Ignore("Unimplemented Usage")] | |||
public void NestedNameScope_Using() | |||
{ | |||
Graph g = tf.Graph().as_default(); | |||
using (var name = new ops.NameScope("scope1")) | |||
{ | |||
Assert.AreEqual("scope1", g._name_stack); | |||
Assert.AreEqual("scope1/", name); | |||
var const1 = tf.constant(1.0); | |||
Assert.AreEqual("scope1/Const:0", const1.name); | |||
using (var name2 = new ops.NameScope("scope2")) | |||
{ | |||
Assert.AreEqual("scope1/scope2", g._name_stack); | |||
Assert.AreEqual("scope1/scope2/", name); | |||
var const2 = tf.constant(2.0); | |||
Assert.AreEqual("scope1/scope2/Const:0", const2.name); | |||
} | |||
Assert.AreEqual("scope1", g._name_stack); | |||
var const3 = tf.constant(2.0); | |||
Assert.AreEqual("scope1/Const_1:0", const3.name); | |||
}; | |||
g.Dispose(); | |||
Assert.AreEqual("", g._name_stack); | |||
} | |||
} | |||
} |
@@ -1,917 +0,0 @@ | |||
// Copyright (c) Microsoft Corporation. All rights reserved. | |||
// Licensed under the MIT license. See LICENSE file in the project root for full license information. | |||
namespace Microsoft.VisualStudio.TestTools.UnitTesting | |||
{ | |||
using System; | |||
//using System.Diagnostics; | |||
//using System.Diagnostics.CodeAnalysis; | |||
using System.Globalization; | |||
using System.Reflection; | |||
/// <summary> | |||
/// This class represents the live NON public INTERNAL object in the system | |||
/// </summary> | |||
internal class PrivateObject | |||
{ | |||
#region Data | |||
// bind everything | |||
private const BindingFlags BindToEveryThing = BindingFlags.Default | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Public; | |||
#pragma warning disable CS0414 // The field 'PrivateObject.constructorFlags' is assigned but its value is never used | |||
private static BindingFlags constructorFlags = BindingFlags.Instance | BindingFlags.Public | BindingFlags.CreateInstance | BindingFlags.NonPublic; | |||
#pragma warning restore CS0414 // The field 'PrivateObject.constructorFlags' is assigned but its value is never used | |||
private object target; // automatically initialized to null | |||
private Type originalType; // automatically initialized to null | |||
//private Dictionary<string, LinkedList<MethodInfo>> methodCache; // automatically initialized to null | |||
#endregion | |||
#region Constructors | |||
///// <summary> | |||
///// Initializes a new instance of the <see cref="PrivateObject"/> class that contains | |||
///// the already existing object of the private class | |||
///// </summary> | |||
///// <param name="obj"> object that serves as starting point to reach the private members</param> | |||
///// <param name="memberToAccess">the derefrencing string using . that points to the object to be retrived as in m_X.m_Y.m_Z</param> | |||
//[SuppressMessage("Microsoft.Naming", "CA1720:IdentifiersShouldNotContainTypeNames", MessageId = "obj", Justification = "We don't know anything about the object other than that it's an object, so 'obj' seems reasonable")] | |||
//public PrivateObject(object obj, string memberToAccess) | |||
//{ | |||
// Helper.CheckParameterNotNull(obj, "obj", string.Empty); | |||
// ValidateAccessString(memberToAccess); | |||
// PrivateObject temp = obj as PrivateObject; | |||
// if (temp == null) | |||
// { | |||
// temp = new PrivateObject(obj); | |||
// } | |||
// // Split The access string | |||
// string[] arr = memberToAccess.Split(new char[] { '.' }); | |||
// for (int i = 0; i < arr.Length; i++) | |||
// { | |||
// object next = temp.InvokeHelper(arr[i], BindToEveryThing | BindingFlags.Instance | BindingFlags.GetField | BindingFlags.GetProperty, null, CultureInfo.InvariantCulture); | |||
// temp = new PrivateObject(next); | |||
// } | |||
// this.target = temp.target; | |||
// this.originalType = temp.originalType; | |||
//} | |||
///// <summary> | |||
///// Initializes a new instance of the <see cref="PrivateObject"/> class that wraps the | |||
///// specified type. | |||
///// </summary> | |||
///// <param name="assemblyName">Name of the assembly</param> | |||
///// <param name="typeName">fully qualified name</param> | |||
///// <param name="args">Argmenets to pass to the constructor</param> | |||
//public PrivateObject(string assemblyName, string typeName, params object[] args) | |||
// : this(assemblyName, typeName, null, args) | |||
//{ | |||
//} | |||
///// <summary> | |||
///// Initializes a new instance of the <see cref="PrivateObject"/> class that wraps the | |||
///// specified type. | |||
///// </summary> | |||
///// <param name="assemblyName">Name of the assembly</param> | |||
///// <param name="typeName">fully qualified name</param> | |||
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the constructor to get</param> | |||
///// <param name="args">Argmenets to pass to the constructor</param> | |||
//public PrivateObject(string assemblyName, string typeName, Type[] parameterTypes, object[] args) | |||
// : this(Type.GetType(string.Format(CultureInfo.InvariantCulture, "{0}, {1}", typeName, assemblyName), false), parameterTypes, args) | |||
//{ | |||
// Helper.CheckParameterNotNull(assemblyName, "assemblyName", string.Empty); | |||
// Helper.CheckParameterNotNull(typeName, "typeName", string.Empty); | |||
//} | |||
///// <summary> | |||
///// Initializes a new instance of the <see cref="PrivateObject"/> class that wraps the | |||
///// specified type. | |||
///// </summary> | |||
///// <param name="type">type of the object to create</param> | |||
///// <param name="args">Argmenets to pass to the constructor</param> | |||
//public PrivateObject(Type type, params object[] args) | |||
// : this(type, null, args) | |||
//{ | |||
// Helper.CheckParameterNotNull(type, "type", string.Empty); | |||
//} | |||
///// <summary> | |||
///// Initializes a new instance of the <see cref="PrivateObject"/> class that wraps the | |||
///// specified type. | |||
///// </summary> | |||
///// <param name="type">type of the object to create</param> | |||
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the constructor to get</param> | |||
///// <param name="args">Argmenets to pass to the constructor</param> | |||
//public PrivateObject(Type type, Type[] parameterTypes, object[] args) | |||
//{ | |||
// Helper.CheckParameterNotNull(type, "type", string.Empty); | |||
// object o; | |||
// if (parameterTypes != null) | |||
// { | |||
// ConstructorInfo ci = type.GetConstructor(BindToEveryThing, null, parameterTypes, null); | |||
// if (ci == null) | |||
// { | |||
// throw new ArgumentException(FrameworkMessages.PrivateAccessorConstructorNotFound); | |||
// } | |||
// try | |||
// { | |||
// o = ci.Invoke(args); | |||
// } | |||
// catch (TargetInvocationException e) | |||
// { | |||
// Debug.Assert(e.InnerException != null, "Inner exception should not be null."); | |||
// if (e.InnerException != null) | |||
// { | |||
// throw e.InnerException; | |||
// } | |||
// throw; | |||
// } | |||
// } | |||
// else | |||
// { | |||
// o = Activator.CreateInstance(type, constructorFlags, null, args, null); | |||
// } | |||
// this.ConstructFrom(o); | |||
//} | |||
/// <summary> | |||
/// Initializes a new instance of the <see cref="PrivateObject"/> class that wraps | |||
/// the given object. | |||
/// </summary> | |||
/// <param name="obj">object to wrap</param> | |||
//[SuppressMessage("Microsoft.Naming", "CA1720:IdentifiersShouldNotContainTypeNames", MessageId = "obj", Justification = "We don't know anything about the object other than that it's an object, so 'obj' seems reasonable")] | |||
public PrivateObject(object obj) | |||
{ | |||
Helper.CheckParameterNotNull(obj, "obj", string.Empty); | |||
this.ConstructFrom(obj); | |||
} | |||
/// <summary> | |||
/// Initializes a new instance of the <see cref="PrivateObject"/> class that wraps | |||
/// the given object. | |||
/// </summary> | |||
/// <param name="obj">object to wrap</param> | |||
/// <param name="type">PrivateType object</param> | |||
//[SuppressMessage("Microsoft.Naming", "CA1720:IdentifiersShouldNotContainTypeNames", MessageId = "obj", Justification = "We don't know anything about the object other than that it's an an object, so 'obj' seems reasonable")] | |||
public PrivateObject(object obj, PrivateType type) | |||
{ | |||
Helper.CheckParameterNotNull(type, "type", string.Empty); | |||
this.target = obj; | |||
this.originalType = type.ReferencedType; | |||
} | |||
#endregion | |||
///// <summary> | |||
///// Gets or sets the target | |||
///// </summary> | |||
//public object Target | |||
//{ | |||
// get | |||
// { | |||
// return this.target; | |||
// } | |||
// set | |||
// { | |||
// Helper.CheckParameterNotNull(value, "Target", string.Empty); | |||
// this.target = value; | |||
// this.originalType = value.GetType(); | |||
// } | |||
//} | |||
///// <summary> | |||
///// Gets the type of underlying object | |||
///// </summary> | |||
//public Type RealType | |||
//{ | |||
// get | |||
// { | |||
// return this.originalType; | |||
// } | |||
//} | |||
//private Dictionary<string, LinkedList<MethodInfo>> GenericMethodCache | |||
//{ | |||
// get | |||
// { | |||
// if (this.methodCache == null) | |||
// { | |||
// this.BuildGenericMethodCacheForType(this.originalType); | |||
// } | |||
// Debug.Assert(this.methodCache != null, "Invalid method cache for type."); | |||
// return this.methodCache; | |||
// } | |||
//} | |||
/// <summary> | |||
/// returns the hash code of the target object | |||
/// </summary> | |||
/// <returns>int representing hashcode of the target object</returns> | |||
public override int GetHashCode() | |||
{ | |||
//Debug.Assert(this.target != null, "target should not be null."); | |||
return this.target.GetHashCode(); | |||
} | |||
/// <summary> | |||
/// Equals | |||
/// </summary> | |||
/// <param name="obj">Object with whom to compare</param> | |||
/// <returns>returns true if the objects are equal.</returns> | |||
public override bool Equals(object obj) | |||
{ | |||
if (this != obj) | |||
{ | |||
//Debug.Assert(this.target != null, "target should not be null."); | |||
if (typeof(PrivateObject) == obj?.GetType()) | |||
{ | |||
return this.target.Equals(((PrivateObject)obj).target); | |||
} | |||
else | |||
{ | |||
return false; | |||
} | |||
} | |||
return true; | |||
} | |||
///// <summary> | |||
///// Invokes the specified method | |||
///// </summary> | |||
///// <param name="name">Name of the method</param> | |||
///// <param name="args">Arguments to pass to the member to invoke.</param> | |||
///// <returns>Result of method call</returns> | |||
//public object Invoke(string name, params object[] args) | |||
//{ | |||
// Helper.CheckParameterNotNull(name, "name", string.Empty); | |||
// return this.Invoke(name, null, args, CultureInfo.InvariantCulture); | |||
//} | |||
///// <summary> | |||
///// Invokes the specified method | |||
///// </summary> | |||
///// <param name="name">Name of the method</param> | |||
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the method to get.</param> | |||
///// <param name="args">Arguments to pass to the member to invoke.</param> | |||
///// <returns>Result of method call</returns> | |||
//public object Invoke(string name, Type[] parameterTypes, object[] args) | |||
//{ | |||
// return this.Invoke(name, parameterTypes, args, CultureInfo.InvariantCulture); | |||
//} | |||
///// <summary> | |||
///// Invokes the specified method | |||
///// </summary> | |||
///// <param name="name">Name of the method</param> | |||
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the method to get.</param> | |||
///// <param name="args">Arguments to pass to the member to invoke.</param> | |||
///// <param name="typeArguments">An array of types corresponding to the types of the generic arguments.</param> | |||
///// <returns>Result of method call</returns> | |||
//public object Invoke(string name, Type[] parameterTypes, object[] args, Type[] typeArguments) | |||
//{ | |||
// return this.Invoke(name, BindToEveryThing, parameterTypes, args, CultureInfo.InvariantCulture, typeArguments); | |||
//} | |||
///// <summary> | |||
///// Invokes the specified method | |||
///// </summary> | |||
///// <param name="name">Name of the method</param> | |||
///// <param name="args">Arguments to pass to the member to invoke.</param> | |||
///// <param name="culture">Culture info</param> | |||
///// <returns>Result of method call</returns> | |||
//public object Invoke(string name, object[] args, CultureInfo culture) | |||
//{ | |||
// return this.Invoke(name, null, args, culture); | |||
//} | |||
///// <summary> | |||
///// Invokes the specified method | |||
///// </summary> | |||
///// <param name="name">Name of the method</param> | |||
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the method to get.</param> | |||
///// <param name="args">Arguments to pass to the member to invoke.</param> | |||
///// <param name="culture">Culture info</param> | |||
///// <returns>Result of method call</returns> | |||
//public object Invoke(string name, Type[] parameterTypes, object[] args, CultureInfo culture) | |||
//{ | |||
// return this.Invoke(name, BindToEveryThing, parameterTypes, args, culture); | |||
//} | |||
///// <summary> | |||
///// Invokes the specified method | |||
///// </summary> | |||
///// <param name="name">Name of the method</param> | |||
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param> | |||
///// <param name="args">Arguments to pass to the member to invoke.</param> | |||
///// <returns>Result of method call</returns> | |||
//public object Invoke(string name, BindingFlags bindingFlags, params object[] args) | |||
//{ | |||
// return this.Invoke(name, bindingFlags, null, args, CultureInfo.InvariantCulture); | |||
//} | |||
///// <summary> | |||
///// Invokes the specified method | |||
///// </summary> | |||
///// <param name="name">Name of the method</param> | |||
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param> | |||
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the method to get.</param> | |||
///// <param name="args">Arguments to pass to the member to invoke.</param> | |||
///// <returns>Result of method call</returns> | |||
//public object Invoke(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args) | |||
//{ | |||
// return this.Invoke(name, bindingFlags, parameterTypes, args, CultureInfo.InvariantCulture); | |||
//} | |||
///// <summary> | |||
///// Invokes the specified method | |||
///// </summary> | |||
///// <param name="name">Name of the method</param> | |||
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param> | |||
///// <param name="args">Arguments to pass to the member to invoke.</param> | |||
///// <param name="culture">Culture info</param> | |||
///// <returns>Result of method call</returns> | |||
//public object Invoke(string name, BindingFlags bindingFlags, object[] args, CultureInfo culture) | |||
//{ | |||
// return this.Invoke(name, bindingFlags, null, args, culture); | |||
//} | |||
///// <summary> | |||
///// Invokes the specified method | |||
///// </summary> | |||
///// <param name="name">Name of the method</param> | |||
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param> | |||
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the method to get.</param> | |||
///// <param name="args">Arguments to pass to the member to invoke.</param> | |||
///// <param name="culture">Culture info</param> | |||
///// <returns>Result of method call</returns> | |||
//public object Invoke(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args, CultureInfo culture) | |||
//{ | |||
// return this.Invoke(name, bindingFlags, parameterTypes, args, culture, null); | |||
//} | |||
///// <summary> | |||
///// Invokes the specified method | |||
///// </summary> | |||
///// <param name="name">Name of the method</param> | |||
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param> | |||
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the method to get.</param> | |||
///// <param name="args">Arguments to pass to the member to invoke.</param> | |||
///// <param name="culture">Culture info</param> | |||
///// <param name="typeArguments">An array of types corresponding to the types of the generic arguments.</param> | |||
///// <returns>Result of method call</returns> | |||
//public object Invoke(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args, CultureInfo culture, Type[] typeArguments) | |||
//{ | |||
// Helper.CheckParameterNotNull(name, "name", string.Empty); | |||
// if (parameterTypes != null) | |||
// { | |||
// bindingFlags |= BindToEveryThing | BindingFlags.Instance; | |||
// // Fix up the parameter types | |||
// MethodInfo member = this.originalType.GetMethod(name, bindingFlags, null, parameterTypes, null); | |||
// // If the method was not found and type arguments were provided for generic paramaters, | |||
// // attempt to look up a generic method. | |||
// if ((member == null) && (typeArguments != null)) | |||
// { | |||
// // This method may contain generic parameters...if so, the previous call to | |||
// // GetMethod() will fail because it doesn't fully support generic parameters. | |||
// // Look in the method cache to see if there is a generic method | |||
// // on the incoming type that contains the correct signature. | |||
// member = this.GetGenericMethodFromCache(name, parameterTypes, typeArguments, bindingFlags, null); | |||
// } | |||
// if (member == null) | |||
// { | |||
// throw new ArgumentException( | |||
// string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name)); | |||
// } | |||
// try | |||
// { | |||
// if (member.IsGenericMethodDefinition) | |||
// { | |||
// MethodInfo constructed = member.MakeGenericMethod(typeArguments); | |||
// return constructed.Invoke(this.target, bindingFlags, null, args, culture); | |||
// } | |||
// else | |||
// { | |||
// return member.Invoke(this.target, bindingFlags, null, args, culture); | |||
// } | |||
// } | |||
// catch (TargetInvocationException e) | |||
// { | |||
// Debug.Assert(e.InnerException != null, "Inner exception should not be null."); | |||
// if (e.InnerException != null) | |||
// { | |||
// throw e.InnerException; | |||
// } | |||
// throw; | |||
// } | |||
// } | |||
// else | |||
// { | |||
// return this.InvokeHelper(name, bindingFlags | BindingFlags.InvokeMethod, args, culture); | |||
// } | |||
//} | |||
///// <summary> | |||
///// Gets the array element using array of subsrcipts for each dimension | |||
///// </summary> | |||
///// <param name="name">Name of the member</param> | |||
///// <param name="indices">the indices of array</param> | |||
///// <returns>An arrya of elements.</returns> | |||
//public object GetArrayElement(string name, params int[] indices) | |||
//{ | |||
// Helper.CheckParameterNotNull(name, "name", string.Empty); | |||
// return this.GetArrayElement(name, BindToEveryThing, indices); | |||
//} | |||
///// <summary> | |||
///// Sets the array element using array of subsrcipts for each dimension | |||
///// </summary> | |||
///// <param name="name">Name of the member</param> | |||
///// <param name="value">Value to set</param> | |||
///// <param name="indices">the indices of array</param> | |||
//public void SetArrayElement(string name, object value, params int[] indices) | |||
//{ | |||
// Helper.CheckParameterNotNull(name, "name", string.Empty); | |||
// this.SetArrayElement(name, BindToEveryThing, value, indices); | |||
//} | |||
///// <summary> | |||
///// Gets the array element using array of subsrcipts for each dimension | |||
///// </summary> | |||
///// <param name="name">Name of the member</param> | |||
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param> | |||
///// <param name="indices">the indices of array</param> | |||
///// <returns>An arrya of elements.</returns> | |||
//public object GetArrayElement(string name, BindingFlags bindingFlags, params int[] indices) | |||
//{ | |||
// Helper.CheckParameterNotNull(name, "name", string.Empty); | |||
// Array arr = (Array)this.InvokeHelper(name, BindingFlags.GetField | bindingFlags, null, CultureInfo.InvariantCulture); | |||
// return arr.GetValue(indices); | |||
//} | |||
///// <summary> | |||
///// Sets the array element using array of subsrcipts for each dimension | |||
///// </summary> | |||
///// <param name="name">Name of the member</param> | |||
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param> | |||
///// <param name="value">Value to set</param> | |||
///// <param name="indices">the indices of array</param> | |||
//public void SetArrayElement(string name, BindingFlags bindingFlags, object value, params int[] indices) | |||
//{ | |||
// Helper.CheckParameterNotNull(name, "name", string.Empty); | |||
// Array arr = (Array)this.InvokeHelper(name, BindingFlags.GetField | bindingFlags, null, CultureInfo.InvariantCulture); | |||
// arr.SetValue(value, indices); | |||
//} | |||
///// <summary> | |||
///// Get the field | |||
///// </summary> | |||
///// <param name="name">Name of the field</param> | |||
///// <returns>The field.</returns> | |||
//public object GetField(string name) | |||
//{ | |||
// Helper.CheckParameterNotNull(name, "name", string.Empty); | |||
// return this.GetField(name, BindToEveryThing); | |||
//} | |||
///// <summary> | |||
///// Sets the field | |||
///// </summary> | |||
///// <param name="name">Name of the field</param> | |||
///// <param name="value">value to set</param> | |||
//public void SetField(string name, object value) | |||
//{ | |||
// Helper.CheckParameterNotNull(name, "name", string.Empty); | |||
// this.SetField(name, BindToEveryThing, value); | |||
//} | |||
///// <summary> | |||
///// Gets the field | |||
///// </summary> | |||
///// <param name="name">Name of the field</param> | |||
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param> | |||
///// <returns>The field.</returns> | |||
//public object GetField(string name, BindingFlags bindingFlags) | |||
//{ | |||
// Helper.CheckParameterNotNull(name, "name", string.Empty); | |||
// return this.InvokeHelper(name, BindingFlags.GetField | bindingFlags, null, CultureInfo.InvariantCulture); | |||
//} | |||
///// <summary> | |||
///// Sets the field | |||
///// </summary> | |||
///// <param name="name">Name of the field</param> | |||
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param> | |||
///// <param name="value">value to set</param> | |||
//public void SetField(string name, BindingFlags bindingFlags, object value) | |||
//{ | |||
// Helper.CheckParameterNotNull(name, "name", string.Empty); | |||
// this.InvokeHelper(name, BindingFlags.SetField | bindingFlags, new object[] { value }, CultureInfo.InvariantCulture); | |||
//} | |||
/// <summary> | |||
/// Get the field or property | |||
/// </summary> | |||
/// <param name="name">Name of the field or property</param> | |||
/// <returns>The field or property.</returns> | |||
public object GetFieldOrProperty(string name) | |||
{ | |||
Helper.CheckParameterNotNull(name, "name", string.Empty); | |||
return this.GetFieldOrProperty(name, BindToEveryThing); | |||
} | |||
/// <summary> | |||
/// Sets the field or property | |||
/// </summary> | |||
/// <param name="name">Name of the field or property</param> | |||
/// <param name="value">value to set</param> | |||
public void SetFieldOrProperty(string name, object value) | |||
{ | |||
Helper.CheckParameterNotNull(name, "name", string.Empty); | |||
this.SetFieldOrProperty(name, BindToEveryThing, value); | |||
} | |||
/// <summary> | |||
/// Gets the field or property | |||
/// </summary> | |||
/// <param name="name">Name of the field or property</param> | |||
/// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param> | |||
/// <returns>The field or property.</returns> | |||
public object GetFieldOrProperty(string name, BindingFlags bindingFlags) | |||
{ | |||
Helper.CheckParameterNotNull(name, "name", string.Empty); | |||
return this.InvokeHelper(name, BindingFlags.GetField | BindingFlags.GetProperty | bindingFlags, null, CultureInfo.InvariantCulture); | |||
} | |||
/// <summary> | |||
/// Sets the field or property | |||
/// </summary> | |||
/// <param name="name">Name of the field or property</param> | |||
/// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param> | |||
/// <param name="value">value to set</param> | |||
public void SetFieldOrProperty(string name, BindingFlags bindingFlags, object value) | |||
{ | |||
Helper.CheckParameterNotNull(name, "name", string.Empty); | |||
this.InvokeHelper(name, BindingFlags.SetField | BindingFlags.SetProperty | bindingFlags, new object[] { value }, CultureInfo.InvariantCulture); | |||
} | |||
///// <summary> | |||
///// Gets the property | |||
///// </summary> | |||
///// <param name="name">Name of the property</param> | |||
///// <param name="args">Arguments to pass to the member to invoke.</param> | |||
///// <returns>The property.</returns> | |||
//public object GetProperty(string name, params object[] args) | |||
//{ | |||
// return this.GetProperty(name, null, args); | |||
//} | |||
///// <summary> | |||
///// Gets the property | |||
///// </summary> | |||
///// <param name="name">Name of the property</param> | |||
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the indexed property.</param> | |||
///// <param name="args">Arguments to pass to the member to invoke.</param> | |||
///// <returns>The property.</returns> | |||
//public object GetProperty(string name, Type[] parameterTypes, object[] args) | |||
//{ | |||
// return this.GetProperty(name, BindToEveryThing, parameterTypes, args); | |||
//} | |||
///// <summary> | |||
///// Set the property | |||
///// </summary> | |||
///// <param name="name">Name of the property</param> | |||
///// <param name="value">value to set</param> | |||
///// <param name="args">Arguments to pass to the member to invoke.</param> | |||
//public void SetProperty(string name, object value, params object[] args) | |||
//{ | |||
// this.SetProperty(name, null, value, args); | |||
//} | |||
///// <summary> | |||
///// Set the property | |||
///// </summary> | |||
///// <param name="name">Name of the property</param> | |||
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the indexed property.</param> | |||
///// <param name="value">value to set</param> | |||
///// <param name="args">Arguments to pass to the member to invoke.</param> | |||
//public void SetProperty(string name, Type[] parameterTypes, object value, object[] args) | |||
//{ | |||
// this.SetProperty(name, BindToEveryThing, value, parameterTypes, args); | |||
//} | |||
///// <summary> | |||
///// Gets the property | |||
///// </summary> | |||
///// <param name="name">Name of the property</param> | |||
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param> | |||
///// <param name="args">Arguments to pass to the member to invoke.</param> | |||
///// <returns>The property.</returns> | |||
//public object GetProperty(string name, BindingFlags bindingFlags, params object[] args) | |||
//{ | |||
// return this.GetProperty(name, bindingFlags, null, args); | |||
//} | |||
///// <summary> | |||
///// Gets the property | |||
///// </summary> | |||
///// <param name="name">Name of the property</param> | |||
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param> | |||
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the indexed property.</param> | |||
///// <param name="args">Arguments to pass to the member to invoke.</param> | |||
///// <returns>The property.</returns> | |||
//public object GetProperty(string name, BindingFlags bindingFlags, Type[] parameterTypes, object[] args) | |||
//{ | |||
// Helper.CheckParameterNotNull(name, "name", string.Empty); | |||
// if (parameterTypes != null) | |||
// { | |||
// PropertyInfo pi = this.originalType.GetProperty(name, bindingFlags, null, null, parameterTypes, null); | |||
// if (pi == null) | |||
// { | |||
// throw new ArgumentException( | |||
// string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name)); | |||
// } | |||
// return pi.GetValue(this.target, args); | |||
// } | |||
// else | |||
// { | |||
// return this.InvokeHelper(name, bindingFlags | BindingFlags.GetProperty, args, null); | |||
// } | |||
//} | |||
///// <summary> | |||
///// Sets the property | |||
///// </summary> | |||
///// <param name="name">Name of the property</param> | |||
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param> | |||
///// <param name="value">value to set</param> | |||
///// <param name="args">Arguments to pass to the member to invoke.</param> | |||
//public void SetProperty(string name, BindingFlags bindingFlags, object value, params object[] args) | |||
//{ | |||
// this.SetProperty(name, bindingFlags, value, null, args); | |||
//} | |||
///// <summary> | |||
///// Sets the property | |||
///// </summary> | |||
///// <param name="name">Name of the property</param> | |||
///// <param name="bindingFlags">A bitmask comprised of one or more <see cref="T:System.Reflection.BindingFlags"/> that specify how the search is conducted.</param> | |||
///// <param name="value">value to set</param> | |||
///// <param name="parameterTypes">An array of <see cref="T:System.Type"/> objects representing the number, order, and type of the parameters for the indexed property.</param> | |||
///// <param name="args">Arguments to pass to the member to invoke.</param> | |||
//public void SetProperty(string name, BindingFlags bindingFlags, object value, Type[] parameterTypes, object[] args) | |||
//{ | |||
// Helper.CheckParameterNotNull(name, "name", string.Empty); | |||
// if (parameterTypes != null) | |||
// { | |||
// PropertyInfo pi = this.originalType.GetProperty(name, bindingFlags, null, null, parameterTypes, null); | |||
// if (pi == null) | |||
// { | |||
// throw new ArgumentException( | |||
// string.Format(CultureInfo.CurrentCulture, FrameworkMessages.PrivateAccessorMemberNotFound, name)); | |||
// } | |||
// pi.SetValue(this.target, value, args); | |||
// } | |||
// else | |||
// { | |||
// object[] pass = new object[(args?.Length ?? 0) + 1]; | |||
// pass[0] = value; | |||
// args?.CopyTo(pass, 1); | |||
// this.InvokeHelper(name, bindingFlags | BindingFlags.SetProperty, pass, null); | |||
// } | |||
//} | |||
#region Private Helpers | |||
///// <summary> | |||
///// Validate access string | |||
///// </summary> | |||
///// <param name="access"> access string</param> | |||
//private static void ValidateAccessString(string access) | |||
//{ | |||
// Helper.CheckParameterNotNull(access, "access", string.Empty); | |||
// if (access.Length == 0) | |||
// { | |||
// throw new ArgumentException(FrameworkMessages.AccessStringInvalidSyntax); | |||
// } | |||
// string[] arr = access.Split('.'); | |||
// foreach (string str in arr) | |||
// { | |||
// if ((str.Length == 0) || (str.IndexOfAny(new char[] { ' ', '\t', '\n' }) != -1)) | |||
// { | |||
// throw new ArgumentException(FrameworkMessages.AccessStringInvalidSyntax); | |||
// } | |||
// } | |||
//} | |||
/// <summary> | |||
/// Invokes the memeber | |||
/// </summary> | |||
/// <param name="name">Name of the member</param> | |||
/// <param name="bindingFlags">Additional attributes</param> | |||
/// <param name="args">Arguments for the invocation</param> | |||
/// <param name="culture">Culture</param> | |||
/// <returns>Result of the invocation</returns> | |||
private object InvokeHelper(string name, BindingFlags bindingFlags, object[] args, CultureInfo culture) | |||
{ | |||
Helper.CheckParameterNotNull(name, "name", string.Empty); | |||
//Debug.Assert(this.target != null, "Internal Error: Null reference is returned for internal object"); | |||
// Invoke the actual Method | |||
try | |||
{ | |||
return this.originalType.InvokeMember(name, bindingFlags, null, this.target, args, culture); | |||
} | |||
catch (TargetInvocationException e) | |||
{ | |||
//Debug.Assert(e.InnerException != null, "Inner exception should not be null."); | |||
if (e.InnerException != null) | |||
{ | |||
throw e.InnerException; | |||
} | |||
throw; | |||
} | |||
} | |||
private void ConstructFrom(object obj) | |||
{ | |||
Helper.CheckParameterNotNull(obj, "obj", string.Empty); | |||
this.target = obj; | |||
this.originalType = obj.GetType(); | |||
} | |||
//private void BuildGenericMethodCacheForType(Type t) | |||
//{ | |||
// Debug.Assert(t != null, "type should not be null."); | |||
// this.methodCache = new Dictionary<string, LinkedList<MethodInfo>>(); | |||
// MethodInfo[] members = t.GetMethods(BindToEveryThing); | |||
// LinkedList<MethodInfo> listByName; // automatically initialized to null | |||
// foreach (MethodInfo member in members) | |||
// { | |||
// if (member.IsGenericMethod || member.IsGenericMethodDefinition) | |||
// { | |||
// if (!this.GenericMethodCache.TryGetValue(member.Name, out listByName)) | |||
// { | |||
// listByName = new LinkedList<MethodInfo>(); | |||
// this.GenericMethodCache.Add(member.Name, listByName); | |||
// } | |||
// Debug.Assert(listByName != null, "list should not be null."); | |||
// listByName.AddLast(member); | |||
// } | |||
// } | |||
//} | |||
///// <summary> | |||
///// Extracts the most appropriate generic method signature from the current private type. | |||
///// </summary> | |||
///// <param name="methodName">The name of the method in which to search the signature cache.</param> | |||
///// <param name="parameterTypes">An array of types corresponding to the types of the parameters in which to search.</param> | |||
///// <param name="typeArguments">An array of types corresponding to the types of the generic arguments.</param> | |||
///// <param name="bindingFlags"><see cref="BindingFlags"/> to further filter the method signatures.</param> | |||
///// <param name="modifiers">Modifiers for parameters.</param> | |||
///// <returns>A methodinfo instance.</returns> | |||
//private MethodInfo GetGenericMethodFromCache(string methodName, Type[] parameterTypes, Type[] typeArguments, BindingFlags bindingFlags, ParameterModifier[] modifiers) | |||
//{ | |||
// Debug.Assert(!string.IsNullOrEmpty(methodName), "Invalid method name."); | |||
// Debug.Assert(parameterTypes != null, "Invalid parameter type array."); | |||
// Debug.Assert(typeArguments != null, "Invalid type arguments array."); | |||
// // Build a preliminary list of method candidates that contain roughly the same signature. | |||
// var methodCandidates = this.GetMethodCandidates(methodName, parameterTypes, typeArguments, bindingFlags, modifiers); | |||
// // Search of ambiguous methods (methods with the same signature). | |||
// MethodInfo[] finalCandidates = new MethodInfo[methodCandidates.Count]; | |||
// methodCandidates.CopyTo(finalCandidates, 0); | |||
// if ((parameterTypes != null) && (parameterTypes.Length == 0)) | |||
// { | |||
// for (int i = 0; i < finalCandidates.Length; i++) | |||
// { | |||
// MethodInfo methodInfo = finalCandidates[i]; | |||
// if (!RuntimeTypeHelper.CompareMethodSigAndName(methodInfo, finalCandidates[0])) | |||
// { | |||
// throw new AmbiguousMatchException(); | |||
// } | |||
// } | |||
// // All the methods have the exact same name and sig so return the most derived one. | |||
// return RuntimeTypeHelper.FindMostDerivedNewSlotMeth(finalCandidates, finalCandidates.Length) as MethodInfo; | |||
// } | |||
// // Now that we have a preliminary list of candidates, select the most appropriate one. | |||
// return RuntimeTypeHelper.SelectMethod(bindingFlags, finalCandidates, parameterTypes, modifiers) as MethodInfo; | |||
//} | |||
//private LinkedList<MethodInfo> GetMethodCandidates(string methodName, Type[] parameterTypes, Type[] typeArguments, BindingFlags bindingFlags, ParameterModifier[] modifiers) | |||
//{ | |||
// Debug.Assert(!string.IsNullOrEmpty(methodName), "methodName should not be null."); | |||
// Debug.Assert(parameterTypes != null, "parameterTypes should not be null."); | |||
// Debug.Assert(typeArguments != null, "typeArguments should not be null."); | |||
// LinkedList<MethodInfo> methodCandidates = new LinkedList<MethodInfo>(); | |||
// LinkedList<MethodInfo> methods = null; | |||
// if (!this.GenericMethodCache.TryGetValue(methodName, out methods)) | |||
// { | |||
// return methodCandidates; | |||
// } | |||
// Debug.Assert(methods != null, "methods should not be null."); | |||
// foreach (MethodInfo candidate in methods) | |||
// { | |||
// bool paramMatch = true; | |||
// ParameterInfo[] candidateParams = null; | |||
// Type[] genericArgs = candidate.GetGenericArguments(); | |||
// Type sourceParameterType = null; | |||
// if (genericArgs.Length != typeArguments.Length) | |||
// { | |||
// continue; | |||
// } | |||
// // Since we can't just get the correct MethodInfo from Reflection, | |||
// // we will just match the number of parameters, their order, and their type | |||
// var methodCandidate = candidate; | |||
// candidateParams = methodCandidate.GetParameters(); | |||
// if (candidateParams.Length != parameterTypes.Length) | |||
// { | |||
// continue; | |||
// } | |||
// // Exact binding | |||
// if ((bindingFlags & BindingFlags.ExactBinding) != 0) | |||
// { | |||
// int i = 0; | |||
// foreach (ParameterInfo candidateParam in candidateParams) | |||
// { | |||
// sourceParameterType = parameterTypes[i++]; | |||
// if (candidateParam.ParameterType.ContainsGenericParameters) | |||
// { | |||
// // Since we have a generic parameter here, just make sure the IsArray matches. | |||
// if (candidateParam.ParameterType.IsArray != sourceParameterType.IsArray) | |||
// { | |||
// paramMatch = false; | |||
// break; | |||
// } | |||
// } | |||
// else | |||
// { | |||
// if (candidateParam.ParameterType != sourceParameterType) | |||
// { | |||
// paramMatch = false; | |||
// break; | |||
// } | |||
// } | |||
// } | |||
// if (paramMatch) | |||
// { | |||
// methodCandidates.AddLast(methodCandidate); | |||
// continue; | |||
// } | |||
// } | |||
// else | |||
// { | |||
// methodCandidates.AddLast(methodCandidate); | |||
// } | |||
// } | |||
// return methodCandidates; | |||
//} | |||
#endregion | |||
} | |||
} |
@@ -1,314 +0,0 @@ | |||
// <copyright file="PrivateObjectExtensions.cs"> | |||
// Copyright (c) 2019 cactuaroid All Rights Reserved | |||
// </copyright> | |||
// <summary> | |||
// Released under the MIT license | |||
// https://github.com/cactuaroid/PrivateObjectExtensions | |||
// </summary> | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using System.Linq; | |||
using System.Reflection; | |||
namespace System | |||
{ | |||
/// <summary> | |||
/// Extension methods for PrivateObject | |||
/// </summary> | |||
public static class PrivateObjectExtensions | |||
{ | |||
private static readonly BindingFlags Static = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly | BindingFlags.Static; | |||
private static readonly BindingFlags Instance = BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly | BindingFlags.Instance; | |||
/// <summary> | |||
/// Get from private (and any other) field/property. | |||
/// If the real type of specified object doesn't contain the specified field/property, | |||
/// base types are searched automatically. | |||
/// </summary> | |||
/// <param name="obj">The object to get from</param> | |||
/// <param name="name">The name of the field/property</param> | |||
/// <returns>The object got from the field/property</returns> | |||
/// <exception cref="ArgumentException">'name' is not found.</exception> | |||
/// <exception cref="ArgumentNullException">Arguments contain null.</exception> | |||
public static object GetPrivate(this object obj, string name) | |||
{ | |||
if (obj == null) { throw new ArgumentNullException("obj"); } | |||
return GetPrivate(obj, name, obj.GetType(), null); | |||
} | |||
/// <summary> | |||
/// Get from private (and any other) field/property. | |||
/// If the real type of specified object doesn't contain the specified field/property, | |||
/// base types are searched automatically. | |||
/// </summary> | |||
/// <typeparam name="T">The type of the field/property</typeparam> | |||
/// <param name="obj">The object to get from</param> | |||
/// <param name="name">The name of the field/property</param> | |||
/// <returns>The object got from the field/property</returns> | |||
/// <exception cref="ArgumentException">'name' is not found.</exception> | |||
/// <exception cref="ArgumentNullException">Arguments contain null.</exception> | |||
public static T GetPrivate<T>(this object obj, string name) | |||
{ | |||
if (obj == null) { throw new ArgumentNullException("obj"); } | |||
return (T)GetPrivate(obj, name, obj.GetType(), typeof(T)); | |||
} | |||
/// <summary> | |||
/// Get from private (and any other) field/property with assuming the specified object as specified type. | |||
/// If the specified type doesn't contain the specified field/property, | |||
/// base types are searched automatically. | |||
/// </summary> | |||
/// <param name="obj">The object to get from</param> | |||
/// <param name="name">The name of the field/property</param> | |||
/// <param name="objType">The type of 'obj' for seaching member starting from. Real type of 'obj' is ignored.</param> | |||
/// <returns>The object got from the field/property</returns> | |||
/// <exception cref="ArgumentException">'name' is not found.</exception> | |||
/// <exception cref="ArgumentException">'objType' is not assignable from 'obj'.</exception> | |||
/// <exception cref="ArgumentNullException">Arguments contain null.</exception> | |||
public static object GetPrivate(this object obj, string name, Type objType) | |||
{ | |||
return GetPrivate(obj, name, objType, null); | |||
} | |||
/// <summary> | |||
/// Get from private (and any other) field/property with assuming the specified object as specified type. | |||
/// If the specified type doesn't contain the specified field/property, | |||
/// base types are searched automatically. | |||
/// </summary> | |||
/// <typeparam name="T">The type of the field/property</typeparam> | |||
/// <param name="obj">The object to get from</param> | |||
/// <param name="name">The name of the field/property</param> | |||
/// <param name="objType">The type of 'obj' for seaching member starting from. Real type of 'obj' is ignored.</param> | |||
/// <returns>The object got from the field/property</returns> | |||
/// <exception cref="ArgumentException">'name' is not found.</exception> | |||
/// <exception cref="ArgumentException">'objType' is not assignable from 'obj'.</exception> | |||
/// <exception cref="ArgumentNullException">Arguments contain null.</exception> | |||
public static T GetPrivate<T>(this object obj, string name, Type objType) | |||
{ | |||
return (T)GetPrivate(obj, name, objType, typeof(T)); | |||
} | |||
private static object GetPrivate(object obj, string name, Type objType, Type memberType) | |||
{ | |||
if (obj == null) { throw new ArgumentNullException("obj"); } | |||
if (name == null) { throw new ArgumentNullException("name"); } | |||
if (string.IsNullOrWhiteSpace(name)) { throw new ArgumentException("name is empty or white-space.", "name"); } | |||
if (objType == null) { throw new ArgumentNullException("objType"); } | |||
if (!objType.IsAssignableFrom(obj.GetType())) { throw new ArgumentException($"{objType} is not assignable from {obj.GetType()}.", "objType"); } | |||
bool memberTypeMatching(Type actualType) => actualType == memberType; | |||
if (TryFindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, Instance, out var ownerType)) | |||
{ | |||
return new PrivateObject(obj, new PrivateType(ownerType)).GetFieldOrProperty(name); | |||
} | |||
else if (TryFindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, Static, out ownerType)) | |||
{ | |||
return new PrivateType(ownerType).GetStaticFieldOrProperty(name); | |||
} | |||
throw new ArgumentException(((memberType != null) ? memberType + " " : "") + name + " is not found."); | |||
} | |||
/// <summary> | |||
/// Get from private (and any other) static field/property. | |||
/// </summary> | |||
/// <param name="type">The type to get from</param> | |||
/// <param name="name">The name of the static field/property</param> | |||
/// <returns>The object got from the static field/property</returns> | |||
/// <exception cref="ArgumentException">'name' is not found.</exception> | |||
/// <exception cref="ArgumentNullException">Arguments contain null.</exception> | |||
public static object GetPrivate(this Type type, string name) | |||
{ | |||
return GetPrivate(type, name, null); | |||
} | |||
/// <summary> | |||
/// Get from private (and any other) static field/property. | |||
/// </summary> | |||
/// <typeparam name="T">The type of the field/property</typeparam> | |||
/// <param name="type">The type to get from</param> | |||
/// <param name="name">The name of the static field/property</param> | |||
/// <returns>The object got from the static field/property</returns> | |||
/// <exception cref="ArgumentException">'name' is not found.</exception> | |||
/// <exception cref="ArgumentNullException">Arguments contain null.</exception> | |||
public static T GetPrivate<T>(this Type type, string name) | |||
{ | |||
return (T)GetPrivate(type, name, typeof(T)); | |||
} | |||
private static object GetPrivate(this Type type, string name, Type memberType) | |||
{ | |||
if (type == null) { throw new ArgumentNullException("type"); } | |||
if (name == null) { throw new ArgumentNullException("name"); } | |||
if (string.IsNullOrWhiteSpace(name)) { throw new ArgumentException("name is empty or white-space.", "name"); } | |||
bool memberTypeMatching(Type actualType) => actualType == memberType; | |||
if (type.ContainsFieldOrProperty(name, memberType, memberTypeMatching, Static)) | |||
{ | |||
return new PrivateType(type).GetStaticFieldOrProperty(name); | |||
} | |||
throw new ArgumentException(((memberType != null) ? memberType + " " : "") + name + " is not found."); | |||
} | |||
/// <summary> | |||
/// Set to private (and any other) field/property. | |||
/// If the real type of specified object doesn't contain the specified field/property, | |||
/// base types are searched automatically. | |||
/// </summary> | |||
/// <param name="obj">The object to set to</param> | |||
/// <param name="name">The name of the field/property</param> | |||
/// <param name="value">The value to set for 'name'</param> | |||
/// <exception cref="ArgumentException">'name' is not found.</exception> | |||
/// <exception cref="ArgumentNullException">Arguments contain null.</exception> | |||
public static void SetPrivate<T>(this object obj, string name, T value) | |||
{ | |||
if (obj == null) { throw new ArgumentNullException("obj"); } | |||
SetPrivate(obj, name, value, obj.GetType()); | |||
} | |||
/// <summary> | |||
/// Set to private (and any other) field/property with assuming the specified object as specified type. | |||
/// If the specified type doesn't contain the specified field/property, | |||
/// base types are searched automatically. | |||
/// </summary> | |||
/// <param name="obj">The object to set to</param> | |||
/// <param name="name">The name of the field/property</param> | |||
/// <param name="value">The value to set for 'name'</param> | |||
/// <param name="objType">The type of 'obj' for seaching member starting from. Real type of 'obj' is ignored.</param> | |||
/// <exception cref="ArgumentException">'name' is not found.</exception> | |||
/// <exception cref="ArgumentException">'objType' is not assignable from 'obj'.</exception> | |||
/// <exception cref="ArgumentNullException">Arguments contain null.</exception> | |||
public static void SetPrivate<T>(this object obj, string name, T value, Type objType) | |||
{ | |||
if (obj == null) { throw new ArgumentNullException("obj"); } | |||
if (name == null) { throw new ArgumentNullException("name"); } | |||
if (string.IsNullOrWhiteSpace(name)) { throw new ArgumentException("name is empty or white-space.", "name"); } | |||
if (value == null) { throw new ArgumentNullException("value"); } | |||
if (objType == null) { throw new ArgumentNullException("objType"); } | |||
if (!objType.IsAssignableFrom(obj.GetType())) { throw new ArgumentException($"{objType} is not assignable from {obj.GetType()}.", "objType"); } | |||
if (TrySetPrivate(obj, name, value, objType)) { return; } | |||
// retry for the case of getter only property | |||
if (TrySetPrivate(obj, GetBackingFieldName(name), value, objType)) { return; } | |||
throw new ArgumentException($"{typeof(T)} {name} is not found."); | |||
} | |||
private static bool TrySetPrivate<T>(object obj, string name, T value, Type objType) | |||
{ | |||
var memberType = typeof(T); | |||
bool memberTypeMatching(Type actualType) => actualType.IsAssignableFrom(memberType); | |||
try | |||
{ | |||
if (TryFindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, Instance, out var ownerType)) | |||
{ | |||
new PrivateObject(obj, new PrivateType(ownerType)).SetFieldOrProperty(name, value); | |||
return true; | |||
} | |||
else if (TryFindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, Static, out ownerType)) | |||
{ | |||
new PrivateType(ownerType).SetStaticFieldOrProperty(name, value); | |||
return true; | |||
} | |||
} | |||
catch (MissingMethodException) | |||
{ | |||
// When getter only property name is given, the property is found but fails to set. | |||
return false; | |||
} | |||
return false; | |||
} | |||
/// <summary> | |||
/// Set to private (and any other) static field/property. | |||
/// </summary> | |||
/// <param name="type">The type to set to</param> | |||
/// <param name="name">The name of the field/property</param> | |||
/// <param name="value">The value to set for 'name'</param> | |||
/// <exception cref="ArgumentException">'name' is not found.</exception> | |||
/// <exception cref="ArgumentNullException">Arguments contain null.</exception> | |||
public static void SetPrivate<T>(this Type type, string name, T value) | |||
{ | |||
if (type == null) { throw new ArgumentNullException("type"); } | |||
if (name == null) { throw new ArgumentNullException("name"); } | |||
if (string.IsNullOrWhiteSpace(name)) { throw new ArgumentException("name is empty or white-space.", "name"); } | |||
if (TrySetPrivate(type, name, value)) { return; } | |||
// retry for the case of getter only property | |||
if (TrySetPrivate(type, GetBackingFieldName(name), value)) { return; } | |||
throw new ArgumentException($"{typeof(T)} {name} is not found."); | |||
} | |||
private static bool TrySetPrivate<T>(this Type type, string name, T value) | |||
{ | |||
var memberType = typeof(T); | |||
bool memberTypeMatching(Type actualType) => actualType.IsAssignableFrom(memberType); | |||
try | |||
{ | |||
if (type.ContainsFieldOrProperty(name, memberType, memberTypeMatching, Static)) | |||
{ | |||
new PrivateType(type).SetStaticFieldOrProperty(name, value); | |||
return true; | |||
} | |||
} | |||
catch (MissingMethodException) | |||
{ | |||
// When getter only property name is given, the property is found but fails to set. | |||
return false; | |||
} | |||
return false; | |||
} | |||
private static string GetBackingFieldName(string propertyName) | |||
=> $"<{propertyName}>k__BackingField"; // generated backing field name | |||
private static bool TryFindFieldOrPropertyOwnerType(Type objType, string name, Type memberType, Func<Type, bool> memberTypeMatching, BindingFlags bindingFlag, out Type ownerType) | |||
{ | |||
ownerType = FindFieldOrPropertyOwnerType(objType, name, memberType, memberTypeMatching, bindingFlag); | |||
return (ownerType != null); | |||
} | |||
private static Type FindFieldOrPropertyOwnerType(Type objectType, string name, Type memberType, Func<Type, bool> memberTypeMatching, BindingFlags bindingFlags) | |||
{ | |||
if (objectType == null) { return null; } | |||
if (objectType.ContainsFieldOrProperty(name, memberType, memberTypeMatching, bindingFlags)) | |||
{ | |||
return objectType; | |||
} | |||
return FindFieldOrPropertyOwnerType(objectType.BaseType, name, memberType, memberTypeMatching, bindingFlags); | |||
} | |||
private static bool ContainsFieldOrProperty(this Type objectType, string name, Type memberType, Func<Type, bool> memberTypeMatching, BindingFlags bindingFlags) | |||
{ | |||
var fields = objectType | |||
.GetFields(bindingFlags) | |||
.Select((x) => new { Type = x.FieldType, Member = x as MemberInfo }); | |||
var properties = objectType | |||
.GetProperties(bindingFlags) | |||
.Select((x) => new { Type = x.PropertyType, Member = x as MemberInfo }); | |||
var members = fields.Concat(properties); | |||
return members.Any((actual) => | |||
(memberType == null || memberTypeMatching.Invoke(actual.Type)) | |||
&& actual.Member.Name == name); | |||
} | |||
} | |||
} |
@@ -1,172 +0,0 @@ | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||
{ | |||
/// <summary> | |||
/// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py | |||
/// </summary> | |||
[TestClass] | |||
public class SwitchTestCase : PythonTest | |||
{ | |||
[Ignore("TODO")] | |||
[TestMethod] | |||
public void testResourceReadInLoop() | |||
{ | |||
//var embedding_matrix = variable_scope.get_variable( | |||
//"embedding_matrix", initializer: new double[,] { { 2.0 }, { 3.0 } }, use_resource: true); | |||
/* | |||
Tensor cond(Tensor it, Tensor _) | |||
{ | |||
return it < 5; | |||
} | |||
*/ | |||
// TODO: below code doesn't compile | |||
//(Tensor, Tensor) body(Tensor it, Tensor cost) | |||
//{ | |||
// var embedding = embedding_ops.embedding_lookup(embedding_matrix, new int[]{0}); | |||
// cost += math_ops.reduce_sum(embedding); | |||
// return (it + 1, cost); | |||
//} | |||
//var (_, cost1) = control_flow_ops.while_loop( | |||
// cond, body, new[] | |||
// { | |||
// constant_op.constant(0), | |||
// constant_op.constant(0.0) | |||
// }); | |||
//with<Session>(this.cached_session(), sess => | |||
//{ | |||
// self.evaluate(variables.global_variables_initializer()); | |||
// self.assertAllEqual(10.0, self.evaluate(cost1)); | |||
//}); | |||
} | |||
[Ignore("TODO")] | |||
[TestMethod] | |||
public void testIndexedSlicesGradientInCondInWhileLoop() | |||
{ | |||
doTestIndexedSlicesGradientInCondInWhileLoop(use_resource: false); | |||
} | |||
[Ignore("TODO")] | |||
[TestMethod] | |||
public void testIndexedSlicesGradientInCondInWhileLoopResource() | |||
{ | |||
doTestIndexedSlicesGradientInCondInWhileLoop(use_resource: true); | |||
} | |||
private void doTestIndexedSlicesGradientInCondInWhileLoop(bool use_resource = false) | |||
{ | |||
//def doTestIndexedSlicesGradientInCondInWhileLoop(self, use_resource=False): | |||
// embedding_matrix = variable_scope.get_variable( | |||
// "embedding_matrix", [5, 5], | |||
// initializer=init_ops.random_normal_initializer(), | |||
// use_resource=use_resource) | |||
// def cond(it, _): | |||
// return it < 5 | |||
// def body(it, cost): | |||
// embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) | |||
// cost = control_flow_ops.cond( | |||
// math_ops.equal(it, 3), lambda: math_ops.square(cost), | |||
// (lambda: cost + math_ops.reduce_sum(embedding))) | |||
// return it + 1, cost | |||
// _, cost = control_flow_ops.while_loop( | |||
// cond, body, [constant_op.constant(0), | |||
// constant_op.constant(0.0)]) | |||
// dynamic_grads = gradients_impl.gradients(cost, [embedding_matrix])[0] | |||
// dynamic_grads = math_ops.segment_sum(dynamic_grads.values, | |||
// dynamic_grads.indices) | |||
// embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) | |||
// static = math_ops.square( | |||
// math_ops.reduce_sum(embedding) + math_ops.reduce_sum(embedding) + | |||
// math_ops.reduce_sum(embedding)) + math_ops.reduce_sum(embedding) | |||
// static_grads = gradients_impl.gradients(static, [embedding_matrix])[0] | |||
// static_grads = math_ops.segment_sum(static_grads.values, | |||
// static_grads.indices) | |||
// with self.cached_session(): | |||
// self.evaluate(variables.global_variables_initializer()) | |||
// self.assertAllEqual(*self.evaluate([static_grads, dynamic_grads])) | |||
} | |||
[Ignore("TODO")] | |||
[TestMethod] | |||
public void testIndexedSlicesWithShapeGradientInWhileLoop() | |||
{ | |||
//@test_util.run_v1_only("b/120545219") | |||
//def testIndexedSlicesWithShapeGradientInWhileLoop(self): | |||
// for dtype in [dtypes.float32, dtypes.float64]: | |||
// with self.cached_session() as sess: | |||
// num_steps = 9 | |||
// inputs = array_ops.placeholder(dtype=dtype, shape=[num_steps]) | |||
// initial_outputs = tensor_array_ops.TensorArray( | |||
// dtype=dtype, size=num_steps) | |||
// initial_i = constant_op.constant(0, dtype=dtypes.int32) | |||
// def cond(i, _): | |||
// return i < num_steps # pylint: disable=cell-var-from-loop | |||
// def body(i, outputs): | |||
// x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop | |||
// outputs = outputs.write(i, x) | |||
// return i + 1, outputs | |||
// _, outputs = control_flow_ops.while_loop(cond, body, | |||
// [initial_i, initial_outputs]) | |||
// outputs = math_ops.reduce_sum(outputs.stack()) | |||
// r = gradients_impl.gradients([outputs], [inputs])[0] | |||
// grad_wr_inputs = ops.convert_to_tensor(r) | |||
// o, grad = sess.run([outputs, grad_wr_inputs], | |||
// feed_dict={inputs: [4, 6, 0, 7, 0, 0, 1, 2, 0]}) | |||
// self.assertEquals(o, 20) | |||
// self.assertAllEqual(grad, [1] * num_steps) | |||
} | |||
[Ignore("TODO")] | |||
[TestMethod] | |||
public void testIndexedSlicesWithDynamicShapeGradientInWhileLoop() | |||
{ | |||
//@test_util.run_v1_only("b/120545219") | |||
//def testIndexedSlicesWithDynamicShapeGradientInWhileLoop(self): | |||
// for dtype in [dtypes.float32, dtypes.float64]: | |||
// with self.cached_session() as sess: | |||
// inputs = array_ops.placeholder(dtype=dtype) | |||
// initial_outputs = tensor_array_ops.TensorArray( | |||
// dtype=dtype, dynamic_size=True, size=1) | |||
// initial_i = constant_op.constant(0, dtype=dtypes.int32) | |||
// def cond(i, _): | |||
// return i < array_ops.size(inputs) # pylint: disable=cell-var-from-loop | |||
// def body(i, outputs): | |||
// x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop | |||
// outputs = outputs.write(i, x) | |||
// return i + 1, outputs | |||
// _, outputs = control_flow_ops.while_loop(cond, body, | |||
// [initial_i, initial_outputs]) | |||
// outputs = math_ops.reduce_sum(outputs.stack()) | |||
// r = gradients_impl.gradients([outputs], [inputs])[0] | |||
// grad_wr_inputs = ops.convert_to_tensor(r) | |||
// o, grad = sess.run([outputs, grad_wr_inputs], | |||
// feed_dict={inputs: [1, 3, 2]}) | |||
// self.assertEquals(o, 6) | |||
// self.assertAllEqual(grad, [1] * 3) | |||
} | |||
} | |||
} |