@@ -12,6 +12,51 @@ namespace Tensorflow | |||
/// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | |||
/// https://www.tensorflow.org/guide/graphs | |||
/// </summary> | |||
/* | |||
A TensorFlow computation, represented as a dataflow graph. | |||
A `Graph` contains a set of | |||
`tf.Operation` objects, | |||
which represent units of computation; and | |||
`tf.Tensor` objects, which represent | |||
the units of data that flow between operations. | |||
A default `Graph` is always registered, and accessible by calling | |||
`tf.get_default_graph`. | |||
To add an operation to the default graph, simply call one of the functions | |||
that defines a new `Operation`: | |||
```python | |||
c = tf.constant(4.0) | |||
assert c.graph is tf.get_default_graph() | |||
``` | |||
Another typical usage involves the | |||
`tf.Graph.as_default` | |||
context manager, which overrides the current default graph for the | |||
lifetime of the context: | |||
```python | |||
g = tf.Graph() | |||
with g.as_default(): | |||
# Define operations and tensors in `g`. | |||
c = tf.constant(30.0) | |||
assert c.graph is g | |||
``` | |||
Important note: This class *is not* thread-safe for graph construction. All | |||
operations should be created from a single thread, or external | |||
synchronization must be provided. Unless otherwise specified, all methods | |||
are not thread-safe. | |||
A `Graph` instance supports an arbitrary number of "collections" | |||
that are identified by name. For convenience when building a large | |||
graph, collections can store groups of related objects: for | |||
example, the `tf.Variable` uses a collection (named | |||
`tf.GraphKeys.GLOBAL_VARIABLES`) for | |||
all variables that are created during the construction of a graph. The caller | |||
may define additional collections by specifying a new name. | |||
*/ | |||
public partial class Graph : IPython, IDisposable | |||
{ | |||
private IntPtr _handle; | |||
@@ -1,5 +1,6 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
namespace Tensorflow.Operations | |||
@@ -92,13 +93,15 @@ namespace Tensorflow.Operations | |||
switch (original_result) | |||
{ | |||
case Tensor result: | |||
return (original_result, _BuildCondTensor(new[] { result.op })); | |||
case Operation[] results: | |||
return (original_result, _BuildCondTensor(results)); | |||
case Tensor tensor: | |||
return (original_result, tensor); | |||
case float[] fv: | |||
{ | |||
var result = ops.convert_to_tensor(fv[0]); | |||
return (original_result, result ); | |||
} | |||
default: | |||
return (original_result, null); | |||
} | |||
@@ -114,7 +117,7 @@ namespace Tensorflow.Operations | |||
switch (original_result) | |||
{ | |||
case Tensor[] results: | |||
return (original_result, results); | |||
return (original_result, new Tensor[] { _BuildCondTensor(results.Select(t=>t.op).ToArray())}); | |||
case Operation[] results: | |||
return (original_result, new Tensor[] { _BuildCondTensor (results) }); | |||
case float[] fv: | |||
@@ -27,9 +27,9 @@ namespace Tensorflow | |||
for (int i = 0; i < NumInputs; i++) | |||
{ | |||
var tf_outpus = Input(i); | |||
var op = new Operation(tf_outpus.oper); | |||
retval[i] = op.outputs[tf_outpus.index]; | |||
var tf_outputs = Input(i); | |||
var op = new Operation(tf_outputs.oper); | |||
retval[i] = op.outputs[tf_outputs.index]; | |||
} | |||
_inputs = new InputList(retval); | |||
@@ -142,10 +142,29 @@ namespace Tensorflow | |||
return tpl.ToArray(); | |||
}); | |||
} | |||
} | |||
/// <summary> | |||
/// Produces the content of `output_tensor` only after `dependencies`. | |||
/// | |||
/// In some cases, a user may want the output of an operation to be | |||
/// consumed externally only after some other dependencies have run | |||
/// first.This function ensures returns `output_tensor`, but only after all | |||
/// operations in `dependencies` have run.Note that this means that there is | |||
/// no guarantee that `output_tensor` will be evaluated after any `dependencies` | |||
/// have run. | |||
/// | |||
/// See also `tf.tuple` and `tf.group`. | |||
/// </summary> | |||
/// <param name="dependencies">Iterable of operations to run before this op finishes.</param> | |||
/// <param name="output_tensor">A `Tensor` or `IndexedSlices` that will be returned.</param> | |||
/// <param name="name">(Optional) A name for this operation.</param> | |||
/// <returns>Same as `output_tensor`.</returns> | |||
public static Tensor with_dependencies(Operation[] dependencies, Tensor output_tensor, string name = null) | |||
{ | |||
//TODO: missing original code | |||
//if context.executing_eagerly(): | |||
// return output_tensor | |||
var values = new List<object>(); | |||
values.AddRange(dependencies); | |||
values.Add(output_tensor); | |||
@@ -153,12 +172,15 @@ namespace Tensorflow | |||
return with(ops.name_scope(name, "control_dependency", values), scope => | |||
{ | |||
name = scope; | |||
return with(ops.control_dependencies(dependencies), ctl => | |||
// TODO: missing original code | |||
//with ops.colocate_with(output_tensor): | |||
{ | |||
output_tensor = ops.convert_to_tensor_or_composite(output_tensor); | |||
return _Identity(output_tensor, name: name); | |||
}); | |||
return with(ops.control_dependencies(dependencies), ctl => | |||
{ | |||
output_tensor = ops.convert_to_tensor_or_composite(output_tensor); | |||
return _Identity(output_tensor, name: name); | |||
}); | |||
} | |||
}); | |||
} | |||
@@ -393,8 +415,27 @@ namespace Tensorflow | |||
return tensors_or_flows; | |||
} | |||
/// <summary> | |||
/// Returns the value of an available element of `inputs`. | |||
/// | |||
/// This op tests each of the tensors in `inputs` in turn to determine if any of | |||
/// them is available.If it finds an available tensor, it returns it and its | |||
/// index in `inputs`. | |||
/// | |||
/// It is an error if more than one tensor in `inputs` is available.If no tensor | |||
/// in `inputs` is available, the returned tensor and index are not set. | |||
/// | |||
/// This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of | |||
/// `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices | |||
/// before merging. | |||
/// </summary> | |||
/// <param name="inputs">inputs: The input tensors, at most one of which is available.</param> | |||
/// <param name="name">A name for this operation (optional).</param> | |||
/// <returns></returns> | |||
public static Tensor merge(Tensor[] inputs, string name = null) | |||
{ | |||
if (inputs.Any(x => x == null)) | |||
throw new ValueError($"At least one of the merge inputs is null: {inputs}"); | |||
return with(ops.name_scope(name, "Merge", inputs), scope => | |||
{ | |||
name = scope; | |||
@@ -49,19 +49,59 @@ namespace Tensorflow | |||
return get_default_graph().get_collection_ref(key); | |||
} | |||
private static Graph default_graph; | |||
private static Graph default_graph; | |||
/// <summary> | |||
/// Returns the default graph for the current thread. | |||
/// | |||
/// The returned graph will be the innermost graph on which a | |||
/// `Graph.as_default()` context has been entered, or a global default | |||
/// graph if none has been explicitly created. | |||
/// | |||
/// NOTE: The default graph is a property of the current thread.If you | |||
/// create a new thread, and wish to use the default graph in that | |||
/// thread, you must explicitly add a `with g.as_default():` in that | |||
/// thread's function. | |||
/// </summary> | |||
/// <returns></returns> | |||
public static Graph get_default_graph() | |||
{ | |||
//TODO: original source indicates there should be a _default_graph_stack! | |||
//return _default_graph_stack.get_default() | |||
if (default_graph == null) | |||
default_graph = tf.Graph(); | |||
return default_graph; | |||
} | |||
public static Graph set_default_graph(Graph graph) | |||
{ | |||
//TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack! | |||
default_graph = graph; | |||
return default_graph; | |||
} | |||
/// <summary> | |||
/// Clears the default graph stack and resets the global default graph. | |||
/// | |||
/// NOTE: The default graph is a property of the current thread.This | |||
/// function applies only to the current thread.Calling this function while | |||
/// a `tf.Session` or `tf.InteractiveSession` is active will result in undefined | |||
/// behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects | |||
/// after calling this function will result in undefined behavior. | |||
/// </summary> | |||
/// <returns></returns> | |||
public static void reset_default_graph() | |||
{ | |||
//TODO: original source indicates there should be a _default_graph_stack! | |||
//if (!_default_graph_stack.is_cleared()) | |||
// throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " + | |||
// "nested graphs. If you need a cleared graph, " + | |||
// "exit the nesting and create a new graph."); | |||
//_default_graph_stack.reset(); | |||
if (default_graph!=null) | |||
default_graph.Dispose(); | |||
default_graph = tf.Graph(); | |||
} | |||
public static Graph _get_graph_from_inputs(List<Tensor> op_input_list, Graph graph = null) | |||
{ | |||
foreach(var op_input in op_input_list) | |||
@@ -13,7 +13,7 @@ namespace TensorFlowNET.UnitTest | |||
/// </summary> | |||
public class PythonTest : Python | |||
{ | |||
public void assertItemsEqual(ICollection expected, ICollection given) | |||
public void assertItemsEqual(ICollection given, ICollection expected) | |||
{ | |||
Assert.IsNotNull(expected); | |||
Assert.IsNotNull(given); | |||
@@ -6,7 +6,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using Tensorflow; | |||
using Tensorflow.Eager; | |||
namespace TensorFlowNET.UnitTest | |||
namespace TensorFlowNET.UnitTest.ops_test | |||
{ | |||
/// <summary> | |||
/// excerpt of tensorflow/python/framework/ops_test.py | |||
@@ -157,8 +157,8 @@ namespace TensorFlowNET.UnitTest | |||
}); | |||
}); | |||
}); | |||
assertItemsEqual(new[] { a_1.op, a_2.op, a_3.op, a_4.op }, b_1.op.control_inputs); | |||
assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs); | |||
assertItemsEqual(b_1.op.control_inputs, new[] { a_1.op, a_2.op, a_3.op, a_4.op }); | |||
assertItemsEqual(b_2.op.control_inputs, b_1.op.control_inputs); | |||
} | |||
[TestMethod] | |||
@@ -200,6 +200,7 @@ namespace TensorFlowNET.UnitTest | |||
b_none2 = constant_op.constant(12.0); | |||
}); | |||
}); | |||
// Note assertItemsEqual(given, expected), expected and given parameters should be swapped below | |||
assertItemsEqual(new[] { a_3.op, a_4.op }, b_3_4.op.control_inputs); | |||
assertItemsEqual(new[] { a_3.op }, b_3.op.control_inputs); | |||
assertItemsEqual(new object[0], b_none.op.control_inputs); | |||
@@ -256,6 +257,7 @@ namespace TensorFlowNET.UnitTest | |||
}); | |||
}); | |||
// Note assertItemsEqual(given, expected), expected and given parameters should be swapped below | |||
assertItemsEqual(new[] {a_1.op}, b_1.op.control_inputs); | |||
assertItemsEqual(new[] {a_1.op, a_2.op}, b_2.op.control_inputs); | |||
assertItemsEqual(new[] { a_1.op, a_2.op}, b_3.op.control_inputs); |
@@ -1,10 +1,12 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using Tensorflow; | |||
using Tensorflow.Operations; | |||
namespace TensorFlowNET.UnitTest | |||
namespace TensorFlowNET.UnitTest.ops_test | |||
{ | |||
/// <summary> | |||
/// excerpt of tensorflow/python/framework/ops_test.py | |||
@@ -19,21 +21,21 @@ namespace TensorFlowNET.UnitTest | |||
[TestClass] | |||
public class CreateOpFromTfOperationTest : PythonTest | |||
{ | |||
[TestMethod] | |||
public void TestShape() | |||
{ | |||
var graph = tf.Graph().as_default(); | |||
with<Graph>(graph, g => | |||
{ | |||
var x = constant_op.constant(new [,] { {1, 2, 3}, {4, 5, 6}}); | |||
var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]); | |||
var x = constant_op.constant(new[,] { { 1, 2, 3 }, { 4, 5, 6 } }); | |||
var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] { x }, new Operation[0]); | |||
var op = g._create_op_from_tf_operation(c_op); | |||
Assert.AreEqual("myop", op.name); | |||
Assert.AreEqual("Identity", op.type); | |||
Assert.AreEqual(1, len(op.outputs)); | |||
assertItemsEqual(new []{2, 3}, op.outputs[0].shape); | |||
assertItemsEqual(new[] { 2, 3 }, op.outputs[0].shape); | |||
}); | |||
} | |||
@@ -47,7 +49,7 @@ namespace TensorFlowNET.UnitTest | |||
//var (c_op2, op_desc1) = ops._create_c_op(g, ops._NodeDef("Const", "myop_1"), new Tensor[0], new Operation[0]); | |||
//var op = g._create_op_from_tf_operation(c_op); | |||
//var op2 = g._create_op_from_tf_operation(c_op2); | |||
var op = constant_op.constant(0, name:"myop").op; | |||
var op = constant_op.constant(0, name: "myop").op; | |||
var op2 = constant_op.constant(0, name: "myop_1").op; | |||
// Create ops with same names as op1 and op2. We expect the new names to be | |||
@@ -62,7 +64,7 @@ namespace TensorFlowNET.UnitTest | |||
}); | |||
} | |||
[Ignore("Something is not right, Switch gets not inserted correctly?")] | |||
[Ignore("Switch op gets not inserted correctly in the graph")] | |||
[TestMethod] | |||
public void TestCond() | |||
{ | |||
@@ -91,8 +93,7 @@ namespace TensorFlowNET.UnitTest | |||
self.assertEqual(op_input.inputs[0], x); | |||
self.assertEqual(op.graph, g); | |||
self.assertIsNotNone(op._get_control_flow_context()); | |||
// TODO: op._get_control_flow_context().name not implemented | |||
//self.assertEqual(op._get_control_flow_context().name, "cond/cond_text"); | |||
self.assertEqual((op._get_control_flow_context() as ControlFlowContext).name, "cond/cond_text"); | |||
}); | |||
/* | |||
@test_util.run_v1_only("b/120545219") | |||
@@ -126,7 +127,39 @@ namespace TensorFlowNET.UnitTest | |||
# pylint: enable=protected-access | |||
*/ | |||
} | |||
/* | |||
[Ignore("Todo: Port")] | |||
[TestMethod] | |||
public void TestWhileLoop() | |||
{ | |||
var graph = tf.Graph().as_default(); | |||
Operation x=null; | |||
with<Graph>(graph, g => | |||
{ | |||
x = constant_op.constant(42); | |||
var body = new Func<int, int>(i => | |||
{ | |||
ops._create_c_op(ops.get_default_graph(), ops._NodeDef("Identity", "myloop/myop"), new[] {x}, | |||
new Operation[0]); | |||
var new_ops = g._add_new_tf_operations(); | |||
self.assertEqual(len(new_ops), 1); | |||
return i; | |||
}); | |||
// TODO: port control_flow_ops.while_loop | |||
//control_flow_ops.while_loop( i => i < 10, body, new int[]{0}, name = "myloop"); | |||
}); | |||
var op = graph.get_operation_by_name("myloop/myop"); | |||
self.assertIsNotNone(op); | |||
self.assertEqual(op.name, "myloop/myop"); | |||
self.assertEqual(op.type, "Identity"); | |||
self.assertEqual(op.outputs.Length, 0); | |||
var op_input = op.inputs[0].op; | |||
self.assertEqual(op_input.type, "Enter"); | |||
self.assertItemsEqual(op_input.inputs.OfType<Operation>().ToArray(), new[] {x}); | |||
self.assertEqual(op.graph, graph); | |||
self.assertIsNotNone(op._get_control_flow_context()); | |||
self.assertEqual(((ControlFlowContext)op._get_control_flow_context()).name, "myloop/while_context"); | |||
/* | |||
@test_util.run_v1_only("b/120545219") | |||
def testWhileLoop(self): | |||
g = ops.Graph() | |||
@@ -156,8 +189,15 @@ namespace TensorFlowNET.UnitTest | |||
self.assertEqual(op._get_control_flow_context().name, | |||
"myloop/while_context") | |||
# pylint: enable=protected-access | |||
*/ | |||
} | |||
@test_util.run_v1_only("b/120545219") | |||
[Ignore("Todo: Port")] | |||
[TestMethod] | |||
public void TestWhileLoopWithInternalControlDep() | |||
{ | |||
/* | |||
@test_util.run_v1_only("b/120545219") | |||
def testWhileLoopWithInternalControlDep(self): | |||
g = ops.Graph() | |||
with g.as_default(): | |||
@@ -180,7 +220,14 @@ namespace TensorFlowNET.UnitTest | |||
self.assertIsNotNone(c) | |||
# Internal control dep is preserved | |||
self.assertEqual(op.control_inputs, [c]) | |||
*/ | |||
} | |||
[Ignore("Todo: Port")] | |||
[TestMethod] | |||
public void TestWhileLoopWithExternalControlDep() | |||
{ | |||
/* | |||
@test_util.run_v1_only("b/120545219") | |||
def testWhileLoopWithExternalControlDep(self): | |||
g = ops.Graph() | |||
@@ -203,8 +250,8 @@ namespace TensorFlowNET.UnitTest | |||
# External control dep is removed and replaced with internal control dep | |||
self.assertNotEqual(op.control_inputs[0], c.op) | |||
self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context()) | |||
*/ | |||
} | |||
*/ | |||
} | |||
} | |||
} |
@@ -0,0 +1,200 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using Tensorflow; | |||
using Tensorflow.Operations; | |||
namespace TensorFlowNET.UnitTest.ops_test | |||
{ | |||
/// <summary> | |||
/// excerpt of tensorflow/python/framework/ops_test.py | |||
/// </summary> | |||
[TestClass] | |||
public class GraphTest : PythonTest | |||
{ | |||
[TestInitialize] | |||
public void SetUp() | |||
{ | |||
ops.reset_default_graph(); | |||
} | |||
[TestCleanup] | |||
public void TearDown() | |||
{ | |||
ops.reset_default_graph(); | |||
} | |||
private void _AssertDefault(Graph expected) { | |||
Assert.AreSame(ops.get_default_graph(), expected); | |||
} | |||
[Ignore("Todo: Port")] | |||
[TestMethod] | |||
public void testResetDefaultGraphNesting() | |||
{ | |||
/* | |||
def testResetDefaultGraphNesting(self): | |||
g0 = ops.Graph() | |||
with self.assertRaises(AssertionError): | |||
with g0.as_default(): | |||
ops.reset_default_graph() | |||
*/ | |||
} | |||
[Ignore("Todo: Port")] | |||
[TestMethod] | |||
public void testGraphContextManagerCancelsEager() | |||
{ | |||
/* | |||
def testGraphContextManagerCancelsEager(self): | |||
with context.eager_mode(): | |||
with ops.Graph().as_default(): | |||
self.assertFalse(context.executing_eagerly()) | |||
*/ | |||
} | |||
[Ignore("Todo: Port")] | |||
[TestMethod] | |||
public void testGraphContextManager() | |||
{ | |||
/* | |||
def testGraphContextManager(self): | |||
g0 = ops.Graph() | |||
with g0.as_default() as g1: | |||
self.assertIs(g0, g1) | |||
*/ | |||
} | |||
[Ignore("Todo: Port")] | |||
[TestMethod] | |||
public void testDefaultGraph() | |||
{ | |||
/* | |||
def testDefaultGraph(self): | |||
orig = ops.get_default_graph() | |||
self._AssertDefault(orig) | |||
g0 = ops.Graph() | |||
self._AssertDefault(orig) | |||
context_manager_0 = g0.as_default() | |||
self._AssertDefault(orig) | |||
with context_manager_0 as g0: | |||
self._AssertDefault(g0) | |||
with ops.Graph().as_default() as g1: | |||
self._AssertDefault(g1) | |||
self._AssertDefault(g0) | |||
self._AssertDefault(orig) | |||
*/ | |||
} | |||
[Ignore("Todo: Port")] | |||
[TestMethod] | |||
public void testPreventFeeding() | |||
{ | |||
/* | |||
def testPreventFeeding(self): | |||
g = ops.Graph() | |||
a = constant_op.constant(2.0) | |||
self.assertTrue(g.is_feedable(a)) | |||
g.prevent_feeding(a) | |||
self.assertFalse(g.is_feedable(a)) | |||
*/ | |||
} | |||
[Ignore("Todo: Port")] | |||
[TestMethod] | |||
public void testAsGraphElementConversions() | |||
{ | |||
/* | |||
def testAsGraphElementConversions(self): | |||
class ConvertibleObj(object): | |||
def _as_graph_element(self): | |||
return "FloatOutput:0" | |||
class NonConvertibleObj(object): | |||
pass | |||
g = ops.Graph() | |||
a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) | |||
self.assertEqual(a, g.as_graph_element(ConvertibleObj())) | |||
with self.assertRaises(TypeError): | |||
g.as_graph_element(NonConvertibleObj()) | |||
*/ | |||
} | |||
[Ignore("Todo: Port")] | |||
[TestMethod] | |||
public void testGarbageCollected() | |||
{ | |||
/* | |||
# Regression test against creating custom __del__ functions in classes | |||
# involved in cyclic references, e.g. Graph and Operation. (Python won't gc | |||
# cycles that require calling a __del__ method, because the __del__ method can | |||
# theoretically increase the object's refcount to "save" it from gc, and any | |||
# already-deleted objects in the cycle would have be to restored.) | |||
def testGarbageCollected(self): | |||
# Create a graph we can delete and a weak reference to monitor if it's gc'd | |||
g = ops.Graph() | |||
g_ref = weakref.ref(g) | |||
# Create some ops | |||
with g.as_default(): | |||
a = constant_op.constant(2.0) | |||
b = constant_op.constant(3.0) | |||
c = math_ops.add(a, b) | |||
# Create a session we can delete | |||
with session.Session(graph=g) as sess: | |||
self.evaluate(c) | |||
# Delete all references and trigger gc | |||
del g | |||
del a | |||
del b | |||
del c | |||
del sess | |||
gc.collect() | |||
self.assertIsNone(g_ref()) | |||
*/ | |||
} | |||
[Ignore("Todo: Port")] | |||
[TestMethod] | |||
public void testRunnableAfterInvalidShape() | |||
{ | |||
/* | |||
def testRunnableAfterInvalidShape(self): | |||
with ops.Graph().as_default(): | |||
with self.assertRaises(ValueError): | |||
math_ops.add([1, 2], [1, 2, 3]) | |||
a = constant_op.constant(1) | |||
with session.Session() as sess: | |||
self.evaluate(a) | |||
*/ | |||
} | |||
[Ignore("Todo: Port")] | |||
[TestMethod] | |||
public void testRunnableAfterInvalidShapeWithKernelLabelMap() | |||
{ | |||
/* | |||
def testRunnableAfterInvalidShapeWithKernelLabelMap(self): | |||
g = ops.Graph() | |||
with g.as_default(): | |||
with g._kernel_label_map({"KernelLabelRequired": "overload_1"}): | |||
with self.assertRaises(ValueError): | |||
test_ops.kernel_label_required(1) | |||
a = constant_op.constant(1) | |||
with session.Session() as sess: | |||
self.evaluate(a) | |||
*/ | |||
} | |||
} | |||
} |