Browse Source

add disable_eager_execution, clean unit test.

tags/v0.20
Oceania2018 5 years ago
parent
commit
ba8b0f37af
22 changed files with 121 additions and 123 deletions
  1. +1
    -1
      src/TensorFlowNET.Console/Program.cs
  2. +30
    -0
      src/TensorFlowNET.Core/APIs/tf.compat.cs
  3. +30
    -0
      src/TensorFlowNET.Core/APIs/tf.compat.v1.cs
  4. +4
    -3
      src/TensorFlowNET.Core/Graphs/Graph.cs
  5. +2
    -2
      src/TensorFlowNET.Core/Operations/Operation.Control.cs
  6. +2
    -2
      src/TensorFlowNET.Core/Operations/Operation.cs
  7. +0
    -6
      src/TensorFlowNET.Core/Tensors/c_api.tensor.cs
  8. +3
    -3
      src/TensorFlowNET.Core/ops.cs
  9. +2
    -2
      test/TensorFlowNET.UnitTest/Basics/QueueTest.cs
  10. +0
    -2
      test/TensorFlowNET.UnitTest/Basics/VariableTest.cs
  11. +24
    -0
      test/TensorFlowNET.UnitTest/GraphModeTestBase.cs
  12. +2
    -2
      test/TensorFlowNET.UnitTest/NameScopeTest.cs
  13. +0
    -0
      test/TensorFlowNET.UnitTest/NativeAPI/GraphTest.cs
  14. +3
    -2
      test/TensorFlowNET.UnitTest/OperationsTest.cs
  15. +1
    -17
      test/TensorFlowNET.UnitTest/TF_API/ConstantTest.cs
  16. +5
    -2
      test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs
  17. +2
    -2
      test/TensorFlowNET.UnitTest/control_flow_ops_test/ShapeTestCase.cs
  18. +3
    -3
      test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs
  19. +2
    -3
      test/TensorFlowNET.UnitTest/img_test/TestCrop.cs
  20. +2
    -2
      test/TensorFlowNET.UnitTest/layers_test/flatten.cs
  21. +2
    -68
      test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs
  22. +1
    -1
      test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs

+ 1
- 1
src/TensorFlowNET.Console/Program.cs View File

@@ -27,7 +27,7 @@ namespace Tensorflow
// 100K gradient 44M. // 100K gradient 44M.
mm.Execute(10, 10 * batchSize, cases.Gradient); mm.Execute(10, 10 * batchSize, cases.Gradient);


// 120M
// 95M
Console.WriteLine("Finished."); Console.WriteLine("Finished.");
Console.ReadLine(); Console.ReadLine();
} }


+ 30
- 0
src/TensorFlowNET.Core/APIs/tf.compat.cs View File

