@@ -12,6 +12,51 @@ namespace Tensorflow | |||||
/// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | ||||
/// https://www.tensorflow.org/guide/graphs | /// https://www.tensorflow.org/guide/graphs | ||||
/// </summary> | /// </summary> | ||||
/* | |||||
A TensorFlow computation, represented as a dataflow graph. | |||||
A `Graph` contains a set of | |||||
`tf.Operation` objects, | |||||
which represent units of computation; and | |||||
`tf.Tensor` objects, which represent | |||||
the units of data that flow between operations. | |||||
A default `Graph` is always registered, and accessible by calling | |||||
`tf.get_default_graph`. | |||||
To add an operation to the default graph, simply call one of the functions | |||||
that defines a new `Operation`: | |||||
```python | |||||
c = tf.constant(4.0) | |||||
assert c.graph is tf.get_default_graph() | |||||
``` | |||||
Another typical usage involves the | |||||
`tf.Graph.as_default` | |||||
context manager, which overrides the current default graph for the | |||||
lifetime of the context: | |||||
```python | |||||
g = tf.Graph() | |||||
with g.as_default(): | |||||
# Define operations and tensors in `g`. | |||||
c = tf.constant(30.0) | |||||
assert c.graph is g | |||||
``` | |||||
Important note: This class *is not* thread-safe for graph construction. All | |||||
operations should be created from a single thread, or external | |||||
synchronization must be provided. Unless otherwise specified, all methods | |||||
are not thread-safe. | |||||
A `Graph` instance supports an arbitrary number of "collections" | |||||
that are identified by name. For convenience when building a large | |||||
graph, collections can store groups of related objects: for | |||||
example, the `tf.Variable` uses a collection (named | |||||
`tf.GraphKeys.GLOBAL_VARIABLES`) for | |||||
all variables that are created during the construction of a graph. The caller | |||||
may define additional collections by specifying a new name. | |||||
*/ | |||||
public partial class Graph : IPython, IDisposable | public partial class Graph : IPython, IDisposable | ||||
{ | { | ||||
private IntPtr _handle; | private IntPtr _handle; | ||||
@@ -49,19 +49,59 @@ namespace Tensorflow | |||||
return get_default_graph().get_collection_ref(key); | return get_default_graph().get_collection_ref(key); | ||||
} | } | ||||
private static Graph default_graph; | |||||
private static Graph default_graph; | |||||
/// <summary> | |||||
/// Returns the default graph for the current thread. | |||||
/// | |||||
/// The returned graph will be the innermost graph on which a | |||||
/// `Graph.as_default()` context has been entered, or a global default | |||||
/// graph if none has been explicitly created. | |||||
/// | |||||
/// NOTE: The default graph is a property of the current thread.If you | |||||
/// create a new thread, and wish to use the default graph in that | |||||
/// thread, you must explicitly add a `with g.as_default():` in that | |||||
/// thread's function. | |||||
/// </summary> | |||||
/// <returns></returns> | |||||
public static Graph get_default_graph() | public static Graph get_default_graph() | ||||
{ | { | ||||
//TODO: original source indicates there should be a _default_graph_stack! | |||||
//return _default_graph_stack.get_default() | |||||
if (default_graph == null) | if (default_graph == null) | ||||
default_graph = tf.Graph(); | default_graph = tf.Graph(); | ||||
return default_graph; | return default_graph; | ||||
} | } | ||||
public static Graph set_default_graph(Graph graph) | public static Graph set_default_graph(Graph graph) | ||||
{ | { | ||||
//TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack! | |||||
default_graph = graph; | default_graph = graph; | ||||
return default_graph; | return default_graph; | ||||
} | |||||
/// <summary> | |||||
/// Clears the default graph stack and resets the global default graph. | |||||
/// | |||||
/// NOTE: The default graph is a property of the current thread.This | |||||
/// function applies only to the current thread.Calling this function while | |||||
/// a `tf.Session` or `tf.InteractiveSession` is active will result in undefined | |||||
/// behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects | |||||
/// after calling this function will result in undefined behavior. | |||||
/// </summary> | |||||
/// <returns></returns> | |||||
public static void reset_default_graph() | |||||
{ | |||||
//TODO: original source indicates there should be a _default_graph_stack! | |||||
//if (!_default_graph_stack.is_cleared()) | |||||
// throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " + | |||||
// "nested graphs. If you need a cleared graph, " + | |||||
// "exit the nesting and create a new graph."); | |||||
//_default_graph_stack.reset(); | |||||
if (default_graph!=null) | |||||
default_graph.Dispose(); | |||||
default_graph = tf.Graph(); | |||||
} | } | ||||
public static Graph _get_graph_from_inputs(List<Tensor> op_input_list, Graph graph = null) | public static Graph _get_graph_from_inputs(List<Tensor> op_input_list, Graph graph = null) | ||||
{ | { | ||||
foreach(var op_input in op_input_list) | foreach(var op_input in op_input_list) | ||||
@@ -0,0 +1,200 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using Tensorflow; | |||||
using Tensorflow.Operations; | |||||
namespace TensorFlowNET.UnitTest.ops_test | |||||
{ | |||||
/// <summary> | |||||
/// excerpt of tensorflow/python/framework/ops_test.py | |||||
/// </summary> | |||||
[TestClass] | |||||
public class GraphTest : PythonTest | |||||
{ | |||||
[TestInitialize] | |||||
public void SetUp() | |||||
{ | |||||
ops.reset_default_graph(); | |||||
} | |||||
[TestCleanup] | |||||
public void TearDown() | |||||
{ | |||||
ops.reset_default_graph(); | |||||
} | |||||
private void _AssertDefault(Graph expected) { | |||||
Assert.AreSame(ops.get_default_graph(), expected); | |||||
} | |||||
[Ignore("Todo: Port")] | |||||
[TestMethod] | |||||
public void testResetDefaultGraphNesting() | |||||
{ | |||||
/* | |||||
def testResetDefaultGraphNesting(self): | |||||
g0 = ops.Graph() | |||||
with self.assertRaises(AssertionError): | |||||
with g0.as_default(): | |||||
ops.reset_default_graph() | |||||
*/ | |||||
} | |||||
[Ignore("Todo: Port")] | |||||
[TestMethod] | |||||
public void testGraphContextManagerCancelsEager() | |||||
{ | |||||
/* | |||||
def testGraphContextManagerCancelsEager(self): | |||||
with context.eager_mode(): | |||||
with ops.Graph().as_default(): | |||||
self.assertFalse(context.executing_eagerly()) | |||||
*/ | |||||
} | |||||
[Ignore("Todo: Port")] | |||||
[TestMethod] | |||||
public void testGraphContextManager() | |||||
{ | |||||
/* | |||||
def testGraphContextManager(self): | |||||
g0 = ops.Graph() | |||||
with g0.as_default() as g1: | |||||
self.assertIs(g0, g1) | |||||
*/ | |||||
} | |||||
[Ignore("Todo: Port")] | |||||
[TestMethod] | |||||
public void testDefaultGraph() | |||||
{ | |||||
/* | |||||
def testDefaultGraph(self): | |||||
orig = ops.get_default_graph() | |||||
self._AssertDefault(orig) | |||||
g0 = ops.Graph() | |||||
self._AssertDefault(orig) | |||||
context_manager_0 = g0.as_default() | |||||
self._AssertDefault(orig) | |||||
with context_manager_0 as g0: | |||||
self._AssertDefault(g0) | |||||
with ops.Graph().as_default() as g1: | |||||
self._AssertDefault(g1) | |||||
self._AssertDefault(g0) | |||||
self._AssertDefault(orig) | |||||
*/ | |||||
} | |||||
[Ignore("Todo: Port")] | |||||
[TestMethod] | |||||
public void testPreventFeeding() | |||||
{ | |||||
/* | |||||
def testPreventFeeding(self): | |||||
g = ops.Graph() | |||||
a = constant_op.constant(2.0) | |||||
self.assertTrue(g.is_feedable(a)) | |||||
g.prevent_feeding(a) | |||||
self.assertFalse(g.is_feedable(a)) | |||||
*/ | |||||
} | |||||
[Ignore("Todo: Port")] | |||||
[TestMethod] | |||||
public void testAsGraphElementConversions() | |||||
{ | |||||
/* | |||||
def testAsGraphElementConversions(self): | |||||
class ConvertibleObj(object): | |||||
def _as_graph_element(self): | |||||
return "FloatOutput:0" | |||||
class NonConvertibleObj(object): | |||||
pass | |||||
g = ops.Graph() | |||||
a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) | |||||
self.assertEqual(a, g.as_graph_element(ConvertibleObj())) | |||||
with self.assertRaises(TypeError): | |||||
g.as_graph_element(NonConvertibleObj()) | |||||
*/ | |||||
} | |||||
[Ignore("Todo: Port")] | |||||
[TestMethod] | |||||
public void testGarbageCollected() | |||||
{ | |||||
/* | |||||
# Regression test against creating custom __del__ functions in classes | |||||
# involved in cyclic references, e.g. Graph and Operation. (Python won't gc | |||||
# cycles that require calling a __del__ method, because the __del__ method can | |||||
# theoretically increase the object's refcount to "save" it from gc, and any | |||||
# already-deleted objects in the cycle would have be to restored.) | |||||
def testGarbageCollected(self): | |||||
# Create a graph we can delete and a weak reference to monitor if it's gc'd | |||||
g = ops.Graph() | |||||
g_ref = weakref.ref(g) | |||||
# Create some ops | |||||
with g.as_default(): | |||||
a = constant_op.constant(2.0) | |||||
b = constant_op.constant(3.0) | |||||
c = math_ops.add(a, b) | |||||
# Create a session we can delete | |||||
with session.Session(graph=g) as sess: | |||||
self.evaluate(c) | |||||
# Delete all references and trigger gc | |||||
del g | |||||
del a | |||||
del b | |||||
del c | |||||
del sess | |||||
gc.collect() | |||||
self.assertIsNone(g_ref()) | |||||
*/ | |||||
} | |||||
[Ignore("Todo: Port")] | |||||
[TestMethod] | |||||
public void testRunnableAfterInvalidShape() | |||||
{ | |||||
/* | |||||
def testRunnableAfterInvalidShape(self): | |||||
with ops.Graph().as_default(): | |||||
with self.assertRaises(ValueError): | |||||
math_ops.add([1, 2], [1, 2, 3]) | |||||
a = constant_op.constant(1) | |||||
with session.Session() as sess: | |||||
self.evaluate(a) | |||||
*/ | |||||
} | |||||
[Ignore("Todo: Port")] | |||||
[TestMethod] | |||||
public void testRunnableAfterInvalidShapeWithKernelLabelMap() | |||||
{ | |||||
/* | |||||
def testRunnableAfterInvalidShapeWithKernelLabelMap(self): | |||||
g = ops.Graph() | |||||
with g.as_default(): | |||||
with g._kernel_label_map({"KernelLabelRequired": "overload_1"}): | |||||
with self.assertRaises(ValueError): | |||||
test_ops.kernel_label_required(1) | |||||
a = constant_op.constant(1) | |||||
with session.Session() as sess: | |||||
self.evaluate(a) | |||||
*/ | |||||
} | |||||
} | |||||
} |