@@ -27,7 +27,7 @@ namespace Tensorflow | |||||
// 100K gradient 44M. | // 100K gradient 44M. | ||||
mm.Execute(10, 10 * batchSize, cases.Gradient); | mm.Execute(10, 10 * batchSize, cases.Gradient); | ||||
// 120M | |||||
// 95M | |||||
Console.WriteLine("Finished."); | Console.WriteLine("Finished."); | ||||
Console.ReadLine(); | Console.ReadLine(); | ||||
} | } | ||||
@@ -0,0 +1,30 @@ | |||||
/***************************************************************************** | |||||
Copyright 2020 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using NumSharp; | |||||
namespace Tensorflow | |||||
{ | |||||
public partial class tensorflow | |||||
{ | |||||
public CompatApi compat { get; } = new CompatApi(); | |||||
public class CompatApi | |||||
{ | |||||
public CompatV1Api v1 { get; } = new CompatV1Api(); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,30 @@ | |||||
/***************************************************************************** | |||||
Copyright 2020 The TensorFlow.NET Authors. All Rights Reserved. | |||||
Licensed under the Apache License, Version 2.0 (the "License"); | |||||
you may not use this file except in compliance with the License. | |||||
You may obtain a copy of the License at | |||||
http://www.apache.org/licenses/LICENSE-2.0 | |||||
Unless required by applicable law or agreed to in writing, software | |||||
distributed under the License is distributed on an "AS IS" BASIS, | |||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
See the License for the specific language governing permissions and | |||||
limitations under the License. | |||||
******************************************************************************/ | |||||
using System; | |||||
using Tensorflow.Eager; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow | |||||
{ | |||||
public class CompatV1Api | |||||
{ | |||||
public void disable_eager_execution() | |||||
{ | |||||
tf.context.default_execution_mode = Context.GRAPH_MODE; | |||||
} | |||||
} | |||||
} |
@@ -259,7 +259,8 @@ namespace Tensorflow | |||||
public Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, | public Operation create_op(string op_type, Tensor[] inputs, TF_DataType[] dtypes, | ||||
TF_DataType[] input_types = null, string name = null, | TF_DataType[] input_types = null, string name = null, | ||||
Dictionary<string, AttrValue> attrs = null, OpDef op_def = null) | |||||
Dictionary<string, AttrValue> attrs = null, OpDef op_def = null, | |||||
bool compute_device = true) | |||||
{ | { | ||||
if (inputs == null) | if (inputs == null) | ||||
inputs = new Tensor[0]; | inputs = new Tensor[0]; | ||||
@@ -270,7 +271,7 @@ namespace Tensorflow | |||||
// If a names ends with a '/' it is a "name scope" and we use it as-is, | // If a names ends with a '/' it is a "name scope" and we use it as-is, | ||||
// after removing the trailing '/'. | // after removing the trailing '/'. | ||||
name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name); | name = name.EndsWith("/") ? ops.name_from_scope_name(name) : unique_name(name); | ||||
var node_def = ops._NodeDef(op_type, name, device: "", attrs: attrs); | |||||
var node_def = ops._NodeDef(op_type, name, attrs: attrs); | |||||
var input_ops = inputs.Select(x => x.op).ToArray(); | var input_ops = inputs.Select(x => x.op).ToArray(); | ||||
var control_inputs = _control_dependencies_for_inputs(input_ops); | var control_inputs = _control_dependencies_for_inputs(input_ops); | ||||
@@ -284,7 +285,7 @@ namespace Tensorflow | |||||
original_op: null, | original_op: null, | ||||
op_def: op_def); | op_def: op_def); | ||||
_create_op_helper(op, true); | |||||
_create_op_helper(op, compute_device); | |||||
/*Console.Write($"create_op: {op_type} '{node_def.Name}'"); | /*Console.Write($"create_op: {op_type} '{node_def.Name}'"); | ||||
Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}"); | Console.Write($", inputs: {(inputs.Length == 0 ? "empty" : String.Join(", ", inputs.Select(x => x.name)))}"); | ||||
@@ -40,8 +40,8 @@ namespace Tensorflow | |||||
public void _add_control_input(Operation op) | public void _add_control_input(Operation op) | ||||
{ | { | ||||
//c_api.TF_AddControlInput(_operDesc, op); | |||||
c_api.AddControlInput(graph, _handle, op); | |||||
c_api.TF_AddControlInput(OpDesc, op); | |||||
//c_api.AddControlInput(graph, _handle, op); | |||||
} | } | ||||
public void _add_control_inputs(Operation[] ops) | public void _add_control_inputs(Operation[] ops) | ||||
@@ -64,7 +64,7 @@ namespace Tensorflow | |||||
public string Device => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | public string Device => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | ||||
bool _is_stateful; | bool _is_stateful; | ||||
public OperationDescription OpDesc { get; set; } | |||||
public NodeDef node_def | public NodeDef node_def | ||||
{ | { | ||||
@@ -170,7 +170,7 @@ namespace Tensorflow | |||||
op_def = g.GetOpDef(node_def.Op); | op_def = g.GetOpDef(node_def.Op); | ||||
var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | ||||
_handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | |||||
(_handle, OpDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | |||||
_is_stateful = op_def.IsStateful; | _is_stateful = op_def.IsStateful; | ||||
// Initialize self._outputs. | // Initialize self._outputs. | ||||
@@ -187,9 +187,6 @@ namespace Tensorflow | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern unsafe ulong TF_StringEncode(byte* src, ulong src_len, sbyte* dst, ulong dst_len, SafeStatusHandle status); | public static extern unsafe ulong TF_StringEncode(byte* src, ulong src_len, sbyte* dst, ulong dst_len, SafeStatusHandle status); | ||||
[DllImport(TensorFlowLibName)] | |||||
public static extern unsafe ulong TF_StringEncode(IntPtr src, ulong src_len, IntPtr dst, ulong dst_len, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// Decode a string encoded using TF_StringEncode. | /// Decode a string encoded using TF_StringEncode. | ||||
/// </summary> | /// </summary> | ||||
@@ -199,9 +196,6 @@ namespace Tensorflow | |||||
/// <param name="dst_len">size_t*</param> | /// <param name="dst_len">size_t*</param> | ||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | |||||
public static extern ulong TF_StringDecode(IntPtr src, ulong src_len, IntPtr dst, ref ulong dst_len, SafeStatusHandle status); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern unsafe ulong TF_StringDecode(byte* src, ulong src_len, byte** dst, ref ulong dst_len, SafeStatusHandle status); | public static extern unsafe ulong TF_StringDecode(byte* src, ulong src_len, byte** dst, ref ulong dst_len, SafeStatusHandle status); | ||||
@@ -155,7 +155,7 @@ namespace Tensorflow | |||||
/// </param> | /// </param> | ||||
/// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param> | /// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param> | ||||
/// <returns>A wrapped TF_Operation*.</returns> | /// <returns>A wrapped TF_Operation*.</returns> | ||||
public static IntPtr _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs) | |||||
public static (IntPtr, OperationDescription) _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs) | |||||
{ | { | ||||
lock (Locks.ProcessWide) | lock (Locks.ProcessWide) | ||||
{ | { | ||||
@@ -198,7 +198,7 @@ namespace Tensorflow | |||||
status.Check(true); | status.Check(true); | ||||
return c_op; | |||||
return (c_op, op_desc); | |||||
} | } | ||||
} | } | ||||
@@ -207,7 +207,7 @@ namespace Tensorflow | |||||
return graph.GetOpDef(type); | return graph.GetOpDef(type); | ||||
} | } | ||||
public static NodeDef _NodeDef(string op_type, string name, string device = "", Dictionary<string, AttrValue> attrs = null) | |||||
public static NodeDef _NodeDef(string op_type, string name, Dictionary<string, AttrValue> attrs = null) | |||||
{ | { | ||||
var node_def = new NodeDef(); | var node_def = new NodeDef(); | ||||
node_def.Op = op_type; | node_def.Op = op_type; | ||||
@@ -4,13 +4,13 @@ using System.Collections.Generic; | |||||
using System.Linq; | using System.Linq; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.UnitTest; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace TensorFlowNET.UnitTest.Basics | namespace TensorFlowNET.UnitTest.Basics | ||||
{ | { | ||||
[Ignore] | |||||
[TestClass] | [TestClass] | ||||
public class QueueTest | |||||
public class QueueTest : GraphModeTestBase | |||||
{ | { | ||||
[TestMethod] | [TestMethod] | ||||
public void PaddingFIFOQueue() | public void PaddingFIFOQueue() | ||||
@@ -10,7 +10,6 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
[TestClass] | [TestClass] | ||||
public class VariableTest | public class VariableTest | ||||
{ | { | ||||
[Ignore] | |||||
[TestMethod] | [TestMethod] | ||||
public void NewVariable() | public void NewVariable() | ||||
{ | { | ||||
@@ -34,7 +33,6 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
Assert.AreEqual(4, (int)y.numpy()); | Assert.AreEqual(4, (int)y.numpy()); | ||||
} | } | ||||
[Ignore] | |||||
[TestMethod] | [TestMethod] | ||||
public void Assign1() | public void Assign1() | ||||
{ | { | ||||
@@ -0,0 +1,24 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using TensorFlowNET.UnitTest; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.UnitTest | |||||
{ | |||||
public class GraphModeTestBase : PythonTest | |||||
{ | |||||
[TestInitialize] | |||||
public void TestInit() | |||||
{ | |||||
tf.compat.v1.disable_eager_execution(); | |||||
} | |||||
[TestCleanup] | |||||
public void TestClean() | |||||
{ | |||||
tf.enable_eager_execution(); | |||||
} | |||||
} | |||||
} |
@@ -1,16 +1,16 @@ | |||||
using System; | using System; | ||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.UnitTest; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace TensorFlowNET.UnitTest.Basics | namespace TensorFlowNET.UnitTest.Basics | ||||
{ | { | ||||
[TestClass] | [TestClass] | ||||
public class NameScopeTest | |||||
public class NameScopeTest : GraphModeTestBase | |||||
{ | { | ||||
string name = ""; | string name = ""; | ||||
[Ignore] | |||||
[TestMethod] | [TestMethod] | ||||
public void NestedNameScope() | public void NestedNameScope() | ||||
{ | { | ||||
@@ -7,12 +7,12 @@ using Tensorflow; | |||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
using Buffer = Tensorflow.Buffer; | using Buffer = Tensorflow.Buffer; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using Tensorflow.UnitTest; | |||||
namespace TensorFlowNET.UnitTest.Basics | namespace TensorFlowNET.UnitTest.Basics | ||||
{ | { | ||||
[Ignore] | |||||
[TestClass] | [TestClass] | ||||
public class OperationsTest | |||||
public class OperationsTest : GraphModeTestBase | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// Port from tensorflow\c\c_api_test.cc | /// Port from tensorflow\c\c_api_test.cc | ||||
@@ -726,6 +726,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
#endregion | #endregion | ||||
} | } | ||||
[Ignore] | |||||
[TestMethod] | [TestMethod] | ||||
public void divOpTests() | public void divOpTests() | ||||
{ | { | ||||
@@ -3,6 +3,7 @@ using NumSharp; | |||||
using System; | using System; | ||||
using System.Linq; | using System.Linq; | ||||
using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
using System.Text; | |||||
using Tensorflow; | using Tensorflow; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
@@ -160,23 +161,6 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
Assert.AreEqual(6.0, (double)c); | Assert.AreEqual(6.0, (double)c); | ||||
} | } | ||||
[TestMethod] | |||||
public void StringEncode() | |||||
{ | |||||
string str = "Hello, TensorFlow.NET!"; | |||||
var handle = Marshal.StringToHGlobalAnsi(str); | |||||
var dst_len = c_api.TF_StringEncodedSize((ulong)str.Length); | |||||
Assert.AreEqual(dst_len, (ulong)23); | |||||
IntPtr dst = Marshal.AllocHGlobal((int)dst_len); | |||||
var encoded_len = c_api.TF_StringEncode(handle, (ulong)str.Length, dst, dst_len, status.Handle); | |||||
Assert.AreEqual((ulong)23, encoded_len); | |||||
Assert.AreEqual(status.Code, TF_Code.TF_OK); | |||||
string encoded_str = Marshal.PtrToStringUTF8(dst + sizeof(byte)); | |||||
Assert.AreEqual(encoded_str, str); | |||||
Assert.AreEqual(str.Length, Marshal.ReadByte(dst)); | |||||
// c_api.TF_StringDecode(dst, (ulong)str.Length, IntPtr.Zero, ref dst_len, status.Handle); | |||||
} | |||||
[TestMethod] | [TestMethod] | ||||
public void Reshape() | public void Reshape() | ||||
{ | { |
@@ -1,5 +1,6 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.UnitTest; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace TensorFlowNET.UnitTest.control_flow_ops_test | namespace TensorFlowNET.UnitTest.control_flow_ops_test | ||||
@@ -7,10 +8,10 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||||
/// <summary> | /// <summary> | ||||
/// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py | /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py | ||||
/// </summary> | /// </summary> | ||||
[Ignore] | |||||
[TestClass] | [TestClass] | ||||
public class CondTestCases : PythonTest | |||||
public class CondTestCases : GraphModeTestBase | |||||
{ | { | ||||
[Ignore("Dependent on UpdateEdge")] | |||||
[TestMethod] | [TestMethod] | ||||
public void testCondTrue_ConstOnly() | public void testCondTrue_ConstOnly() | ||||
{ | { | ||||
@@ -49,6 +50,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||||
} | } | ||||
} | } | ||||
[Ignore("Dependent on UpdateEdge")] | |||||
[TestMethod] | [TestMethod] | ||||
public void testCondTrue() | public void testCondTrue() | ||||
{ | { | ||||
@@ -65,6 +67,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||||
assertEquals(result, 34); | assertEquals(result, 34); | ||||
} | } | ||||
[Ignore("Dependent on UpdateEdge")] | |||||
[TestMethod] | [TestMethod] | ||||
public void testCondFalse() | public void testCondFalse() | ||||
{ | { | ||||
@@ -1,14 +1,14 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.UnitTest; | |||||
namespace TensorFlowNET.UnitTest.control_flow_ops_test | namespace TensorFlowNET.UnitTest.control_flow_ops_test | ||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py | /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py | ||||
/// </summary> | /// </summary> | ||||
[Ignore] | |||||
[TestClass] | [TestClass] | ||||
public class ShapeTestCase : PythonTest | |||||
public class ShapeTestCase : GraphModeTestBase | |||||
{ | { | ||||
[TestMethod] | [TestMethod] | ||||
@@ -1,24 +1,24 @@ | |||||
using System; | using System; | ||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.UnitTest; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace TensorFlowNET.UnitTest.control_flow_ops_test | namespace TensorFlowNET.UnitTest.control_flow_ops_test | ||||
{ | { | ||||
[TestClass] | [TestClass] | ||||
public class WhileContextTestCase : PythonTest | |||||
public class WhileContextTestCase : GraphModeTestBase | |||||
{ | { | ||||
/// <summary> | /// <summary> | ||||
/// https://www.tensorflow.org/api_docs/python/tf/while_loop | /// https://www.tensorflow.org/api_docs/python/tf/while_loop | ||||
/// </summary> | /// </summary> | ||||
[Ignore] | |||||
[TestMethod] | [TestMethod] | ||||
public void SimpleWhileLoop() | public void SimpleWhileLoop() | ||||
{ | { | ||||
var i = constant_op.constant(0, name: "i"); | var i = constant_op.constant(0, name: "i"); | ||||
var c = new Func<Tensor, Tensor>(x => tf.less(x, 10, name: "c")); | var c = new Func<Tensor, Tensor>(x => tf.less(x, 10, name: "c")); | ||||
var b = new Func<Tensor, Tensor>(x => tf.add(x, 1, name: "c")); | var b = new Func<Tensor, Tensor>(x => tf.add(x, 1, name: "c")); | ||||
//var r = control_flow_ops.while_loop(c, b, i); | |||||
// var r = control_flow_ops.while_loop(c, b, i); | |||||
} | } | ||||
private void _testWhileContextHelper(int maximum_iterations) | private void _testWhileContextHelper(int maximum_iterations) | ||||
@@ -2,15 +2,14 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using NumSharp; | using NumSharp; | ||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.UnitTest; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace TensorFlowNET.UnitTest.img_test | namespace TensorFlowNET.UnitTest.img_test | ||||
{ | { | ||||
[Ignore] | |||||
[TestClass] | [TestClass] | ||||
public class TestCrop | |||||
public class TestCrop : GraphModeTestBase | |||||
{ | { | ||||
[TestMethod] | [TestMethod] | ||||
public void TestCropAndResize() | public void TestCropAndResize() | ||||
{ | { | ||||
@@ -3,13 +3,13 @@ using FluentAssertions; | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using NumSharp; | using NumSharp; | ||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.UnitTest; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace TensorFlowNET.UnitTest.layers_test | namespace TensorFlowNET.UnitTest.layers_test | ||||
{ | { | ||||
[Ignore] | |||||
[TestClass] | [TestClass] | ||||
public class flatten | |||||
public class flatten : GraphModeTestBase | |||||
{ | { | ||||
[TestMethod] | [TestMethod] | ||||
public void Case1() | public void Case1() | ||||
@@ -3,6 +3,7 @@ using System.Linq; | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using Tensorflow.UnitTest; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace TensorFlowNET.UnitTest.ops_test | namespace TensorFlowNET.UnitTest.ops_test | ||||
@@ -10,9 +11,8 @@ namespace TensorFlowNET.UnitTest.ops_test | |||||
/// <summary> | /// <summary> | ||||
/// excerpt of tensorflow/python/framework/ops_test.py | /// excerpt of tensorflow/python/framework/ops_test.py | ||||
/// </summary> | /// </summary> | ||||
[Ignore] | |||||
[TestClass] | [TestClass] | ||||
public class ControlDependenciesTest : PythonTest | |||||
public class ControlDependenciesTest : GraphModeTestBase | |||||
{ | { | ||||
[TestMethod] | [TestMethod] | ||||
public void TestBasic() | public void TestBasic() | ||||
@@ -35,72 +35,6 @@ namespace TensorFlowNET.UnitTest.ops_test | |||||
Assert.AreEqual(0, e.op.control_inputs.Length); | Assert.AreEqual(0, e.op.control_inputs.Length); | ||||
} | } | ||||
[Ignore("Future is not supported yet")] | |||||
[TestMethod] | |||||
public void TestEager() | |||||
{ | |||||
Tensor a = null, c = null; | |||||
object b = null; | |||||
var calls = 0; | |||||
Func<Tensor> future = () => | |||||
{ | |||||
calls += 1; | |||||
return constant_op.constant(2.0); | |||||
}; | |||||
using (var opts = new ContextOptions()) | |||||
using (var status = new Status()) | |||||
using (var context = new Context(opts, status)) | |||||
{ | |||||
if (context.executing_eagerly()) | |||||
{ | |||||
// TODO: make this compile (see original Python code below) | |||||
a = constant_op.constant(1.0); | |||||
b = future; // <--- {henon} obviously, this doesn't compile, looks like control_dependencies needs to be able to take callables as well. | |||||
tf_with(ops.control_dependencies(new object[] { a, b }), ctrl => | |||||
{ | |||||
return c = constant_op.constant(3.0); | |||||
}); | |||||
Assert.AreEqual(calls, 1); | |||||
} | |||||
else | |||||
{ | |||||
var g = tf.Graph().as_default(); | |||||
a = constant_op.constant(1.0); | |||||
var b1 = future(); | |||||
tf_with(g.control_dependencies(new[] { a, b }), ctrl => | |||||
{ | |||||
c = constant_op.constant(3.0); | |||||
}); | |||||
Assert.IsTrue(Enumerable.SequenceEqual(c.op.control_inputs, new[] { a.op, b1.op })); | |||||
Assert.AreEqual(1, calls); | |||||
} | |||||
} | |||||
/* | |||||
def testEager(self): | |||||
def future(): | |||||
future.calls += 1 | |||||
return constant_op.constant(2.0) | |||||
future.calls = 0 | |||||
if context.executing_eagerly(): | |||||
a = constant_op.constant(1.0) | |||||
b = future | |||||
with ops.control_dependencies([a, b]): | |||||
c = constant_op.constant(3.0) | |||||
self.assertEqual(future.calls, 1) | |||||
else: | |||||
g = ops.Graph() | |||||
with g.as_default(): | |||||
a = constant_op.constant(1.0) | |||||
b = future() | |||||
with g.control_dependencies([a, b]): | |||||
c = constant_op.constant(3.0) | |||||
self.assertEqual(c.op.control_inputs, [a.op, b.op]) | |||||
self.assertEqual(future.calls, 1) | |||||
*/ | |||||
} | |||||
[Ignore("How to port the ConvertibleObj?")] | [Ignore("How to port the ConvertibleObj?")] | ||||
[TestMethod] | [TestMethod] | ||||
public void TestBasicWithConversion() | public void TestBasicWithConversion() | ||||
@@ -28,7 +28,7 @@ namespace TensorFlowNET.UnitTest.ops_test | |||||
using (var g = tf.Graph().as_default()) | using (var g = tf.Graph().as_default()) | ||||
{ | { | ||||
var x = constant_op.constant(new[,] {{1, 2, 3}, {4, 5, 6}}); | var x = constant_op.constant(new[,] {{1, 2, 3}, {4, 5, 6}}); | ||||
var c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]); | |||||
var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]); | |||||
var op = g._create_op_from_tf_operation(c_op); | var op = g._create_op_from_tf_operation(c_op); | ||||
Assert.AreEqual("myop", op.name); | Assert.AreEqual("myop", op.name); | ||||