Browse Source

added GraphTests, all Igored at the moment

tags/v0.9
Meinrad Recheis 6 years ago
parent
commit
e7707cbfad
4 changed files with 3300 additions and 1 deletions
  1. +45
    -0
      src/TensorFlowNET.Core/Graphs/Graph.cs
  2. +41
    -1
      src/TensorFlowNET.Core/ops.py.cs
  3. +200
    -0
      test/TensorFlowNET.UnitTest/ops_test/GraphTest.cs
  4. +3014
    -0
      test/TensorFlowNET.UnitTest/ops_test/ops_test_r1.13.py

+ 45
- 0
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -12,6 +12,51 @@ namespace Tensorflow
/// then create a TensorFlow session to run parts of the graph across a set of local and remote devices.
/// https://www.tensorflow.org/guide/graphs
/// </summary>
/*
A TensorFlow computation, represented as a dataflow graph.
A `Graph` contains a set of
`tf.Operation` objects,
which represent units of computation; and
`tf.Tensor` objects, which represent
the units of data that flow between operations.
A default `Graph` is always registered, and accessible by calling
`tf.get_default_graph`.
To add an operation to the default graph, simply call one of the functions
that defines a new `Operation`:
```python
c = tf.constant(4.0)
assert c.graph is tf.get_default_graph()
```
Another typical usage involves the
`tf.Graph.as_default`
context manager, which overrides the current default graph for the
lifetime of the context:
```python
g = tf.Graph()
with g.as_default():
# Define operations and tensors in `g`.
c = tf.constant(30.0)
assert c.graph is g
```
Important note: This class *is not* thread-safe for graph construction. All
operations should be created from a single thread, or external
synchronization must be provided. Unless otherwise specified, all methods
are not thread-safe.
A `Graph` instance supports an arbitrary number of "collections"
that are identified by name. For convenience when building a large
graph, collections can store groups of related objects: for
example, the `tf.Variable` uses a collection (named
`tf.GraphKeys.GLOBAL_VARIABLES`) for
all variables that are created during the construction of a graph. The caller
may define additional collections by specifying a new name.
*/
public partial class Graph : IPython, IDisposable
{
private IntPtr _handle;


+ 41
- 1
src/TensorFlowNET.Core/ops.py.cs View File

@@ -49,19 +49,59 @@ namespace Tensorflow
return get_default_graph().get_collection_ref(key);
}

private static Graph default_graph;
private static Graph default_graph;
/// <summary>
/// Returns the default graph for the current thread.
///
/// The returned graph will be the innermost graph on which a
/// `Graph.as_default()` context has been entered, or a global default
/// graph if none has been explicitly created.
///
/// NOTE: The default graph is a property of the current thread.If you
/// create a new thread, and wish to use the default graph in that
/// thread, you must explicitly add a `with g.as_default():` in that
/// thread's function.
/// </summary>
/// <returns></returns>
public static Graph get_default_graph()
{
//TODO: original source indicates there should be a _default_graph_stack!
//return _default_graph_stack.get_default()
if (default_graph == null)
default_graph = tf.Graph();
return default_graph;
}
public static Graph set_default_graph(Graph graph)
{
//TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack!
default_graph = graph;
return default_graph;
}
/// <summary>
/// Clears the default graph stack and resets the global default graph.
///
/// NOTE: The default graph is a property of the current thread.This
/// function applies only to the current thread.Calling this function while
/// a `tf.Session` or `tf.InteractiveSession` is active will result in undefined
/// behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects
/// after calling this function will result in undefined behavior.
/// </summary>
/// <returns></returns>
public static void reset_default_graph()
{
//TODO: original source indicates there should be a _default_graph_stack!
//if (!_default_graph_stack.is_cleared())
// throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " +
// "nested graphs. If you need a cleared graph, " +
// "exit the nesting and create a new graph.");
//_default_graph_stack.reset();
if (default_graph!=null)
default_graph.Dispose();
default_graph = tf.Graph();
}


