Browse Source

ported while and switch/case testcases

tags/v0.9
Meinrad Recheis 6 years ago
parent
commit
b66e2586b8
9 changed files with 130 additions and 28 deletions
  1. +10
    -4
      src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
  2. +1
    -0
      src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs
  3. +9
    -0
      src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs
  4. +8
    -3
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
  5. +6
    -0
      src/TensorFlowNET.Core/Variables/variable_scope.py.cs
  6. +5
    -0
      src/TensorFlowNET.Core/Variables/variables.py.cs
  7. +12
    -1
      test/TensorFlowNET.UnitTest/PythonTest.cs
  8. +30
    -20
      test/TensorFlowNET.UnitTest/control_flow_ops_test/SwitchTestCase.cs
  9. +49
    -0
      test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs

+ 10
- 4
src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs View File

@@ -119,13 +119,13 @@ namespace Tensorflow.Operations
return null;
}

/// <summary>
/// Notifies a scope about an operator added to an inner scope.
/// </summary>
/// <summary>
/// Notifies a scope about an operator added to an inner scope.
/// </summary>
/// <param name="op"></param>
public virtual void AddInnerOp(Operation op)
{
if (_outer_context != null)
if (_outer_context != null)
_outer_context.AddInnerOp(op);
}

@@ -164,6 +164,12 @@ namespace Tensorflow.Operations
var internal_control_inputs = op.control_inputs;
}

public object to_proto()
{
throw new NotImplementedException();
}


public void Dispose()
{
}


+ 1
- 0
src/TensorFlowNET.Core/Operations/ControlFlows/IControlFlowContext.cs View File

@@ -11,5 +11,6 @@ namespace Tensorflow
HashSet<string> values { get; }
Tensor AddValue(Tensor val);
void AddInnerOp(Operation resultOp);
object to_proto();
}
}

+ 9
- 0
src/TensorFlowNET.Core/Operations/ControlFlows/WhileContext.cs View File

@@ -6,5 +6,14 @@ namespace Tensorflow.Operations
{
public class WhileContext : ControlFlowContext
{
public static WhileContext from_proto(object proto)
{
throw new NotImplementedException();
}

public object to_proto()
{
throw new NotImplementedException();
}
}
}

+ 8
- 3
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -490,8 +490,13 @@ namespace Tensorflow
}

throw new NotImplementedException("ZerosLikeOutsideLoop");
}
}

// TODO
public static void while_loop(Func<Tensor, Tensor> func, Func<Tensor, Tensor> func1, Tensor[] tensors, int? i)
{
throw new NotImplementedException();
}

}
}

+ 6
- 0
src/TensorFlowNET.Core/Variables/variable_scope.py.cs View File

@@ -240,5 +240,11 @@ namespace Tensorflow
if (_current_name_scope != null)
_current_name_scope.Dispose();
}

// TODO for Switch/Case
public static RefVariable get_variable(string embeddingMatrix, double[,] initializer, bool use_resource)
{
throw new NotImplementedException();
}
}
}

+ 5
- 0
src/TensorFlowNET.Core/Variables/variables.py.cs View File

@@ -67,5 +67,10 @@ namespace Tensorflow
else
return gen_control_flow_ops.no_op(name: name);
}

public static Tensor global_variables_initializer()
{
throw new NotImplementedException();
}
}
}

+ 12
- 1
test/TensorFlowNET.UnitTest/PythonTest.cs View File

@@ -95,9 +95,15 @@ namespace TensorFlowNET.UnitTest
Assert.IsTrue(cond);
}
public void assertProtoEquals(object toProto, object o)
{
throw new NotImplementedException();
}
#endregion
#region tensor evaluation
#region tensor evaluation and test session
protected object _eval_helper(Tensor[] tensors)
{
@@ -166,6 +172,11 @@ namespace TensorFlowNET.UnitTest
}
protected Session cached_session()
{
throw new NotImplementedException();
}
//Returns a TensorFlow Session for use in executing tests.
public Session session(Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false)
{


+ 30
- 20
test/TensorFlowNET.UnitTest/control_flow_ops_test/SwitchTestCase.cs View File

@@ -1,4 +1,5 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
namespace TensorFlowNET.UnitTest.control_flow_ops_test
@@ -14,24 +15,33 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
[TestMethod]
public void testResourceReadInLoop()
{
//def testResourceReadInLoop(self):
// embedding_matrix = variable_scope.get_variable(
// "embedding_matrix", initializer=[[2.0], [3.0]], use_resource=True)
//
// def cond(it, _):
// return it < 5
//
// def body(it, cost):
// embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
// cost += math_ops.reduce_sum(embedding)
// return it + 1, cost
//
// _, cost = control_flow_ops.while_loop(
// cond, body, [constant_op.constant(0),
// constant_op.constant(0.0)])
// with self.cached_session():
// self.evaluate(variables.global_variables_initializer())
// self.assertAllEqual(10.0, self.evaluate(cost))
var embedding_matrix = variable_scope.get_variable(
"embedding_matrix", initializer: new double[,] { { 2.0 }, { 3.0 } }, use_resource: true);
Tensor cond(Tensor it, Tensor _)
{
return it < 5;
}
// TODO: below code doesn't compile
//(Tensor, Tensor) body(Tensor it, Tensor cost)
//{
// var embedding = embedding_ops.embedding_lookup(embedding_matrix, new int[]{0});
// cost += math_ops.reduce_sum(embedding);
// return (it + 1, cost);
//}
//var (_, cost1) = control_flow_ops.while_loop(
// cond, body, new[]
// {
// constant_op.constant(0),
// constant_op.constant(0.0)
// });
//with<Session>(this.cached_session(), sess =>
//{
// self.evaluate(variables.global_variables_initializer());
// self.assertAllEqual(10.0, self.evaluate(cost1));
//});
}
@@ -49,7 +59,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
doTestIndexedSlicesGradientInCondInWhileLoop(use_resource: true);
}
private void doTestIndexedSlicesGradientInCondInWhileLoop(bool use_resource= false)
private void doTestIndexedSlicesGradientInCondInWhileLoop(bool use_resource = false)
{
//def doTestIndexedSlicesGradientInCondInWhileLoop(self, use_resource=False):
// embedding_matrix = variable_scope.get_variable(


+ 49
- 0
test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs View File

@@ -0,0 +1,49 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
using Tensorflow.Operations;
namespace TensorFlowNET.UnitTest.control_flow_ops_test
{
[TestClass]
public class WhileContextTestCase : PythonTest
{
private void _testWhileContextHelper(int? maximum_iterations = null)
{
// TODO: implement missing code dependencies
with<Session>(this.cached_session(), sess =>
{
var i = constant_op.constant(0, name: "i");
var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, 10, name: "c"));
var b = new Func<Tensor, Tensor>(x => gen_math_ops.add(x, 1, name: "c"));
control_flow_ops.while_loop(
c, b, new[] { i }, maximum_iterations = maximum_iterations);
foreach (Operation op in sess.graph.get_operations())
{
var control_flow_context = op._get_control_flow_context();
if (control_flow_context != null)
self.assertProtoEquals(control_flow_context.to_proto(),
WhileContext.from_proto(
control_flow_context.to_proto()).to_proto());
}
});
}
[Ignore("TODO")]
[TestMethod]
public void testWhileContext()
{
_testWhileContextHelper();
}
[Ignore("TODO")]
[TestMethod]
public void testWhileContextWithMaximumIterations()
{
_testWhileContextHelper(maximum_iterations: 10);
}
}
}

Loading…
Cancel
Save