@@ -0,0 +1,30 @@
/*****************************************************************************
Copyright 2020 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using NumSharp;

namespace Tensorflow
{
public partial class tensorflow
{
public CompatApi compat { get; } = new CompatApi();

public class CompatApi
{
public CompatV1Api v1 { get; } = new CompatV1Api();
}
}
}

+ 30
- 0
src/TensorFlowNET.Core/APIs/tf.compat.v1.cs View File

@@ -0,0 +1,30 @@
/*****************************************************************************
Copyright 2020 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using Tensorflow.Eager;
using static Tensorflow.Binding;

namespace Tensorflow
{
public class CompatV1Api
{
public void disable_eager_execution()
{
tf.context.default_execution_mode = Context.GRAPH_MODE;
}
}
}

+ 4
- 3
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -259,7 +259,8 @@ namespace Tensorflow


public Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, public Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes,
TF_DataType[] input_types = null, string name = null, TF_DataType[] input_types = null, string name = null,
Dictionary<string, AttrValue> attrs = null, OpDef op_def = null)
Dictionary<string, AttrValue> attrs = null, OpDef op_def = null,
bool compute_device = true)
{ {
if (inputs == null) if (inputs == null)
inputs = new Tensor[0]; inputs = new Tensor[0];
@@ -270,7 +271,7 @@ namespace Tensorflow
// If a names ends with a '/' it is a "name scope" and we use it as-is, // If a names ends with a '/' it is a "name scope" and we use it as-is,
// after removing the trailing '/'. // after removing the trailing '/'.
name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name); name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name);
var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs);
var node_def = ops._NodeDef(op_type, name, attrs: attrs);


var input_ops = inputs.Select(x => x.op).ToArray(); var input_ops = inputs.Select(x => x.op).ToArray();
var control_inputs = _control_dependencies_for_inputs(input_ops); var control_inputs = _control_dependencies_for_inputs(input_ops);
@@ -284,7 +285,7 @@ namespace Tensorflow
original_op: null, original_op: null,
op_def: op_def); op_def: op_def);


_create_op_helper(op, true);
_create_op_helper(op, compute_device);


/*Console.Write($"create_op: {op_type} '{node_def.Name}'"); /*Console.Write($"create_op: {op_type} '{node_def.Name}'");
Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}"); Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}");


+ 2
- 2
src/TensorFlowNET.Core/Operations/Operation.Control.cs View File

@@ -40,8 +40,8 @@ namespace Tensorflow


public void _add_control_input(Operation op) public void _add_control_input(Operation op)
{ {
//c_api.TF_AddControlInput(_operDesc, op);
c_api.AddControlInput(graph, _handle, op);
c_api.TF_AddControlInput(OpDesc, op);
//c_api.AddControlInput(graph, _handle, op);
} }


public void _add_control_inputs(Operation[] ops) public void _add_control_inputs(Operation[] ops)


+ 2
- 2
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -64,7 +64,7 @@ namespace Tensorflow
public string Device => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); public string Device => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationDevice(_handle));


bool _is_stateful; bool _is_stateful;
public OperationDescription OpDesc { get; set; }


public NodeDef node_def public NodeDef node_def
{ {
@@ -170,7 +170,7 @@ namespace Tensorflow
op_def = g.GetOpDef(node_def.Op); op_def = g.GetOpDef(node_def.Op);


var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr);
_handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());
(_handle, OpDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray());
_is_stateful = op_def.IsStateful; _is_stateful = op_def.IsStateful;


// Initialize self._outputs. // Initialize self._outputs.


+ 0
- 6
src/TensorFlowNET.Core/Tensors/c_api.tensor.cs View File

@@ -187,9 +187,6 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern unsafe ulong TF_StringEncode(byte* src, ulong src_len, sbyte* dst, ulong dst_len, SafeStatusHandle status); public static extern unsafe ulong TF_StringEncode(byte* src, ulong src_len, sbyte* dst, ulong dst_len, SafeStatusHandle status);


[DllImport(TensorFlowLibName)]
public static extern unsafe ulong TF_StringEncode(IntPtr src, ulong src_len, IntPtr dst, ulong dst_len, SafeStatusHandle status);

/// <summary> /// <summary>
/// Decode a string encoded using TF_StringEncode. /// Decode a string encoded using TF_StringEncode.
/// </summary> /// </summary>
@@ -199,9 +196,6 @@ namespace Tensorflow
/// <param name="dst_len">size_t*</param> /// <param name="dst_len">size_t*</param>
/// <param name="status">TF_Status*</param> /// <param name="status">TF_Status*</param>
/// <returns></returns> /// <returns></returns>
[DllImport(TensorFlowLibName)]
public static extern ulong TF_StringDecode(IntPtr src, ulong src_len, IntPtr dst, ref ulong dst_len, SafeStatusHandle status);

[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern unsafe ulong TF_StringDecode(byte* src, ulong src_len, byte** dst, ref ulong dst_len, SafeStatusHandle status); public static extern unsafe ulong TF_StringDecode(byte* src, ulong src_len, byte** dst, ref ulong dst_len, SafeStatusHandle status);




+ 3
- 3
src/TensorFlowNET.Core/ops.cs View File

@@ -155,7 +155,7 @@ namespace Tensorflow
/// </param> /// </param>
/// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param> /// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param>
/// <returns>A wrapped TF_Operation*.</returns> /// <returns>A wrapped TF_Operation*.</returns>
public static IntPtr _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs)
public static (IntPtr, OperationDescription) _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs)
{ {
lock (Locks.ProcessWide) lock (Locks.ProcessWide)
{ {
@@ -198,7 +198,7 @@ namespace Tensorflow


status.Check(true); status.Check(true);


return c_op;
return (c_op, op_desc);
} }
} }


@@ -207,7 +207,7 @@ namespace Tensorflow
return graph.GetOpDef(type); return graph.GetOpDef(type);
} }


public static NodeDef _NodeDef(string op_type, string name, string device = "", Dictionary<string, AttrValue> attrs = null)
public static NodeDef _NodeDef(string op_type, string name, Dictionary<string, AttrValue> attrs = null)
{ {
var node_def = new NodeDef(); var node_def = new NodeDef();
node_def.Op = op_type; node_def.Op = op_type;


+ 2
- 2
test/TensorFlowNET.UnitTest/Basics/QueueTest.cs View File

@@ -4,13 +4,13 @@ using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using Tensorflow; using Tensorflow;
using Tensorflow.UnitTest;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace TensorFlowNET.UnitTest.Basics namespace TensorFlowNET.UnitTest.Basics
{ {
[Ignore]
[TestClass] [TestClass]
public class QueueTest
public class QueueTest : GraphModeTestBase
{ {
[TestMethod] [TestMethod]
public void PaddingFIFOQueue() public void PaddingFIFOQueue()


+ 0
- 2
test/TensorFlowNET.UnitTest/Basics/VariableTest.cs View File

@@ -10,7 +10,6 @@ namespace TensorFlowNET.UnitTest.Basics
[TestClass] [TestClass]
public class VariableTest public class VariableTest
{ {
[Ignore]
[TestMethod] [TestMethod]
public void NewVariable() public void NewVariable()
{ {
@@ -34,7 +33,6 @@ namespace TensorFlowNET.UnitTest.Basics
Assert.AreEqual(4, (int)y.numpy()); Assert.AreEqual(4, (int)y.numpy());
} }


[Ignore]
[TestMethod] [TestMethod]
public void Assign1() public void Assign1()
{ {


+ 24
- 0
test/TensorFlowNET.UnitTest/GraphModeTestBase.cs View File

@@ -0,0 +1,24 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Text;
using TensorFlowNET.UnitTest;
using static Tensorflow.Binding;

namespace Tensorflow.UnitTest
{
public class GraphModeTestBase : PythonTest
{
[TestInitialize]
public void TestInit()
{
tf.compat.v1.disable_eager_execution();
}

[TestCleanup]
public void TestClean()
{
tf.enable_eager_execution();
}
}
}

+ 2
- 2
test/TensorFlowNET.UnitTest/NameScopeTest.cs View File

@@ -1,16 +1,16 @@
using System; using System;
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow; using Tensorflow;
using Tensorflow.UnitTest;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace TensorFlowNET.UnitTest.Basics namespace TensorFlowNET.UnitTest.Basics
{ {
[TestClass] [TestClass]
public class NameScopeTest
public class NameScopeTest : GraphModeTestBase
{ {
string name = ""; string name = "";


[Ignore]
[TestMethod] [TestMethod]
public void NestedNameScope() public void NestedNameScope()
{ {


test/TensorFlowNET.UnitTest/GraphTest.cs → test/TensorFlowNET.UnitTest/NativeAPI/GraphTest.cs View File


+ 3
- 2
test/TensorFlowNET.UnitTest/OperationsTest.cs View File

@@ -7,12 +7,12 @@ using Tensorflow;
using Tensorflow.Util; using Tensorflow.Util;
using Buffer = Tensorflow.Buffer; using Buffer = Tensorflow.Buffer;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using Tensorflow.UnitTest;


namespace TensorFlowNET.UnitTest.Basics namespace TensorFlowNET.UnitTest.Basics
{ {
[Ignore]
[TestClass] [TestClass]
public class OperationsTest
public class OperationsTest : GraphModeTestBase
{ {
/// <summary> /// <summary>
/// Port from tensorflow\c\c_api_test.cc /// Port from tensorflow\c\c_api_test.cc
@@ -726,6 +726,7 @@ namespace TensorFlowNET.UnitTest.Basics
#endregion #endregion
} }


[Ignore]
[TestMethod] [TestMethod]
public void divOpTests() public void divOpTests()
{ {


test/TensorFlowNET.UnitTest/ConstantTest.cs → test/TensorFlowNET.UnitTest/TF_API/ConstantTest.cs View File

@@ -3,6 +3,7 @@ using NumSharp;
using System; using System;
using System.Linq; using System.Linq;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Text;
using Tensorflow; using Tensorflow;
using static Tensorflow.Binding; using static Tensorflow.Binding;


@@ -160,23 +161,6 @@ namespace TensorFlowNET.UnitTest.Basics
Assert.AreEqual(6.0, (double)c); Assert.AreEqual(6.0, (double)c);
} }


[TestMethod]
public void StringEncode()
{
string str = "Hello, TensorFlow.NET!";
var handle = Marshal.StringToHGlobalAnsi(str);
var dst_len = c_api.TF_StringEncodedSize((ulong)str.Length);
Assert.AreEqual(dst_len, (ulong)23);
IntPtr dst = Marshal.AllocHGlobal((int)dst_len);
var encoded_len = c_api.TF_StringEncode(handle, (ulong)str.Length, dst, dst_len, status.Handle);
Assert.AreEqual((ulong)23, encoded_len);
Assert.AreEqual(status.Code, TF_Code.TF_OK);
string encoded_str = Marshal.PtrToStringUTF8(dst + sizeof(byte));
Assert.AreEqual(encoded_str, str);
Assert.AreEqual(str.Length, Marshal.ReadByte(dst));
// c_api.TF_StringDecode(dst, (ulong)str.Length, IntPtr.Zero, ref dst_len, status.Handle);
}

[TestMethod] [TestMethod]
public void Reshape() public void Reshape()
{ {

+ 5
- 2
test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs View File

@@ -1,5 +1,6 @@
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow; using Tensorflow;
using Tensorflow.UnitTest;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace TensorFlowNET.UnitTest.control_flow_ops_test namespace TensorFlowNET.UnitTest.control_flow_ops_test
@@ -7,10 +8,10 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
/// <summary> /// <summary>
/// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py
/// </summary> /// </summary>
[Ignore]
[TestClass] [TestClass]
public class CondTestCases : PythonTest
public class CondTestCases : GraphModeTestBase
{ {
[Ignore("Dependent on UpdateEdge")]
[TestMethod] [TestMethod]
public void testCondTrue_ConstOnly() public void testCondTrue_ConstOnly()
{ {
@@ -49,6 +50,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
} }
} }


[Ignore("Dependent on UpdateEdge")]
[TestMethod] [TestMethod]
public void testCondTrue() public void testCondTrue()
{ {
@@ -65,6 +67,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
assertEquals(result, 34); assertEquals(result, 34);
} }


[Ignore("Dependent on UpdateEdge")]
[TestMethod] [TestMethod]
public void testCondFalse() public void testCondFalse()
{ {


+ 2
- 2
test/TensorFlowNET.UnitTest/control_flow_ops_test/ShapeTestCase.cs View File

@@ -1,14 +1,14 @@
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow; using Tensorflow;
using Tensorflow.UnitTest;


namespace TensorFlowNET.UnitTest.control_flow_ops_test namespace TensorFlowNET.UnitTest.control_flow_ops_test
{ {
/// <summary> /// <summary>
/// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py
/// </summary> /// </summary>
[Ignore]
[TestClass] [TestClass]
public class ShapeTestCase : PythonTest
public class ShapeTestCase : GraphModeTestBase
{ {


[TestMethod] [TestMethod]


+ 3
- 3
test/TensorFlowNET.UnitTest/control_flow_ops_test/WhileContextTestCase.cs View File

@@ -1,24 +1,24 @@
using System; using System;
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow; using Tensorflow;
using Tensorflow.UnitTest;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace TensorFlowNET.UnitTest.control_flow_ops_test namespace TensorFlowNET.UnitTest.control_flow_ops_test
{ {
[TestClass] [TestClass]
public class WhileContextTestCase : PythonTest
public class WhileContextTestCase : GraphModeTestBase
{ {
/// <summary> /// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/while_loop /// https://www.tensorflow.org/api_docs/python/tf/while_loop
/// </summary> /// </summary>
[Ignore]
[TestMethod] [TestMethod]
public void SimpleWhileLoop() public void SimpleWhileLoop()
{ {
var i = constant_op.constant(0, name: "i"); var i = constant_op.constant(0, name: "i");
var c = new Func<Tensor, Tensor>(x => tf.less(x, 10, name: "c")); var c = new Func<Tensor, Tensor>(x => tf.less(x, 10, name: "c"));
var b = new Func<Tensor, Tensor>(x => tf.add(x, 1, name: "c")); var b = new Func<Tensor, Tensor>(x => tf.add(x, 1, name: "c"));
//var r = control_flow_ops.while_loop(c, b, i);
// var r = control_flow_ops.while_loop(c, b, i);
} }


private void _testWhileContextHelper(int maximum_iterations) private void _testWhileContextHelper(int maximum_iterations)


+ 2
- 3
test/TensorFlowNET.UnitTest/img_test/TestCrop.cs View File

@@ -2,15 +2,14 @@
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp; using NumSharp;
using Tensorflow; using Tensorflow;
using Tensorflow.UnitTest;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace TensorFlowNET.UnitTest.img_test namespace TensorFlowNET.UnitTest.img_test
{ {
[Ignore]
[TestClass] [TestClass]
public class TestCrop
public class TestCrop : GraphModeTestBase
{ {

[TestMethod] [TestMethod]
public void TestCropAndResize() public void TestCropAndResize()
{ {


+ 2
- 2
test/TensorFlowNET.UnitTest/layers_test/flatten.cs View File

@@ -3,13 +3,13 @@ using FluentAssertions;
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp; using NumSharp;
using Tensorflow; using Tensorflow;
using Tensorflow.UnitTest;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace TensorFlowNET.UnitTest.layers_test namespace TensorFlowNET.UnitTest.layers_test
{ {
[Ignore]
[TestClass] [TestClass]
public class flatten
public class flatten : GraphModeTestBase
{ {
[TestMethod] [TestMethod]
public void Case1() public void Case1()


+ 2
- 68
test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs View File

@@ -3,6 +3,7 @@ using System.Linq;
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow; using Tensorflow;
using Tensorflow.Eager; using Tensorflow.Eager;
using Tensorflow.UnitTest;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace TensorFlowNET.UnitTest.ops_test namespace TensorFlowNET.UnitTest.ops_test
@@ -10,9 +11,8 @@ namespace TensorFlowNET.UnitTest.ops_test
/// <summary> /// <summary>
/// excerpt of tensorflow/python/framework/ops_test.py /// excerpt of tensorflow/python/framework/ops_test.py
/// </summary> /// </summary>
[Ignore]
[TestClass] [TestClass]
public class ControlDependenciesTest : PythonTest
public class ControlDependenciesTest : GraphModeTestBase
{ {
[TestMethod] [TestMethod]
public void TestBasic() public void TestBasic()
@@ -35,72 +35,6 @@ namespace TensorFlowNET.UnitTest.ops_test
Assert.AreEqual(0, e.op.control_inputs.Length); Assert.AreEqual(0, e.op.control_inputs.Length);
} }


[Ignore("Future is not supported yet")]
[TestMethod]
public void TestEager()
{
Tensor a = null, c = null;
object b = null;
var calls = 0;
Func<Tensor> future = () =>
{
calls += 1;
return constant_op.constant(2.0);
};
using (var opts = new ContextOptions())
using (var status = new Status())
using (var context = new Context(opts, status))
{
if (context.executing_eagerly())
{
// TODO: make this compile (see original Python code below)
a = constant_op.constant(1.0);
b = future; // <--- {henon} obviously, this doesn't compile, looks like control_dependencies needs to be able to take callables as well.
tf_with(ops.control_dependencies(new object[] { a, b }), ctrl =>
{
return c = constant_op.constant(3.0);
});
Assert.AreEqual(calls, 1);
}
else
{
var g = tf.Graph().as_default();
a = constant_op.constant(1.0);
var b1 = future();
tf_with(g.control_dependencies(new[] { a, b }), ctrl =>
{
c = constant_op.constant(3.0);
});
Assert.IsTrue(Enumerable.SequenceEqual(c.op.control_inputs, new[] { a.op, b1.op }));
Assert.AreEqual(1, calls);
}
}
/*
def testEager(self):
def future():
future.calls += 1
return constant_op.constant(2.0)
future.calls = 0

if context.executing_eagerly():
a = constant_op.constant(1.0)
b = future
with ops.control_dependencies([a, b]):
c = constant_op.constant(3.0)
self.assertEqual(future.calls, 1)
else:
g = ops.Graph()
with g.as_default():
a = constant_op.constant(1.0)
b = future()
with g.control_dependencies([a, b]):
c = constant_op.constant(3.0)
self.assertEqual(c.op.control_inputs, [a.op, b.op])
self.assertEqual(future.calls, 1)
*/
}


[Ignore("How to port the ConvertibleObj?")] [Ignore("How to port the ConvertibleObj?")]
[TestMethod] [TestMethod]
public void TestBasicWithConversion() public void TestBasicWithConversion()


+ 1
- 1
test/TensorFlowNET.UnitTest/ops_test/CreateOpFromTfOperationTest.cs View File

@@ -28,7 +28,7 @@ namespace TensorFlowNET.UnitTest.ops_test
using (var g = tf.Graph().as_default()) using (var g = tf.Graph().as_default())
{ {
var x = constant_op.constant(new[,] {{1, 2, 3}, {4, 5, 6}}); var x = constant_op.constant(new[,] {{1, 2, 3}, {4, 5, 6}});
var c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]);
var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]);
var op = g._create_op_from_tf_operation(c_op); var op = g._create_op_from_tf_operation(c_op);


Assert.AreEqual("myop", op.name); Assert.AreEqual("myop", op.name);


Loading…
Cancel
Save