public static Graph _get_graph_from_inputs(List<Tensor> op_input_list, Graph graph = null)
{
foreach(var op_input in op_input_list)


+ 200
- 0
test/TensorFlowNET.UnitTest/ops_test/GraphTest.cs View File

@@ -0,0 +1,200 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
using Tensorflow.Operations;
namespace TensorFlowNET.UnitTest.ops_test
{
/// <summary>
/// excerpt of tensorflow/python/framework/ops_test.py
/// </summary>
[TestClass]
public class GraphTest : PythonTest
{
[TestInitialize]
public void SetUp()
{
ops.reset_default_graph();
}
[TestCleanup]
public void TearDown()
{
ops.reset_default_graph();
}
private void _AssertDefault(Graph expected) {
Assert.AreSame(ops.get_default_graph(), expected);
}
[Ignore("Todo: Port")]
[TestMethod]
public void testResetDefaultGraphNesting()
{
/*
def testResetDefaultGraphNesting(self):
g0 = ops.Graph()
with self.assertRaises(AssertionError):
with g0.as_default():
ops.reset_default_graph()
*/
}
[Ignore("Todo: Port")]
[TestMethod]
public void testGraphContextManagerCancelsEager()
{
/*
def testGraphContextManagerCancelsEager(self):
with context.eager_mode():
with ops.Graph().as_default():
self.assertFalse(context.executing_eagerly())
*/
}
[Ignore("Todo: Port")]
[TestMethod]
public void testGraphContextManager()
{
/*
def testGraphContextManager(self):
g0 = ops.Graph()
with g0.as_default() as g1:
self.assertIs(g0, g1)
*/
}
[Ignore("Todo: Port")]
[TestMethod]
public void testDefaultGraph()
{
/*
def testDefaultGraph(self):
orig = ops.get_default_graph()
self._AssertDefault(orig)
g0 = ops.Graph()
self._AssertDefault(orig)
context_manager_0 = g0.as_default()
self._AssertDefault(orig)
with context_manager_0 as g0:
self._AssertDefault(g0)
with ops.Graph().as_default() as g1:
self._AssertDefault(g1)
self._AssertDefault(g0)
self._AssertDefault(orig)
*/
}
[Ignore("Todo: Port")]
[TestMethod]
public void testPreventFeeding()
{
/*
def testPreventFeeding(self):
g = ops.Graph()
a = constant_op.constant(2.0)
self.assertTrue(g.is_feedable(a))
g.prevent_feeding(a)
self.assertFalse(g.is_feedable(a))
*/
}
[Ignore("Todo: Port")]
[TestMethod]
public void testAsGraphElementConversions()
{
/*
def testAsGraphElementConversions(self):
class ConvertibleObj(object):
def _as_graph_element(self):
return "FloatOutput:0"
class NonConvertibleObj(object):
pass
g = ops.Graph()
a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
self.assertEqual(a, g.as_graph_element(ConvertibleObj()))
with self.assertRaises(TypeError):
g.as_graph_element(NonConvertibleObj())
*/
}
[Ignore("Todo: Port")]
[TestMethod]
public void testGarbageCollected()
{
/*
# Regression test against creating custom __del__ functions in classes
# involved in cyclic references, e.g. Graph and Operation. (Python won't gc
# cycles that require calling a __del__ method, because the __del__ method can
# theoretically increase the object's refcount to "save" it from gc, and any
# already-deleted objects in the cycle would have be to restored.)
def testGarbageCollected(self):
# Create a graph we can delete and a weak reference to monitor if it's gc'd
g = ops.Graph()
g_ref = weakref.ref(g)
# Create some ops
with g.as_default():
a = constant_op.constant(2.0)
b = constant_op.constant(3.0)
c = math_ops.add(a, b)
# Create a session we can delete
with session.Session(graph=g) as sess:
self.evaluate(c)
# Delete all references and trigger gc
del g
del a
del b
del c
del sess
gc.collect()
self.assertIsNone(g_ref())
*/
}
[Ignore("Todo: Port")]
[TestMethod]
public void testRunnableAfterInvalidShape()
{
/*
def testRunnableAfterInvalidShape(self):
with ops.Graph().as_default():
with self.assertRaises(ValueError):
math_ops.add([1, 2], [1, 2, 3])
a = constant_op.constant(1)
with session.Session() as sess:
self.evaluate(a)
*/
}
[Ignore("Todo: Port")]
[TestMethod]
public void testRunnableAfterInvalidShapeWithKernelLabelMap()
{
/*
def testRunnableAfterInvalidShapeWithKernelLabelMap(self):
g = ops.Graph()
with g.as_default():
with g._kernel_label_map({"KernelLabelRequired": "overload_1"}):
with self.assertRaises(ValueError):
test_ops.kernel_label_required(1)
a = constant_op.constant(1)
with session.Session() as sess:
self.evaluate(a)
*/
}
}
}

+ 3014
- 0
test/TensorFlowNET.UnitTest/ops_test/ops_test_r1.13.py
File diff suppressed because it is too large
View File


Loading…
Cancel
Save