diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs
index 03d58dbc..081893c2 100644
--- a/src/TensorFlowNET.Core/Graphs/Graph.cs
+++ b/src/TensorFlowNET.Core/Graphs/Graph.cs
@@ -12,6 +12,51 @@ namespace Tensorflow
/// then create a TensorFlow session to run parts of the graph across a set of local and remote devices.
/// https://www.tensorflow.org/guide/graphs
///
+ /*
+ A TensorFlow computation, represented as a dataflow graph.
+
+ A `Graph` contains a set of
+ `tf.Operation` objects,
+ which represent units of computation; and
+ `tf.Tensor` objects, which represent
+ the units of data that flow between operations.
+
+ A default `Graph` is always registered, and accessible by calling
+ `tf.get_default_graph`.
+ To add an operation to the default graph, simply call one of the functions
+ that defines a new `Operation`:
+
+ ```python
+ c = tf.constant(4.0)
+ assert c.graph is tf.get_default_graph()
+ ```
+
+ Another typical usage involves the
+ `tf.Graph.as_default`
+ context manager, which overrides the current default graph for the
+ lifetime of the context:
+
+ ```python
+ g = tf.Graph()
+ with g.as_default():
+ # Define operations and tensors in `g`.
+ c = tf.constant(30.0)
+ assert c.graph is g
+ ```
+
+ Important note: This class *is not* thread-safe for graph construction. All
+ operations should be created from a single thread, or external
+ synchronization must be provided. Unless otherwise specified, all methods
+ are not thread-safe.
+
+ A `Graph` instance supports an arbitrary number of "collections"
+ that are identified by name. For convenience when building a large
+ graph, collections can store groups of related objects: for
+ example, the `tf.Variable` uses a collection (named
+ `tf.GraphKeys.GLOBAL_VARIABLES`) for
+ all variables that are created during the construction of a graph. The caller
+ may define additional collections by specifying a new name.
+ */
public partial class Graph : IPython, IDisposable
{
private IntPtr _handle;
diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs
index add752ea..7b208854 100644
--- a/src/TensorFlowNET.Core/ops.py.cs
+++ b/src/TensorFlowNET.Core/ops.py.cs
@@ -49,19 +49,59 @@ namespace Tensorflow
return get_default_graph().get_collection_ref(key);
}
- private static Graph default_graph;
+ private static Graph default_graph;
+ ///
+ /// Returns the default graph for the current thread.
+ ///
+ /// The returned graph will be the innermost graph on which a
+ /// `Graph.as_default()` context has been entered, or a global default
+ /// graph if none has been explicitly created.
+ ///
+ /// NOTE: The default graph is a property of the current thread.If you
+ /// create a new thread, and wish to use the default graph in that
+ /// thread, you must explicitly add a `with g.as_default():` in that
+ /// thread's function.
+ ///
+ ///
public static Graph get_default_graph()
{
+ //TODO: original source indicates there should be a _default_graph_stack!
+ //return _default_graph_stack.get_default()
if (default_graph == null)
default_graph = tf.Graph();
return default_graph;
}
public static Graph set_default_graph(Graph graph)
{
+ //TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack!
default_graph = graph;
return default_graph;
+ }
+
+ ///
+ /// Clears the default graph stack and resets the global default graph.
+ ///
+ /// NOTE: The default graph is a property of the current thread.This
+ /// function applies only to the current thread.Calling this function while
+ /// a `tf.Session` or `tf.InteractiveSession` is active will result in undefined
+ /// behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects
+ /// after calling this function will result in undefined behavior.
+ ///
+ ///
+ public static void reset_default_graph()
+ {
+ //TODO: original source indicates there should be a _default_graph_stack!
+ //if (!_default_graph_stack.is_cleared())
+ // throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " +
+ // "nested graphs. If you need a cleared graph, " +
+ // "exit the nesting and create a new graph.");
+ //_default_graph_stack.reset();
+ if (default_graph!=null)
+ default_graph.Dispose();
+ default_graph = tf.Graph();
}
+
public static Graph _get_graph_from_inputs(List op_input_list, Graph graph = null)
{
foreach(var op_input in op_input_list)
diff --git a/test/TensorFlowNET.UnitTest/ops_test/GraphTest.cs b/test/TensorFlowNET.UnitTest/ops_test/GraphTest.cs
new file mode 100644
index 00000000..a4a8c299
--- /dev/null
+++ b/test/TensorFlowNET.UnitTest/ops_test/GraphTest.cs
@@ -0,0 +1,200 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Tensorflow;
+using Tensorflow.Operations;
+
+namespace TensorFlowNET.UnitTest.ops_test
+{
+ ///
+ /// excerpt of tensorflow/python/framework/ops_test.py
+ ///
+ [TestClass]
+ public class GraphTest : PythonTest
+ {
+
+ [TestInitialize]
+ public void SetUp()
+ {
+ ops.reset_default_graph();
+ }
+
+ [TestCleanup]
+ public void TearDown()
+ {
+ ops.reset_default_graph();
+ }
+
+ private void _AssertDefault(Graph expected) {
+ Assert.AreSame(ops.get_default_graph(), expected);
+ }
+
+
+ [Ignore("Todo: Port")]
+ [TestMethod]
+ public void testResetDefaultGraphNesting()
+ {
+/*
+ def testResetDefaultGraphNesting(self):
+ g0 = ops.Graph()
+ with self.assertRaises(AssertionError):
+ with g0.as_default():
+ ops.reset_default_graph()
+*/
+ }
+
+ [Ignore("Todo: Port")]
+ [TestMethod]
+ public void testGraphContextManagerCancelsEager()
+ {
+ /*
+ def testGraphContextManagerCancelsEager(self):
+ with context.eager_mode():
+ with ops.Graph().as_default():
+ self.assertFalse(context.executing_eagerly())
+ */
+ }
+
+
+ [Ignore("Todo: Port")]
+ [TestMethod]
+ public void testGraphContextManager()
+ {
+ /*
+ def testGraphContextManager(self):
+ g0 = ops.Graph()
+ with g0.as_default() as g1:
+ self.assertIs(g0, g1)
+ */
+ }
+
+ [Ignore("Todo: Port")]
+ [TestMethod]
+ public void testDefaultGraph()
+ {
+ /*
+ def testDefaultGraph(self):
+ orig = ops.get_default_graph()
+ self._AssertDefault(orig)
+ g0 = ops.Graph()
+ self._AssertDefault(orig)
+ context_manager_0 = g0.as_default()
+ self._AssertDefault(orig)
+ with context_manager_0 as g0:
+ self._AssertDefault(g0)
+ with ops.Graph().as_default() as g1:
+ self._AssertDefault(g1)
+ self._AssertDefault(g0)
+ self._AssertDefault(orig)
+ */
+ }
+
+ [Ignore("Todo: Port")]
+ [TestMethod]
+ public void testPreventFeeding()
+ {
+ /*
+ def testPreventFeeding(self):
+ g = ops.Graph()
+ a = constant_op.constant(2.0)
+ self.assertTrue(g.is_feedable(a))
+ g.prevent_feeding(a)
+ self.assertFalse(g.is_feedable(a))
+ */
+ }
+
+
+ [Ignore("Todo: Port")]
+ [TestMethod]
+ public void testAsGraphElementConversions()
+ {
+ /*
+ def testAsGraphElementConversions(self):
+
+ class ConvertibleObj(object):
+
+ def _as_graph_element(self):
+ return "FloatOutput:0"
+
+ class NonConvertibleObj(object):
+
+ pass
+
+ g = ops.Graph()
+ a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ self.assertEqual(a, g.as_graph_element(ConvertibleObj()))
+ with self.assertRaises(TypeError):
+ g.as_graph_element(NonConvertibleObj())
+ */
+ }
+
+ [Ignore("Todo: Port")]
+ [TestMethod]
+ public void testGarbageCollected()
+ {
+ /*
+ # Regression test against creating custom __del__ functions in classes
+ # involved in cyclic references, e.g. Graph and Operation. (Python won't gc
+ # cycles that require calling a __del__ method, because the __del__ method can
+ # theoretically increase the object's refcount to "save" it from gc, and any
+ # already-deleted objects in the cycle would have be to restored.)
+ def testGarbageCollected(self):
+ # Create a graph we can delete and a weak reference to monitor if it's gc'd
+ g = ops.Graph()
+ g_ref = weakref.ref(g)
+ # Create some ops
+ with g.as_default():
+ a = constant_op.constant(2.0)
+ b = constant_op.constant(3.0)
+ c = math_ops.add(a, b)
+ # Create a session we can delete
+ with session.Session(graph=g) as sess:
+ self.evaluate(c)
+ # Delete all references and trigger gc
+ del g
+ del a
+ del b
+ del c
+ del sess
+ gc.collect()
+ self.assertIsNone(g_ref())
+ */
+ }
+
+ [Ignore("Todo: Port")]
+ [TestMethod]
+ public void testRunnableAfterInvalidShape()
+ {
+ /*
+ def testRunnableAfterInvalidShape(self):
+ with ops.Graph().as_default():
+ with self.assertRaises(ValueError):
+ math_ops.add([1, 2], [1, 2, 3])
+ a = constant_op.constant(1)
+ with session.Session() as sess:
+ self.evaluate(a)
+ */
+ }
+
+ [Ignore("Todo: Port")]
+ [TestMethod]
+ public void testRunnableAfterInvalidShapeWithKernelLabelMap()
+ {
+ /*
+ def testRunnableAfterInvalidShapeWithKernelLabelMap(self):
+ g = ops.Graph()
+ with g.as_default():
+ with g._kernel_label_map({"KernelLabelRequired": "overload_1"}):
+ with self.assertRaises(ValueError):
+ test_ops.kernel_label_required(1)
+ a = constant_op.constant(1)
+ with session.Session() as sess:
+ self.evaluate(a)
+ */
+ }
+
+
+ }
+}
diff --git a/test/TensorFlowNET.UnitTest/ops_test/ops_test_r1.13.py b/test/TensorFlowNET.UnitTest/ops_test/ops_test_r1.13.py
new file mode 100644
index 00000000..2d7ee1a9
--- /dev/null
+++ b/test/TensorFlowNET.UnitTest/ops_test/ops_test_r1.13.py
@@ -0,0 +1,3014 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.python.framework.ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gc
+import os
+import threading
+import weakref
+
+from tensorflow.core.framework import attr_value_pb2
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function as eager_function
+from tensorflow.python.framework import common_shapes
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import device as pydev
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework import test_ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import versions
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import resources
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+import tensorflow.python.ops.gradients # pylint: disable=unused-import
+from tensorflow.python.platform import googletest
+from tensorflow.python.util import compat
+
+ops._set_call_cpp_shape_fn(common_shapes.call_cpp_shape_fn)
+
+
+class ResourceTest(test_util.TensorFlowTestCase):
+
+ @test_util.run_deprecated_v1
+ def testBuildGraph(self):
+ with self.cached_session():
+ pt = test_ops.stub_resource_handle_op(container="a", shared_name="b")
+ test_ops.resource_create_op(pt).run()
+
+ @test_util.run_deprecated_v1
+ def testInitialize(self):
+ with self.cached_session():
+ handle = test_ops.stub_resource_handle_op(container="a", shared_name="b")
+ resources.register_resource(
+ handle=handle,
+ create_op=test_ops.resource_create_op(handle),
+ is_initialized_op=test_ops.resource_initialized_op(handle))
+ self.assertEquals(
+ len(
+ resources.report_uninitialized_resources(
+ resources.shared_resources()).eval()), 1)
+ resources.initialize_resources(resources.shared_resources()).run()
+ self.assertEquals(
+ len(
+ resources.report_uninitialized_resources(
+ resources.shared_resources()).eval()), 0)
+
+
+class TensorAndShapeTest(test_util.TensorFlowTestCase):
+
+ def testShape(self):
+ op = ops.Operation(
+ ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
+ t = op.outputs[0]
+ self.assertEqual(tensor_shape.unknown_shape(), t.get_shape())
+ t.set_shape([1, 2, 3])
+ self.assertEqual([1, 2, 3], t.get_shape())
+
+ def testIterable(self):
+ op = ops.Operation(
+ ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
+ t = op.outputs[0]
+ self.assertTrue(isinstance(t, ops.Tensor))
+ with self.assertRaisesRegexp(TypeError, "iter"):
+ for _ in t:
+ pass
+
+ def testAddShape(self):
+ with self.cached_session():
+ a = array_ops.zeros([2, 3])
+ b = array_ops.ones([1, 3])
+ c = a + b
+ self.assertEqual([2, 3], c.shape)
+
+ @test_util.run_deprecated_v1
+ def testUnknownDim(self):
+ with self.cached_session():
+ a = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
+ b = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
+ c = a + b
+ self.assertEqual([2, None, 3], c.shape.as_list())
+
+ @test_util.run_deprecated_v1
+ def testUnknownShape(self):
+ with self.cached_session():
+ a = array_ops.placeholder(dtype=dtypes.float32, shape=None)
+ b = array_ops.ones([1, 3])
+ c = a + b
+ self.assertEqual(tensor_shape.unknown_shape(), c.shape)
+
+ @test_util.run_deprecated_v1
+ def testScalarShape(self):
+ with self.cached_session():
+ a = array_ops.placeholder(dtype=dtypes.float32, shape=[])
+ b = array_ops.ones([])
+ c = a + b
+ self.assertEqual(tensor_shape.scalar(), c.shape)
+
+ @test_util.run_deprecated_v1
+ def testShapeFunctionError(self):
+ with self.cached_session():
+ a = array_ops.ones([1, 2, 3])
+ b = array_ops.ones([4, 5, 6])
+ with self.assertRaisesRegexp(
+ ValueError,
+ r"Dimensions must be equal, but are 2 and 5 for 'add' \(op: 'Add'\) "
+ r"with input shapes: \[1,2,3\], \[4,5,6\]."):
+ _ = a + b
+
+
+class IndexedSlicesTest(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testToTensor(self):
+ values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
+ indices = constant_op.constant([0, 2])
+ dense_shape = constant_op.constant([3, 2])
+ x = ops.IndexedSlices(values, indices, dense_shape)
+ tensor = ops.convert_to_tensor(x, name="tensor")
+ self.assertAllEqual(self.evaluate(tensor), [[2, 3], [0, 0], [5, 7]])
+
+ @test_util.run_deprecated_v1
+ def testNegation(self):
+ with self.cached_session():
+ values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
+ indices = constant_op.constant([0, 2])
+ x = -ops.IndexedSlices(values, indices)
+ self.assertAllEqual(x.values.eval(), [[-2, -3], [-5, -7]])
+ self.assertAllEqual(x.indices.eval(), [0, 2])
+
+ @test_util.run_deprecated_v1
+ def testScalarMul(self):
+ with self.cached_session():
+ values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
+ indices = constant_op.constant([0, 2])
+ x = math_ops.scalar_mul(-2, ops.IndexedSlices(values, indices))
+ self.assertAllEqual(x.values.eval(), [[-4, -6], [-10, -14]])
+ self.assertAllEqual(x.indices.eval(), [0, 2])
+
+
+class NodeDefConstructorTest(test_util.TensorFlowTestCase):
+
+ def testNoArgs(self):
+ nodedef = ops._NodeDef("None", "bar")
+ self.assertProtoEquals("op: 'None' name: 'bar'", nodedef)
+
+ def testArgs(self):
+ nodedef = ops._NodeDef("foo", "bar", device="/device:baz:*")
+ self.assertProtoEquals("op:'foo' name:'bar' device:'/device:baz:*'",
+ nodedef)
+ nodedef = ops._NodeDef("foo", "bar", device=pydev.DeviceSpec(job="j"))
+ self.assertProtoEquals("op:'foo' name:'bar' device:'/job:j'", nodedef)
+
+
+def _apply_op(g, *args, **kwargs):
+ op = g.create_op(*args, **kwargs)
+ if len(op.outputs) == 1:
+ return op.outputs[0]
+ else:
+ return op.outputs
+
+
+class OperationTest(test_util.TensorFlowTestCase):
+
+ @test_util.run_deprecated_v1
+ def testNoInputs(self):
+ op = test_ops.float_output_string_output(name="myop").a.op
+ self.assertEqual(2, len(op.values()))
+ self.assertEqual(0, len(op.inputs))
+ self.assertEqual("myop", op.name)
+
+ float_t, label_str_t = op.values()
+ self.assertEqual(dtypes.float32, float_t.dtype)
+ self.assertEqual(op, float_t.op)
+ self.assertEqual(0, float_t._value_index)
+ self.assertEqual(0, len(float_t.consumers()))
+ self.assertEqual("myop", float_t._as_node_def_input())
+
+ self.assertEqual(dtypes.string, label_str_t.dtype)
+ self.assertEqual(op, label_str_t.op)
+ self.assertEqual(1, label_str_t._value_index)
+ self.assertEqual(0, len(label_str_t.consumers()))
+ self.assertEqual("myop:1", label_str_t._as_node_def_input())
+
+ self.assertProtoEquals("op:'FloatOutputStringOutput' name:'myop'",
+ op.node_def)
+
+ @test_util.run_deprecated_v1
+ def testNoOutputs(self):
+ op1 = test_ops.float_output(name="myop1").op
+ float_t, = op1.values()
+ op2 = test_ops.float_input(float_t, name="myop2")
+ self.assertEqual(0, len(op2.values()))
+ self.assertEqual(1, len(op2.inputs))
+ self.assertIs(float_t, op2.inputs[0])
+
+ self.assertEqual(1, len(float_t.consumers()))
+ self.assertEqual(op2, float_t.consumers()[0])
+
+ self.assertProtoEquals("op:'FloatOutput' name:'myop1'", op1.node_def)
+ self.assertProtoEquals("op:'FloatInput' name:'myop2' input:'myop1'",
+ op2.node_def)
+
+ @test_util.run_deprecated_v1
+ def testInputsAndOutputs(self):
+ op1 = test_ops.float_output(name="myop1").op
+ self.assertEqual(1, len(op1.values()))
+ float1_t, = op1.values()
+
+ op2 = test_ops.float_output_string_output(name="myop2").a.op
+ self.assertEqual(2, len(op2.values()))
+ float2_t, label2_str_t = op2.values()
+
+ # Note that we consume label2_str_t twice here.
+ op3 = test_ops.foo2(float1_t, label2_str_t, label2_str_t, name="myop3").d.op
+ self.assertEqual(2, len(op3.values()))
+
+ self.assertEqual(1, len(float1_t.consumers()))
+ self.assertEqual(op3, float1_t.consumers()[0])
+
+ self.assertEqual(0, len(float2_t.consumers()))
+
+ self.assertEqual(2, len(label2_str_t.consumers()))
+ self.assertEqual(op3, label2_str_t.consumers()[0])
+ self.assertEqual(op3, label2_str_t.consumers()[1])
+
+ self.assertProtoEquals("""
+ op:'Foo2' name:'myop3'
+ input:'myop1' input:'myop2:1' input:'myop2:1'
+ """, op3.node_def)
+
+ def testDeviceFromNodeDef(self):
+ op = ops.Operation(
+ ops._NodeDef("None", "myop", device="/job:goo/device:GPU:0"),
+ ops.Graph(), [], [])
+ self.assertEqual("/job:goo/device:GPU:0", op.device)
+
+ def testDeviceObject(self):
+ op = ops.Operation(ops._NodeDef("None", "myop"), ops.Graph(), [], [])
+ op._set_device("/job:goo/device:GPU:0")
+ self.assertProtoEquals(
+ "op:'None' name:'myop' device:'/job:goo/device:GPU:0' ", op.node_def)
+ op = ops.Operation(ops._NodeDef("None", "op2"), ops.Graph(), [], [])
+ op._set_device(
+ pydev.DeviceSpec(
+ job="muu", device_type="CPU", device_index=0))
+ self.assertProtoEquals(
+ "op:'None' name:'op2' device:'/job:muu/device:CPU:0'", op.node_def)
+
+ def testReferenceInput(self):
+ g = ops.Graph()
+ op1 = ops.Operation(
+ ops._NodeDef("RefOutputFloatOutput", "op1"), g, [],
+ [dtypes.float32_ref, dtypes.float32])
+ self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def)
+ self.assertEquals([], list(op1.inputs))
+ ref_t, nonref_t = op1.values()
+ # NOTE(mrry): Must specify input_types to preserve ref-typed input.
+ op2 = ops.Operation(
+ ops._NodeDef("RefInputFloatInput", "op2"),
+ g, [ref_t, nonref_t], [],
+ input_types=[dtypes.float32_ref, dtypes.float32])
+ self.assertProtoEquals(
+ "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'",
+ op2.node_def)
+ self.assertEquals([ref_t, nonref_t], list(op2.inputs))
+ op3 = ops.Operation(
+ ops._NodeDef("TwoFloatInputs", "op3"), g, [ref_t, nonref_t], [])
+ self.assertProtoEquals(
+ "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'",
+ op3.node_def)
+
+ def testInvalidNames(self):
+ g = ops.Graph()
+ with self.assertRaises(ValueError):
+ ops.Operation(ops._NodeDef("op", ""), g)
+ with self.assertRaises(ValueError):
+ ops.Operation(ops._NodeDef("op", "_invalid"), g)
+ with self.assertRaises(ValueError):
+ ops.Operation(ops._NodeDef("op", "-invalid"), g)
+ with self.assertRaises(ValueError):
+ ops.Operation(ops._NodeDef("op", "/invalid"), g)
+ with self.assertRaises(ValueError):
+ ops.Operation(ops._NodeDef("op", "invalid:0"), g)
+
+ @test_util.run_deprecated_v1
+ def testNoShapeFunction(self):
+ op = test_ops.a()
+ self.assertEqual(tensor_shape.unknown_shape(), op.get_shape())
+
+ @test_util.run_in_graph_and_eager_modes
+ def testConvertToTensorNestedArray(self):
+ values = [[2], [3], [5], [7]]
+ tensor = ops.convert_to_tensor(values)
+ self.assertAllEqual((4, 1), tensor.get_shape().as_list())
+ self.assertAllEqual(values, self.evaluate(tensor))
+
+ def testShapeTuple(self):
+ with self.cached_session():
+ c = constant_op.constant(1)
+ self.assertEqual(c._shape_tuple(), ()) # pylint: disable=protected-access
+
+ def testConvertToTensorEager(self):
+ with context.eager_mode():
+ t = constant_op.constant(1)
+ self.assertTrue(isinstance(t, ops.EagerTensor))
+ converted = ops.convert_to_tensor(t)
+ self.assertTrue(isinstance(converted, ops.EagerTensor))
+ converted = ops.convert_to_tensor(1)
+ self.assertTrue(isinstance(converted, ops.EagerTensor))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testConvertToTensorNestedTuple(self):
+ values = ((2,), (3,), (5,), (7,))
+ tensor = ops.convert_to_tensor(values)
+ self.assertAllEqual((4, 1), tensor.get_shape().as_list())
+ self.assertAllEqual(values, self.evaluate(ops.convert_to_tensor(values)))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testConvertToTensorNestedTensors(self):
+ values = ((2,), (3,), (5,), (7,))
+ tensor = ops.convert_to_tensor(
+ [constant_op.constant(row) for row in values])
+ self.assertAllEqual((4, 1), tensor.get_shape().as_list())
+ self.assertAllEqual(values, self.evaluate(tensor))
+ tensor = ops.convert_to_tensor(
+ [[constant_op.constant(v) for v in row] for row in values])
+ self.assertAllEqual((4, 1), tensor.get_shape().as_list())
+ self.assertAllEqual(values, self.evaluate(tensor))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testConvertToTensorNestedMix(self):
+ values = ([2], (3,), [constant_op.constant(5)], constant_op.constant([7]))
+ tensor = ops.convert_to_tensor(values)
+ self.assertAllEqual((4, 1), tensor.get_shape().as_list())
+ self.assertAllEqual(((2,), (3,), (5,), (7,)), self.evaluate(tensor))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testConvertToTensorPreferred(self):
+ values = [2, 3, 5, 7]
+ tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.float32)
+ self.assertEqual(dtypes.float32, tensor.dtype)
+
+ # Convert empty tensor to anything.
+ values = []
+ tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
+ self.assertEqual(dtypes.int64, tensor.dtype)
+
+ # The preferred dtype is a type error and will convert to
+ # float32 instead.
+ values = [1.23]
+ tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
+ self.assertEqual(dtypes.float32, tensor.dtype)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testConvertToInvalidTensorType(self):
+ with self.assertRaises(TypeError):
+ # Forcing an invalid dtype should fail with a type error.
+ values = [1.23]
+ ops.convert_to_tensor(values, dtype=dtypes.int64)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testConvertToTensorFromInvalidTensor(self):
+ tensor = constant_op.constant(42.0, dtype=dtypes.float32)
+ with self.assertRaises(ValueError):
+ ops.convert_to_tensor(tensor, dtype=dtypes.int32)
+
+ @test_util.run_deprecated_v1
+ def testNoConvert(self):
+ # Operation cannot be converted to Tensor.
+ op = control_flow_ops.no_op()
+ with self.assertRaisesRegexp(TypeError,
+ r"Can't convert Operation '.*' to Tensor"):
+ ops.convert_to_tensor(op)
+
+ def testStr(self):
+ node_def = ops._NodeDef("None", "op1")
+ op = ops.Operation(node_def, ops.Graph(), [], [dtypes.float32])
+ self.assertEqual(str(node_def), str(op))
+
+ def testRepr(self):
+ op = ops.Operation(
+ ops._NodeDef("None", "op1"), ops.Graph(), [], [dtypes.float32])
+ self.assertEqual("", repr(op))
+
+ @test_util.run_deprecated_v1
+ def testGetAttr(self):
+ op = test_ops.default_attrs()
+ self.assertEqual(op.get_attr("string_val"), b"abc")
+ self.assertEqual(op.get_attr("string_list_val"), [b"abc", b""])
+ self.assertEqual(op.get_attr("int_val"), 123)
+ self.assertEqual(op.get_attr("int_list_val"), [1, 2, 3])
+ self.assertEqual(op.get_attr("float_val"), 10.0)
+ self.assertEqual(op.get_attr("float_list_val"), [10.0])
+ self.assertEqual(op.get_attr("bool_val"), True)
+ self.assertEqual(op.get_attr("bool_list_val"), [True, False])
+ self.assertEqual(op.get_attr("shape_val"),
+ tensor_shape.as_shape([2, 1]).as_proto())
+ self.assertEqual(op.get_attr("shape_list_val"),
+ [tensor_shape.as_shape([]).as_proto(),
+ tensor_shape.as_shape([1]).as_proto()])
+ self.assertEqual(op.get_attr("tensor_val"),
+ tensor_util.make_tensor_proto(1, dtypes.int32))
+ self.assertEqual(op.get_attr("tensor_list_val"),
+ [tensor_util.make_tensor_proto(1, dtypes.int32)])
+
+ type_val = op.get_attr("type_val")
+ # First check that type_val is a DType, because the assertEquals will work
+ # no matter what since DType overrides __eq__
+ self.assertIsInstance(type_val, dtypes.DType)
+ self.assertEqual(type_val, dtypes.int32)
+
+ type_list_val = op.get_attr("type_list_val")
+ self.assertTrue(all(isinstance(x, dtypes.DType) for x in type_list_val))
+ self.assertEqual(type_list_val, [dtypes.int32, dtypes.float32])
+
+ @function.Defun(dtypes.float32, func_name="MyFunc")
+ def func(x):
+ return x
+
+ op = test_ops.func_attr(func)
+ self.assertEqual(op.get_attr("f"),
+ attr_value_pb2.NameAttrList(name="MyFunc"))
+
+ # Try fetching missing attr
+ with self.assertRaisesRegexp(
+ ValueError, "Operation 'FuncAttr' has no attr named 'FakeAttr'."):
+ op.get_attr("FakeAttr")
+
+ # TODO(b/65162920): remove this test when users who are directly mutating the
+ # node_def have been updated to proper usage.
+ @test_util.run_deprecated_v1
+ def testSetAttr(self):
+ op = test_ops.int_attr().op
+ op._set_attr("foo", attr_value_pb2.AttrValue(i=2))
+ # TODO(skyewm): add node_def check
+ self.assertEqual(op.get_attr("foo"), 2)
+
+ # TODO(nolivia): test all error cases
+ def testAddControlInput(self):
+ with ops.Graph().as_default():
+ x = constant_op.constant(1).op
+ y = constant_op.constant(2).op
+ z = constant_op.constant(3).op
+ z._add_control_input(x) # pylint: disable=protected-access
+ self.assertEqual(z.control_inputs, [x])
+ z._add_control_input(x) # pylint: disable=protected-access
+ self.assertEqual(z.control_inputs, [x])
+ z._add_control_inputs([x, y, y]) # pylint: disable=protected-access
+ self.assertEqual(z.control_inputs, [x, y])
+ self.assertEqual(x._control_outputs, [z])
+
+ @test_util.run_deprecated_v1
+ def testRemoveAllControlInputs(self):
+ a = constant_op.constant(1)
+ with ops.control_dependencies([a]):
+ b = constant_op.constant(2)
+ c = constant_op.constant(3)
+ d = constant_op.constant(4)
+ e = constant_op.constant(5)
+ with ops.control_dependencies([a, c]):
+ f = d + e
+
+ self.assertEqual(a.op.control_inputs, [])
+ self.assertEqual(b.op.control_inputs, [a.op])
+ self.assertEqual(f.op.control_inputs, [a.op, c.op])
+
+ a.op._remove_all_control_inputs() # pylint: disable=protected-access
+ self.assertEqual(a.op.control_inputs, [])
+
+ b.op._remove_all_control_inputs() # pylint: disable=protected-access
+ self.assertEqual(b.op.control_inputs, [])
+
+ f.op._remove_all_control_inputs() # pylint: disable=protected-access
+ self.assertEqual(f.op.control_inputs, [])
+ self.assertEqual(list(f.op.inputs), [d, e])
+
+ @test_util.run_deprecated_v1
+ def testControlInputCycle(self):
+ graph = ops.Graph()
+ with graph.as_default():
+ z = constant_op.constant(0)
+ x = constant_op.constant(1)
+ y = constant_op.constant(2)
+ y.op._add_control_input(z.op) # pylint: disable=protected-access
+ y.op._add_control_input(x.op) # pylint: disable=protected-access
+ x.op._add_control_input(y.op) # pylint: disable=protected-access
+ with self.session(graph=graph) as sess:
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "Graph is invalid, contains a cycle with 2 nodes"):
+ self.evaluate(x)
+
+ def testUpdateInput(self):
+ g = ops.Graph()
+ with g.as_default():
+ x = constant_op.constant(1)
+ y = constant_op.constant(2)
+ z = x + y
+
+ z.op._update_input(0, y) # pylint: disable=protected-access
+ self.assertEquals(list(z.op.inputs), [y, y])
+ self.assertEquals(x.consumers(), [])
+ self.assertEquals(y.consumers(), [z.op, z.op])
+ with session.Session(graph=g) as sess:
+ self.assertEquals(self.evaluate(z), 4)
+
+ z.op._update_input(0, x) # pylint: disable=protected-access
+ self.assertEquals(list(z.op.inputs), [x, y])
+ self.assertEquals(x.consumers(), [z.op])
+ self.assertEquals(y.consumers(), [z.op])
+ with session.Session(graph=g) as sess:
+ self.assertEquals(self.evaluate(z), 3)
+
+ z.op._update_input(1, y) # pylint: disable=protected-access
+ self.assertEquals(list(z.op.inputs), [x, y])
+ self.assertEquals(x.consumers(), [z.op])
+ self.assertEquals(y.consumers(), [z.op])
+ with session.Session(graph=g) as sess:
+ self.assertEquals(self.evaluate(z), 3)
+
+ def testUpdateInputGraphError(self):
+ g_0 = ops.Graph()
+ g_1 = ops.Graph()
+ with g_0.as_default():
+ x = constant_op.constant(1)
+ with g_1.as_default():
+ y = constant_op.constant(2)
+ z = y * 2
+ with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
+ z.op._update_input(0, x) # pylint: disable=protected-access
+
+ def testUpdateInputTypeError(self):
+ g = ops.Graph()
+ with g.as_default():
+ w = constant_op.constant(0)
+ x = constant_op.constant("")
+ y = constant_op.constant(1)
+ z = y + w
+ z.op._update_input(0, x) # pylint: disable=protected-access
+ with session.Session(graph=g) as sess:
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "Input 0 of node add was passed string from Const_1:0 incompatible "
+ "with expected int32"):
+ self.evaluate(z)
+
+ def testUpdateInputShapeError(self):
+ g = ops.Graph()
+ with g.as_default():
+ w = constant_op.constant(2, shape=[3, 1])
+ x = constant_op.constant(0, shape=[3, 1])
+ y = constant_op.constant(1, shape=[2, 2])
+ z = w + x
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ r"Cannot update edge, incompatible shapes: \[2,2\] and \[3,1\]"):
+ z.op._update_input(0, y) # pylint: disable=protected-access
+
+ def testUpdateInputOutOfRange(self):
+ g = ops.Graph()
+ with g.as_default():
+ x = constant_op.constant(1)
+ with self.assertRaisesRegexp(
+ errors.OutOfRangeError,
+ r"Cannot update edge. Input index \[1\] is greater than the number of "
+ r"total inputs \[0\]."
+ ):
+ x.op._update_input(1, x) # pylint: disable=protected-access
+
+ @test_util.enable_control_flow_v2
+ @test_util.run_v1_only("b/120545219")
+ def testAddWhileInput(self):
+ @eager_function.defun
+ def test():
+ output = control_flow_ops.while_loop(lambda x: x < 3, lambda x: x + 1,
+ [1])
+ while_op = output.op.inputs[0].op
+ self.assertEqual(while_op.type, "While")
+ orig_num_inputs = len(while_op.inputs)
+
+ # Make sure we can handle the while op having a control input.
+ while_op._add_control_input(constant_op.constant(0).op)
+
+ new_input1 = constant_op.constant(1.0)
+ new_input2 = constant_op.constant(True)
+
+ while_op._set_type_list_attr("T",
+ [t.dtype for t in while_op.inputs] +
+ [new_input1.dtype, new_input2.dtype])
+
+ while_op._add_while_inputs([new_input1, new_input2])
+ # Can't add an edge beyond what's specified by "T"
+ with self.assertRaises(errors.OutOfRangeError):
+ while_op._add_while_inputs([new_input2])
+ self.assertEqual(len(while_op.inputs), orig_num_inputs + 2) # pylint: disable=g-deprecated-assert
+
+ test()
+
+ @test_util.run_deprecated_v1
+ def testOpDef(self):
+ x = constant_op.constant(0)
+ y = constant_op.constant(1)
+ z = x + y
+
+ self.assertEqual(x.op.op_def.name, "Const")
+ self.assertEqual(len(x.op.op_def.input_arg), 0)
+ self.assertEqual(len(x.op.op_def.output_arg), 1)
+
+ self.assertEqual(z.op.op_def.name, "Add")
+ self.assertEqual(len(z.op.op_def.input_arg), 2)
+ self.assertEqual(len(z.op.op_def.output_arg), 1)
+
+ def testInputFromDifferentGraphError(self):
+ g_0 = ops.Graph()
+ g_1 = ops.Graph()
+ with g_0.as_default():
+ x = constant_op.constant(1)
+ with g_1.as_default():
+ y = constant_op.constant(2)
+ with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
+ y * x # pylint: disable=pointless-statement
+
+ def testInputsAreImmutable(self):
+ g = ops.Graph()
+ with g.as_default():
+ x = test_ops.int_output()
+ op = test_ops.int_input_int_output(x, name="myop").op
+ with self.assertRaisesRegexp(
+ AttributeError, "'_InputList' object has no attribute 'append'"):
+ op.inputs.append(None)
+
+
+class CreateOpTest(test_util.TensorFlowTestCase):
+
+ def testNodeDefArgs(self):
+ g = ops.Graph()
+ op1 = g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1")
+ with g.device("/device:GPU:0"):
+ op2 = g.create_op(
+ "FloatOutputStringOutput", [], [dtypes.float32, dtypes.string], None,
+ name="myop2")
+ op3 = g.create_op(
+ "Foo3",
+ [list(op1.values())[0], list(op2.values())[1], list(op2.values())[0]],
+ [dtypes.float32, dtypes.int32],
+ None,
+ name="myop3")
+ self.assertDeviceEqual(None, op1.device)
+ self.assertDeviceEqual("/device:GPU:0", op2.device)
+ self.assertDeviceEqual(None, op3.device)
+ self.assertProtoEquals("name:'myop1' op:'FloatOutput'", op1.node_def)
+ self.assertProtoEquals(
+ "name:'myop2' op:'FloatOutputStringOutput' device:'/device:GPU:0'",
+ op2.node_def)
+ self.assertProtoEquals(
+ "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo3'",
+ op3.node_def)
+
+ def testReferenceInput(self):
+ g = ops.Graph()
+ op1 = g.create_op(
+ "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32],
+ name="op1")
+ self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def)
+ ref_t, nonref_t = op1.values()
+ # NOTE(mrry): Must specify input_types to preserve ref-typed input.
+ op2 = g.create_op(
+ "RefInputFloatInput", [ref_t, nonref_t], [],
+ input_types=[dtypes.float32_ref, dtypes.float32],
+ name="op2")
+ self.assertProtoEquals(
+ "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'",
+ op2.node_def)
+ op3 = g.create_op("TwoFloatInputs", [ref_t, nonref_t], [], name="op3")
+ self.assertProtoEquals(
+ "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'",
+ op3.node_def)
+
+ def testFinalized(self):
+ g = ops.Graph()
+ g.finalize()
+ with self.assertRaises(RuntimeError):
+ g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1")
+
+ # Test unfinalize.
+ g._unsafe_unfinalize()
+ g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1")
+
+
+# NOTE(skyewm): these cases test the private Graph._create_op_from_tf_operation
+# method. Arguably we should only test the public APIs that depend on this
+# method. However, this logic is complex and tricky, and it can be difficult to
+# ascertain if we have adequate coverage (e.g. a graph may run successfully if
+# the control flow context isn't set properly, but a more complicated use case
+# that might not be obvious to test will fail). Thus we instead explicitly test
+# the low-level behavior.
+class CreateOpFromTFOperationTest(test_util.TensorFlowTestCase):
+
+ @test_util.run_deprecated_v1
+ def testBasic(self):
+ g = ops.Graph()
+ with g.as_default():
+ x = test_ops.int_output()
+ c_op = ops._create_c_op(
+ g, ops._NodeDef("IntInputIntOutput", "myop"), [x], [])
+ op = g._create_op_from_tf_operation(c_op)
+
+ self.assertEqual(op.name, "myop")
+ self.assertEqual(op.type, "IntInputIntOutput")
+ self.assertEqual(len(op.outputs), 1)
+ self.assertEqual(op.outputs[0].shape, tensor_shape.unknown_shape())
+ self.assertEqual(list(op.inputs), [x])
+ self.assertEqual(op.control_inputs, [])
+ self.assertEqual(op.graph, g)
+ self.assertEqual(x.consumers(), [op])
+ self.assertIsNotNone(op.traceback)
+ self.assertEqual(g.get_operation_by_name("myop"), op)
+ self.assertEqual(g.get_tensor_by_name("myop:0"), op.outputs[0])
+
+ def testShape(self):
+ g = ops.Graph()
+ with g.as_default():
+ x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
+ c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), [x], [])
+ op = g._create_op_from_tf_operation(c_op)
+
+ self.assertEqual(op.name, "myop")
+ self.assertEqual(op.type, "Identity")
+ self.assertEqual(len(op.outputs), 1)
+ self.assertEqual(op.outputs[0].shape, tensor_shape.matrix(2, 3))
+
+ def testUniqueName(self):
+ g = ops.Graph()
+ with g.as_default():
+ c_op = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop"), [], [])
+ c_op2 = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop_1"), [], [])
+ op = g._create_op_from_tf_operation(c_op)
+ op2 = g._create_op_from_tf_operation(c_op2)
+
+ # Create ops with same names as op1 and op2. We expect the new names to be
+ # uniquified.
+ op3 = test_ops.int_output(name="myop").op
+ op4 = test_ops.int_output(name="myop_1").op
+
+ self.assertEqual(op.name, "myop")
+ self.assertEqual(op2.name, "myop_1")
+ self.assertEqual(op3.name, "myop_2")
+ self.assertEqual(op4.name, "myop_1_1")
+
+ @test_util.run_v1_only("b/120545219")
+ def testCond(self):
+ g = ops.Graph()
+ with g.as_default():
+ x = test_ops.int_output()
+
+ def true_fn():
+ ops._create_c_op(ops.get_default_graph(),
+ ops._NodeDef("IntInput", "cond/myop"), [x], [])
+ new_ops = g._add_new_tf_operations()
+ self.assertEqual(len(new_ops), 1)
+ return x
+
+ control_flow_ops.cond(x < 10, true_fn, lambda: x)
+
+ op = g.get_operation_by_name("cond/myop")
+ self.assertIsNotNone(op)
+ self.assertEqual(op.name, "cond/myop")
+ self.assertEqual(op.type, "IntInput")
+ self.assertEqual(op.outputs, [])
+ op_input = op.inputs[0].op
+ self.assertEqual(op_input.type, "Switch")
+ self.assertEqual(op_input.inputs[0], x)
+ self.assertEqual(op.graph, g)
+ # pylint: disable=protected-access
+ self.assertIsNotNone(op._get_control_flow_context())
+ self.assertEqual(op._get_control_flow_context().name,
+ "cond/cond_text")
+ # pylint: enable=protected-access
+
+ @test_util.run_v1_only("b/120545219")
+ def testWhileLoop(self):
+ g = ops.Graph()
+ with g.as_default():
+ x = test_ops.int_output()
+
+ def body(i):
+ ops._create_c_op(ops.get_default_graph(),
+ ops._NodeDef("IntInput", "myloop/myop"), [x], [])
+ new_ops = g._add_new_tf_operations()
+ self.assertEqual(len(new_ops), 1)
+ return i
+
+ control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
+
+ op = g.get_operation_by_name("myloop/myop")
+ self.assertIsNotNone(op)
+ self.assertEqual(op.name, "myloop/myop")
+ self.assertEqual(op.type, "IntInput")
+ self.assertEqual(op.outputs, [])
+ op_input = op.inputs[0].op
+ self.assertEqual(op_input.type, "Enter")
+ self.assertEqual(list(op_input.inputs), [x])
+ self.assertEqual(op.graph, g)
+ # pylint: disable=protected-access
+ self.assertIsNotNone(op._get_control_flow_context())
+ self.assertEqual(op._get_control_flow_context().name,
+ "myloop/while_context")
+ # pylint: enable=protected-access
+
+ @test_util.run_v1_only("b/120545219")
+ def testWhileLoopWithInternalControlDep(self):
+ g = ops.Graph()
+ with g.as_default():
+ x = test_ops.int_output()
+
+ def body(i):
+ c = constant_op.constant(1.0, name="c")
+ ops._create_c_op(ops.get_default_graph(),
+ ops._NodeDef("IntInput", "myloop/myop"), [x], [])
+ with ops.control_dependencies([c]):
+ new_ops = g._add_new_tf_operations()
+ self.assertEqual(len(new_ops), 1)
+ return i
+
+ control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
+
+ op = g.get_operation_by_name("myloop/myop")
+ self.assertIsNotNone(op)
+ c = g.get_operation_by_name("myloop/c")
+ self.assertIsNotNone(c)
+ # Internal control dep is preserved
+ self.assertEqual(op.control_inputs, [c])
+
+ @test_util.run_v1_only("b/120545219")
+ def testWhileLoopWithExternalControlDep(self):
+ g = ops.Graph()
+ with g.as_default():
+ x = test_ops.int_output()
+ c = constant_op.constant(1.0)
+
+ def body(i):
+ ops._create_c_op(ops.get_default_graph(),
+ ops._NodeDef("IntInput", "myloop/myop"), [x], [])
+ with ops.control_dependencies([c]):
+ new_ops = g._add_new_tf_operations()
+ self.assertEqual(len(new_ops), 1)
+ return i
+
+ control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
+
+ op = g.get_operation_by_name("myloop/myop")
+ self.assertIsNotNone(op)
+ # External control dep is removed and replaced with internal control dep
+ self.assertNotEqual(op.control_inputs[0], c.op)
+ self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context())
+
+
+class ApplyOpTest(test_util.TensorFlowTestCase):
+
+ def testNodeDefArgs(self):
+ g = ops.Graph()
+ t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1")
+ with g.device("/device:GPU:0"):
+ t2 = _apply_op(
+ g, "TwoIntOutputs", [], [dtypes.int32, dtypes.int32], name="myop2")
+ t3 = _apply_op(
+ g,
+ "Foo1", [t1, t2[1], t2[0]], [dtypes.float32, dtypes.int32],
+ name="myop3")
+ self.assertTrue(isinstance(t1, ops.Tensor))
+ self.assertTrue(isinstance(t2, list))
+ self.assertTrue(isinstance(t3, list))
+ self.assertTrue(isinstance(t3[0], ops.Tensor))
+ self.assertEqual("myop1", t1._as_node_def_input())
+ self.assertEqual("myop2", t2[0]._as_node_def_input())
+ self.assertEqual("myop2:1", t2[1]._as_node_def_input())
+ self.assertEqual("myop3", t3[0]._as_node_def_input())
+ # Validate that we got the right ops as well
+ self.assertProtoEquals("name:'myop1' op:'FloatOutput'", t1.op.node_def)
+ self.assertProtoEquals(
+ "name:'myop2' op:'TwoIntOutputs' device:'/device:GPU:0'",
+ t2[0].op.node_def)
+ self.assertProtoEquals(
+ "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo1'",
+ t3[0].op.node_def)
+
+ def testReferenceInput(self):
+ g = ops.Graph()
+ ref_t, nonref_t = _apply_op(
+ g, "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32],
+ name="op1")
+ self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'",
+ ref_t.op.node_def)
+ # NOTE(mrry): Must specify input_types to preserve ref-typed input.
+ out_2 = _apply_op(
+ g,
+ "RefInputFloatInputIntOutput", [ref_t, nonref_t], [dtypes.int32],
+ input_types=[dtypes.float32_ref, dtypes.float32],
+ name="op2")
+ self.assertProtoEquals(
+ "op:'RefInputFloatInputIntOutput' name:'op2' input:'op1' input:'op1:1'",
+ out_2.op.node_def)
+ out_3 = _apply_op(
+ g, "TwoFloatInputsIntOutput", [ref_t, nonref_t], [dtypes.int32],
+ name="op3")
+ self.assertProtoEquals(
+ "op:'TwoFloatInputsIntOutput' name:'op3' input:'op1' input:'op1:1'",
+ out_3.op.node_def)
+
+
+class NameStackTest(test_util.TensorFlowTestCase):
+
+ def testBasics(self):
+ g = ops.Graph()
+ self.assertEqual("foo", g.unique_name("foo", mark_as_used=False))
+ self.assertEqual("foo", g.unique_name("foo", mark_as_used=False))
+ self.assertEqual("foo", g.unique_name("foo"))
+ self.assertEqual("foo_1", g.unique_name("foo", mark_as_used=False))
+ self.assertEqual("foo_1", g.unique_name("foo"))
+ self.assertEqual("foo_2", g.unique_name("foo", mark_as_used=False))
+ self.assertEqual("foo_2", g.unique_name("foo"))
+ self.assertEqual("foo_1_1", g.unique_name("foo_1", mark_as_used=False))
+ self.assertEqual("foo_1_1", g.unique_name("foo_1"))
+ self.assertEqual("foo_1_2", g.unique_name("foo_1", mark_as_used=False))
+ self.assertEqual("foo_1_2", g.unique_name("foo_1"))
+ self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2", mark_as_used=False))
+ self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2"))
+ with g.name_scope("bar"):
+ self.assertEqual("bar/foo", g.unique_name("foo", mark_as_used=False))
+ self.assertEqual("bar/foo", g.unique_name("foo"))
+ self.assertEqual("bar/foo_1", g.unique_name("foo", mark_as_used=False))
+ self.assertEqual("bar/foo_1", g.unique_name("foo"))
+ with g.name_scope(None):
+ self.assertEqual("foo_3", g.unique_name("foo", mark_as_used=False))
+ self.assertEqual("foo_3", g.unique_name("foo"))
+ with g.name_scope("baz"):
+ self.assertEqual(
+ "bar/baz/foo", g.unique_name(
+ "foo", mark_as_used=False))
+ self.assertEqual("bar/baz/foo", g.unique_name("foo"))
+ self.assertEqual(
+ "bar/baz/foo_1", g.unique_name(
+ "foo", mark_as_used=False))
+ self.assertEqual("bar/baz/foo_1", g.unique_name("foo"))
+ with g.name_scope("baz"):
+ self.assertEqual(
+ "bar/baz_1/foo", g.unique_name(
+ "foo", mark_as_used=False))
+ self.assertEqual("bar/baz_1/foo", g.unique_name("foo"))
+ self.assertEqual(
+ "bar/baz_1/foo_1", g.unique_name(
+ "foo", mark_as_used=False))
+ self.assertEqual("bar/baz_1/foo_1", g.unique_name("foo"))
+ with g.name_scope("quux"):
+ self.assertEqual("quux/foo", g.unique_name("foo", mark_as_used=False))
+ self.assertEqual("quux/foo", g.unique_name("foo"))
+ with g.name_scope("bar"):
+ with g.name_scope("baz"):
+ self.assertEqual(
+ "bar_1/baz/foo", g.unique_name(
+ "foo", mark_as_used=False))
+ self.assertEqual("bar_1/baz/foo", g.unique_name("foo"))
+ self.assertEqual("foo_4", g.unique_name("foo", mark_as_used=False))
+ self.assertEqual("foo_4", g.unique_name("foo"))
+ self.assertEqual("bar_2", g.unique_name("bar", mark_as_used=False))
+ self.assertEqual("bar_2", g.unique_name("bar"))
+
+ @test_util.run_deprecated_v1
+ def testNameAndVariableScope(self):
+ with self.cached_session() as sess:
+ with sess.graph.name_scope("l0"):
+ with variable_scope.variable_scope("l1"):
+ with sess.graph.name_scope("l1") as scope:
+ self.assertEqual("l0/l1/l1/", scope)
+ self.assertEqual(
+ "l0/l1/l1/foo",
+ sess.graph.unique_name(
+ "foo", mark_as_used=False))
+ self.assertEqual("l0/l1/l1/foo", sess.graph.unique_name("foo"))
+ with sess.graph.name_scope("l2") as scope:
+ self.assertEqual("l0/l1/l2/", scope)
+ self.assertEqual(
+ "l0/l1/l2/foo",
+ sess.graph.unique_name(
+ "foo", mark_as_used=False))
+ self.assertEqual("l0/l1/l2/foo", sess.graph.unique_name("foo"))
+
+ def testOutOfOrderUniqueName(self):
+ g = ops.Graph()
+ self.assertEqual("foo_2", g.unique_name("foo_2"))
+ self.assertEqual("foo", g.unique_name("foo"))
+ self.assertEqual("foo_1", g.unique_name("foo"))
+ self.assertEqual("foo_3", g.unique_name("foo"))
+
+ def testUniqueNameCaseInsensitivity(self):
+ g = ops.Graph()
+ self.assertEqual("foo", g.unique_name("foo"))
+ self.assertEqual("Foo_1", g.unique_name("Foo"))
+ with g.name_scope("bar"):
+ self.assertEqual("bar/foo", g.unique_name("foo"))
+ with g.name_scope("Bar"):
+ self.assertEqual("Bar_1/foo", g.unique_name("foo"))
+
+ def testInvalidNameRaisesError(self):
+ g = ops.Graph()
+ with g.name_scope(""): # Should not raise
+ pass
+ with g.name_scope("foo/"): # Should not raise
+ with g.name_scope("_bar"): # Should not raise
+ pass
+ with self.assertRaises(ValueError):
+ with g.name_scope("foo:0"):
+ pass
+ with self.assertRaises(ValueError):
+ with g.name_scope("_bar"):
+ pass
+
+
+class NameTest(test_util.TensorFlowTestCase):
+
+ def testGenerateName(self):
+ g = ops.Graph()
+ op0 = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32])
+ self.assertEqual("TwoFloatOutputs", op0.name)
+ self.assertEqual("TwoFloatOutputs:0", op0.outputs[0].name)
+ self.assertEqual("TwoFloatOutputs:1", op0.outputs[1].name)
+
+ op1 = g.create_op("FloatOutput", [], [dtypes.float32])
+ self.assertEqual("FloatOutput", op1.name)
+ self.assertEqual("FloatOutput:0", op1.outputs[0].name)
+
+ op2 = g.create_op("FloatOutput", [], [dtypes.float32])
+ self.assertEqual("FloatOutput_1", op2.name)
+ self.assertEqual("FloatOutput_1:0", op2.outputs[0].name)
+
+ op3 = g.create_op("FloatOutput", [], [dtypes.float32], name="my_op")
+ self.assertEqual("my_op", op3.name)
+ self.assertEqual("my_op:0", op3.outputs[0].name)
+
+ def testNameScope(self):
+ g = ops.Graph()
+
+ with g.name_scope("foo") as foo:
+ self.assertEqual("foo/", foo)
+ with g.name_scope("foo2") as foo2:
+ self.assertEqual("foo/foo2/", foo2)
+ with g.name_scope(None) as empty1:
+ self.assertEqual("", empty1)
+ with g.name_scope("foo3") as foo3:
+ self.assertEqual("foo3/", foo3)
+ with g.name_scope("") as empty2:
+ self.assertEqual("", empty2)
+
+ self.assertEqual("FloatOutput",
+ g.create_op("FloatOutput", [], [dtypes.float32]).name)
+ with g.name_scope("bar") as scope:
+ self.assertEqual("bar/FloatOutput",
+ g.create_op("FloatOutput", [], [dtypes.float32]).name)
+ self.assertEqual("bar/FloatOutput_1",
+ g.create_op("FloatOutput", [], [dtypes.float32]).name)
+ # If you use the value from "with .. as", that values is used as-is.
+ self.assertEqual(
+ "bar", g.create_op(
+ "FloatOutput", [], [dtypes.float32], name=scope).name)
+ with g.name_scope("baz") as scope:
+ with g.name_scope("quux"):
+ self.assertEqual("baz/quux/FloatOutput",
+ g.create_op("FloatOutput", [], [dtypes.float32]).name)
+ # If you use the value from the enclosing "with .. as", nothing is pushed.
+ with g.name_scope(scope):
+ self.assertEqual("baz/FloatOutput",
+ g.create_op("FloatOutput", [], [dtypes.float32]).name)
+ self.assertEqual(
+ "baz", g.create_op(
+ "FloatOutput", [], [dtypes.float32], name=scope).name)
+ self.assertEqual(
+ "trailing",
+ g.create_op(
+ "FloatOutput", [], [dtypes.float32], name="trailing/").name)
+ with g.name_scope("bar"):
+ self.assertEqual("bar_1/FloatOutput",
+ g.create_op("FloatOutput", [], [dtypes.float32]).name)
+ with g.name_scope("bar/"):
+ self.assertEqual("bar/FloatOutput_2",
+ g.create_op("FloatOutput", [], [dtypes.float32]).name)
+
+
+class DeviceTest(test_util.TensorFlowTestCase):
+
+ def testNoDevice(self):
+ g = ops.Graph()
+ op = g.create_op("FloatOutput", [], [dtypes.float32])
+ self.assertDeviceEqual(None, op.device)
+ gd = g.as_graph_def()
+ self.assertProtoEqualsVersion("""
+ node { name: "FloatOutput" op: "FloatOutput" }
+ """, gd)
+
+ def testEagerBackingDevice(self):
+ with context.eager_mode():
+ with ops.device("/device:CPU:0"):
+ t = constant_op.constant(1.0)
+ self.assertRegexpMatches(t.device, "/device:CPU:0")
+ self.assertRegexpMatches(t.backing_device, "/device:CPU:0")
+
+ def testDevicePartialString(self):
+ g = ops.Graph()
+ with g.device("/job:worker/replica:2"):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ gd = g.as_graph_def()
+ self.assertProtoEqualsVersion("""
+ node { name: "FloatOutput" op: "FloatOutput"
+ device: "/job:worker/replica:2" }
+ """, gd)
+
+ def testDeviceFull(self):
+ g = ops.Graph()
+ with g.device(
+ pydev.DeviceSpec(
+ job="worker", replica=2, task=0, device_type="CPU",
+ device_index=3)):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ gd = g.as_graph_def()
+ self.assertProtoEqualsVersion("""
+ node { name: "FloatOutput" op: "FloatOutput"
+ device: "/job:worker/replica:2/task:0/device:CPU:3" }
+ """, gd)
+
+ def testNesting(self):
+ g = ops.Graph()
+ with g.device("/job:worker/replica:2"):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ with g.device("/job:worker/replica:3/task:0"):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ gd = g.as_graph_def()
+ self.assertProtoEqualsVersion("""
+ node { name: "FloatOutput" op: "FloatOutput"
+ device: "/job:worker/replica:2" }
+ node { name: "FloatOutput_1" op: "FloatOutput"
+ device: "/job:worker/replica:3/task:0" }
+ node { name: "FloatOutput_2" op: "FloatOutput"
+ device: "/job:worker/replica:2" }
+ """, gd)
+
+ def testNestingString(self):
+ g = ops.Graph()
+ with g.device("/job:worker/replica:2"):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ with g.device("/job:worker/replica:3/task:0"):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ gd = g.as_graph_def()
+ self.assertProtoEqualsVersion("""
+ node { name: "FloatOutput" op: "FloatOutput"
+ device: "/job:worker/replica:2" }
+ node { name: "FloatOutput_1" op: "FloatOutput"
+ device: "/job:worker/replica:3/task:0" }
+ node { name: "FloatOutput_2" op: "FloatOutput"
+ device: "/job:worker/replica:2" }
+ """, gd)
+
+ def testNestingOverrideGpuCpu(self):
+ g = ops.Graph()
+ with g.device("/job:worker/replica:2/device:CPU:1"):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ with g.device("/job:worker/replica:2/device:GPU:2"):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ gd = g.as_graph_def()
+ self.assertProtoEqualsVersion("""
+ node { name: "FloatOutput" op: "FloatOutput"
+ device: "/job:worker/replica:2/device:CPU:1" }
+ node { name: "FloatOutput_1" op: "FloatOutput"
+ device: "/job:worker/replica:2/device:GPU:2" }
+ node { name: "FloatOutput_2" op: "FloatOutput"
+ device: "/job:worker/replica:2/device:CPU:1" }
+ """, gd)
+
+ def testNestingWithMergeDeviceFunction(self):
+ g = ops.Graph()
+
+ with g.device(pydev.merge_device("/device:GPU:0")):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ with g.device(pydev.merge_device("/job:worker")):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ with g.device(pydev.merge_device("/device:CPU:0")):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ with g.device(pydev.merge_device("/job:ps")):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ with g.device(pydev.merge_device(None)):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+
+ gd = g.as_graph_def()
+ self.assertProtoEqualsVersion("""
+ node { name: "FloatOutput" op: "FloatOutput"
+ device: "/device:GPU:0" }
+ node { name: "FloatOutput_1" op: "FloatOutput"
+ device: "/job:worker/device:GPU:0" }
+ node { name: "FloatOutput_2" op: "FloatOutput"
+ device: "/job:worker/device:CPU:0" }
+ node { name: "FloatOutput_3" op: "FloatOutput"
+ device: "/job:ps/device:CPU:0" }
+ node { name: "FloatOutput_4" op: "FloatOutput"
+ device: "/job:ps/device:CPU:0" }
+ """, gd)
+
+ def testNestingWithDeviceStrings(self):
+ g = ops.Graph()
+
+ with g.device("/device:GPU:0"):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ with g.device("/job:worker"):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ with g.device("/device:CPU:0"):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ with g.device("/job:ps"):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ with g.device(""):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+
+ gd = g.as_graph_def()
+ self.assertProtoEqualsVersion("""
+ node { name: "FloatOutput" op: "FloatOutput"
+ device: "/device:GPU:0" }
+ node { name: "FloatOutput_1" op: "FloatOutput"
+ device: "/job:worker/device:GPU:0" }
+ node { name: "FloatOutput_2" op: "FloatOutput"
+ device: "/job:worker/device:CPU:0" }
+ node { name: "FloatOutput_3" op: "FloatOutput"
+ device: "/job:ps/device:CPU:0" }
+ node { name: "FloatOutput_4" op: "FloatOutput"
+ device: "/job:ps/device:CPU:0" }
+ """, gd)
+
+ def testNestingWithDeviceStringWildcard(self):
+ g = ops.Graph()
+
+ with g.device("/device:GPU:7"):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ with g.device("/device:GPU:*"):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+
+ with g.device("/device:CPU:*"):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ with g.device("/device:CPU:5"):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+
+ gd = g.as_graph_def()
+ self.assertProtoEqualsVersion("""
+ node { name: "FloatOutput" op: "FloatOutput"
+ device: "/device:GPU:7" }
+ node { name: "FloatOutput_1" op: "FloatOutput"
+ device: "/device:GPU:7" }
+ node { name: "FloatOutput_2" op: "FloatOutput"
+ device: "/device:CPU:*" }
+ node { name: "FloatOutput_3" op: "FloatOutput"
+ device: "/device:CPU:5" }
+ """, gd)
+
+ def testNoneClearsDefault(self):
+ g = ops.Graph()
+ with g.device("/job:worker/replica:2/device:CPU:1"):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ with g.device(None):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ gd = g.as_graph_def()
+ self.assertProtoEqualsVersion("""
+ node { name: "FloatOutput" op: "FloatOutput"
+ device: "/job:worker/replica:2/device:CPU:1" }
+ node { name: "FloatOutput_1" op: "FloatOutput" }
+ node { name: "FloatOutput_2" op: "FloatOutput"
+ device: "/job:worker/replica:2/device:CPU:1" }
+ """, gd)
+
+ def testNoneIgnoresOuterDeviceFunction(self):
+ g = ops.Graph()
+ with g.device(lambda op: "/job:worker/replica:2/device:CPU:1"):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ with g.device(None):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ gd = g.as_graph_def()
+ self.assertProtoEqualsVersion("""
+ node { name: "FloatOutput" op: "FloatOutput"
+ device: "/job:worker/replica:2/device:CPU:1" }
+ node { name: "FloatOutput_1" op: "FloatOutput" }
+ node { name: "FloatOutput_2" op: "FloatOutput"
+ device: "/job:worker/replica:2/device:CPU:1" }
+ """, gd)
+
+ def _overwritingDeviceFunction(self, unused_op):
+ # This device function unconditionally overwrites the device of ops.
+ #
+ # NOTE(mrry): Writing device functions like this is not
+ # recommended. Instead, in most cases you should use
+ # `pydev.merge_device("/job:ps")` or simply `"/job:ps"` as the
+ # argument to `tf.device()` and the device component will be merged in.
+ return "/job:overwrite"
+
+ def testOverwritingBehavior(self):
+ g = ops.Graph()
+ with g.device(self._overwritingDeviceFunction):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ with g.device("/job:ps"): # Will be overwritten.
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ with g.device(pydev.merge_device("/job:ps")): # Will be overwritten.
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ with g.device(None): # Disables overwriting device function
+ with g.device("/job:ps"):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ with g.device(None): # Disables overwriting device function
+ with g.device(pydev.merge_device("/job:ps")):
+ g.create_op("FloatOutput", [], [dtypes.float32])
+ gd = g.as_graph_def()
+ self.assertProtoEqualsVersion("""
+ node { name: "FloatOutput" op: "FloatOutput"
+ device: "/job:overwrite" }
+ node { name: "FloatOutput_1" op: "FloatOutput"
+ device: "/job:overwrite" }
+ node { name: "FloatOutput_2" op: "FloatOutput"
+ device: "/job:overwrite" }
+ node { name: "FloatOutput_3" op: "FloatOutput"
+ device: "/job:ps" }
+ node { name: "FloatOutput_4" op: "FloatOutput"
+ device: "/job:ps" }
+ """, gd)
+
+
+class MultithreadedGraphStateTest(test_util.TensorFlowTestCase):
+
+ class TestThread(threading.Thread):
+
+ def __init__(self, graph, replica_id):
+ super(MultithreadedGraphStateTest.TestThread, self).__init__()
+ self._graph = graph
+ self._replica_id = replica_id
+ # This thread sets this event when it mutated the graph. The caller can
+ # wait for that.
+ self.has_mutated_graph = threading.Event()
+ # This thread waits for when it should continue. The caller can set this
+ # event.
+ self.should_continue = threading.Event()
+
+ def run(self):
+ # Mutate a graph's stack, then set `has_mutated_graph`, then wait for
+ # `should_continue`, then add an op to the graph affected by the graph's
+ # stack.
+ raise NotImplementedError("must be implemented in descendants")
+
+ def testDeviceFunctionStack(self):
+
+ class DeviceSettingThread(self.TestThread):
+
+ def run(self):
+ with g.device("/job:worker/replica:{}".format(self._replica_id)):
+ self.has_mutated_graph.set()
+ self.should_continue.wait()
+ self.should_continue.clear()
+ g.create_op(
+ "FloatOutput", [], [dtypes.float32],
+ name="FloatOutput_{}".format(self._replica_id))
+
+ g = ops.Graph()
+ # If `switch_to_thread` isn't called, then device placement of the ops
+ # below is not deterministic.
+ g.switch_to_thread_local()
+ threads = [DeviceSettingThread(g, i) for i in range(3)]
+ for t in threads:
+ t.start()
+ t.has_mutated_graph.wait()
+ t.has_mutated_graph.clear()
+ for t in threads:
+ t.should_continue.set()
+ t.join()
+
+ gd = g.as_graph_def()
+ self.assertProtoEqualsVersion("""
+ node { name: "FloatOutput_0" op: "FloatOutput"
+ device: "/job:worker/replica:0" }
+ node { name: "FloatOutput_1" op: "FloatOutput"
+ device: "/job:worker/replica:1" }
+ node { name: "FloatOutput_2" op: "FloatOutput"
+ device: "/job:worker/replica:2" }
+ """, gd)
+
+ def testColocateWith(self):
+
+ class ColocatingThread(self.TestThread):
+
+ def __init__(self, graph, replica_id, op_to_colocate_with):
+ super(ColocatingThread, self).__init__(graph, replica_id)
+ self._op_to_colocate_with = op_to_colocate_with
+
+ def run(self):
+ with g.colocate_with(self._op_to_colocate_with):
+ self.has_mutated_graph.set()
+ self.should_continue.wait()
+ self.should_continue.clear()
+ g.create_op(
+ "FloatOutput", [], [dtypes.float32],
+ name="FloatOutput_{}".format(self._replica_id))
+
+ g = ops.Graph()
+ ops_to_colocate_with = []
+ for i in range(3):
+ with g.device("/job:worker/replica:{}".format(i)):
+ ops_to_colocate_with.append(
+ g.create_op(
+ "FloatOutput", [], [dtypes.float32],
+ name="ColocateWithMe_{}".format(i)))
+
+ # If `switch_to_thread` isn't called, then `device` and `attr` values for
+ # the ops below are not deterministic.
+ g.switch_to_thread_local()
+ threads = [
+ ColocatingThread(g, i, ops_to_colocate_with[i]) for i in range(3)
+ ]
+ for t in threads:
+ t.start()
+ t.has_mutated_graph.wait()
+ t.has_mutated_graph.clear()
+ for t in threads:
+ t.should_continue.set()
+ t.join()
+
+ gd = g.as_graph_def()
+ self.assertProtoEqualsVersion("""
+ node { name: "ColocateWithMe_0" op: "FloatOutput"
+ device: "/job:worker/replica:0" }
+ node { name: "ColocateWithMe_1" op: "FloatOutput"
+ device: "/job:worker/replica:1" }
+ node { name: "ColocateWithMe_2" op: "FloatOutput"
+ device: "/job:worker/replica:2" }
+ node { name: "FloatOutput_0" op: "FloatOutput"
+ device: "/job:worker/replica:0"
+ attr { key: "_class"
+ value { list {
+ s: "loc:@ColocateWithMe_0"}}}}
+ node { name: "FloatOutput_1" op: "FloatOutput"
+ device: "/job:worker/replica:1"
+ attr { key: "_class"
+ value { list {
+ s: "loc:@ColocateWithMe_1"}}}}
+ node { name: "FloatOutput_2" op: "FloatOutput"
+ device: "/job:worker/replica:2"
+ attr { key: "_class"
+ value { list {
+ s: "loc:@ColocateWithMe_2"}}}}
+ """, gd)
+
+ def testControlDependencies(self):
+
+ class DependingThread(self.TestThread):
+
+ def __init__(self, graph, replica_id, dependency_op):
+ super(DependingThread, self).__init__(graph, replica_id)
+ self._dependency_op = dependency_op
+
+ def run(self):
+ with g.control_dependencies([self._dependency_op]):
+ self.has_mutated_graph.set()
+ self.should_continue.wait()
+ self.should_continue.clear()
+ g.create_op(
+ "FloatOutput", [], [dtypes.float32],
+ name="FloatOutput_{}".format(self._replica_id))
+
+ g = ops.Graph()
+ dependency_ops = []
+ for i in range(3):
+ dependency_ops.append(
+ g.create_op(
+ "FloatOutput", [], [dtypes.float32],
+ name="ColocateWithMe_{}".format(i)))
+
+ # If `switch_to_thread` isn't called, then `input` values for the ops below
+ # are not deterministic.
+ g.switch_to_thread_local()
+ threads = [DependingThread(g, i, dependency_ops[i]) for i in range(3)]
+ for t in threads:
+ t.start()
+ t.has_mutated_graph.wait()
+ t.has_mutated_graph.clear()
+ for t in threads:
+ t.should_continue.set()
+ t.join()
+
+ gd = g.as_graph_def()
+ self.assertProtoEqualsVersion("""
+ node { name: "ColocateWithMe_0" op: "FloatOutput" }
+ node { name: "ColocateWithMe_1" op: "FloatOutput" }
+ node { name: "ColocateWithMe_2" op: "FloatOutput" }
+ node { name: "FloatOutput_0" op: "FloatOutput"
+ input: "^ColocateWithMe_0" }
+ node { name: "FloatOutput_1" op: "FloatOutput"
+ input: "^ColocateWithMe_1" }
+ node { name: "FloatOutput_2" op: "FloatOutput"
+ input: "^ColocateWithMe_2" }
+ """, gd)
+
+ def testNameStack(self):
+
+ class NameSettingThread(self.TestThread):
+
+ def run(self):
+ with g.name_scope("foo"):
+ op1 = g.create_op("FloatOutput", [], [dtypes.float32])
+ self.has_mutated_graph.set()
+ self.should_continue.wait()
+ self.should_continue.clear()
+ op2 = g.create_op("FloatOutput", [], [dtypes.float32])
+ self.result = (op1, op2)
+
+ g = ops.Graph()
+ threads = [NameSettingThread(g, i) for i in range(3)]
+ for t in threads:
+ t.start()
+ t.has_mutated_graph.wait()
+ t.has_mutated_graph.clear()
+
+ for t in threads:
+ t.should_continue.set()
+ t.join()
+
+ suffixes = ["", "_1", "_2"]
+ for t, s in zip(threads, suffixes):
+ self.assertEquals("foo" + s + "/FloatOutput", t.result[0].name)
+ self.assertEquals("foo" + s + "/FloatOutput_1", t.result[1].name)
+
+
+class ObjectWithName(object):
+
+ def __init__(self, name):
+ self._name = name
+
+ @property
+ def name(self):
+ return self._name
+
+
+class CollectionTest(test_util.TensorFlowTestCase):
+
+ def test_get_collections(self):
+ g = ops.Graph()
+ self.assertSequenceEqual(g.collections, [])
+ g.add_to_collection("key", 12)
+ g.add_to_collection("key", 15)
+ self.assertSequenceEqual(g.collections, ["key"])
+ g.add_to_collection("other", "foo")
+ self.assertSequenceEqual(sorted(g.collections), ["key", "other"])
+
+ def test_add_to_collection(self):
+ g = ops.Graph()
+ g.add_to_collection("key", 12)
+ g.add_to_collection("other", "foo")
+ g.add_to_collection("key", 34)
+
+ # Note that only blank1 is returned.
+ g.add_to_collection("blah", 27)
+ blank1 = ObjectWithName("prefix/foo")
+ g.add_to_collection("blah", blank1)
+ blank2 = ObjectWithName("junk/foo")
+ g.add_to_collection("blah", blank2)
+
+ self.assertEqual([12, 34], g.get_collection("key"))
+ self.assertEqual([], g.get_collection("nothing"))
+ self.assertEqual([27, blank1, blank2], g.get_collection("blah"))
+ self.assertEqual([blank1], g.get_collection("blah", "prefix"))
+ self.assertEqual([blank1], g.get_collection("blah", ".*x"))
+
+ # Make sure that get_collection() returns a first-level
+ # copy of the collection, while get_collection_ref() returns
+ # the original list.
+ other_collection_snapshot = g.get_collection("other")
+ other_collection_ref = g.get_collection_ref("other")
+ self.assertEqual(["foo"], other_collection_snapshot)
+ self.assertEqual(["foo"], other_collection_ref)
+ g.add_to_collection("other", "bar")
+ self.assertEqual(["foo"], other_collection_snapshot)
+ self.assertEqual(["foo", "bar"], other_collection_ref)
+ self.assertEqual(["foo", "bar"], g.get_collection("other"))
+ self.assertTrue(other_collection_ref is g.get_collection_ref("other"))
+
+ # Verify that getting an empty collection ref returns a modifiable list.
+ empty_coll_ref = g.get_collection_ref("empty")
+ self.assertEqual([], empty_coll_ref)
+ empty_coll = g.get_collection("empty")
+ self.assertEqual([], empty_coll)
+ self.assertFalse(empty_coll is empty_coll_ref)
+ empty_coll_ref2 = g.get_collection_ref("empty")
+ self.assertTrue(empty_coll_ref2 is empty_coll_ref)
+ # Add to the collection.
+ empty_coll_ref.append("something")
+ self.assertEqual(["something"], empty_coll_ref)
+ self.assertEqual(["something"], empty_coll_ref2)
+ self.assertEqual([], empty_coll)
+ self.assertEqual(["something"], g.get_collection("empty"))
+ empty_coll_ref3 = g.get_collection_ref("empty")
+ self.assertTrue(empty_coll_ref3 is empty_coll_ref)
+
+ def test_add_to_collections_uniquify(self):
+ g = ops.Graph()
+ g.add_to_collections([1, 2, 1], "key")
+ # Make sure "key" is not added twice
+ self.assertEqual(["key"], g.get_collection(1))
+
+ def test_add_to_collections_from_list(self):
+ g = ops.Graph()
+ g.add_to_collections(["abc", "123"], "key")
+ self.assertEqual(["key"], g.get_collection("abc"))
+ self.assertEqual(["key"], g.get_collection("123"))
+
+ def test_add_to_collections_from_tuple(self):
+ g = ops.Graph()
+ g.add_to_collections(("abc", "123"), "key")
+ self.assertEqual(["key"], g.get_collection("abc"))
+ self.assertEqual(["key"], g.get_collection("123"))
+
+ def test_add_to_collections_from_generator(self):
+ g = ops.Graph()
+
+ def generator():
+ yield "abc"
+ yield "123"
+
+ g.add_to_collections(generator(), "key")
+ self.assertEqual(["key"], g.get_collection("abc"))
+ self.assertEqual(["key"], g.get_collection("123"))
+
+ def test_add_to_collections_from_set(self):
+ g = ops.Graph()
+ g.add_to_collections(set(["abc", "123"]), "key")
+ self.assertEqual(["key"], g.get_collection("abc"))
+ self.assertEqual(["key"], g.get_collection("123"))
+
+ def test_add_to_collections_from_string(self):
+ g = ops.Graph()
+ g.add_to_collections("abc", "key")
+ self.assertEqual(["key"], g.get_collection("abc"))
+
+ def test_default_graph(self):
+ with ops.Graph().as_default():
+ ops.add_to_collection("key", 90)
+ ops.add_to_collection("key", 100)
+ # Collections are ordered.
+ self.assertEqual([90, 100], ops.get_collection("key"))
+
+ def test_defun(self):
+ with context.eager_mode():
+
+ @eager_function.defun
+ def defun():
+ ops.add_to_collection("int", 1)
+ ops.add_to_collection("tensor", constant_op.constant(2))
+
+ @eager_function.defun
+ def inner_defun():
+ self.assertEqual(ops.get_collection("int"), [1])
+ three = ops.get_collection("tensor")[0] + ops.get_collection("int")[0]
+ ops.add_to_collection("int", 2)
+ self.assertEqual(ops.get_collection("int"), [1, 2])
+ ops.add_to_collection("foo", "bar")
+ self.assertEqual(ops.get_collection("foo"), ["bar"])
+ return three
+
+ self.assertEqual(ops.get_collection("int"), [1])
+ three = inner_defun()
+ self.assertEqual(ops.get_collection("int"), [1])
+ self.assertEqual(ops.get_collection("foo"), [])
+ return three
+
+ three = defun()
+ self.assertEqual(three.numpy(), 3)
+
+
+ops.NotDifferentiable("FloatOutput")
+
+
+@ops.RegisterGradient("CopyOp")
+def _CopyGrad(op, x_grad): # pylint: disable=invalid-name
+ _ = op
+ return x_grad
+
+
+@ops.RegisterGradient("copy_override")
+def _CopyOverrideGrad(op, x_grad): # pylint: disable=invalid-name
+ _ = op
+ return x_grad
+
+
+class RegistrationTest(test_util.TensorFlowTestCase):
+
+ @test_util.run_deprecated_v1
+ def testRegisterGradients(self):
+ x = test_ops.float_output()
+ y = test_ops.copy_op(x)
+ fn = ops.get_gradient_function(y.op)
+ self.assertEqual(_CopyGrad, fn)
+
+ def testOverrideGradients(self):
+ g = ops.Graph()
+ with g.as_default():
+ x = test_ops.float_output()
+ with g.gradient_override_map({"CopyOp": "copy_override"}):
+ y = test_ops.copy_op(x)
+ fn = ops.get_gradient_function(y.op)
+ self.assertEqual(_CopyOverrideGrad, fn)
+
+ def testNonExistentOverride(self):
+ g = ops.Graph()
+ with g.as_default():
+ x = test_ops.float_output()
+ with g.gradient_override_map({"CopyOp": "unknown_override"}):
+ y = test_ops.copy_op(x)
+ with self.assertRaisesRegexp(LookupError, "unknown_override"):
+ ops.get_gradient_function(y.op)
+
+
+class ComparisonTest(test_util.TensorFlowTestCase):
+
+ def testMembershipAllowed(self):
+ g = ops.Graph()
+ t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1")
+ t2 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop2")
+ self.assertTrue(isinstance(t1, ops.Tensor))
+ self.assertTrue(isinstance(t2, ops.Tensor))
+ self.assertTrue(t1 in [t1])
+ self.assertTrue(t1 not in [t2])
+
+
+class ControlDependenciesTest(test_util.TensorFlowTestCase):
+
+ @test_util.run_deprecated_v1
+ def testBasic(self):
+ g = ops.Graph()
+ with g.as_default():
+ # Creating unregistered ops with _apply_op() doesn't work with the C API
+ # TODO(skyewm): address this more consistently. Possible solutions are
+ # to use registered ops in all tests, create a way to register ops in
+ # Python tests, or conditionally disable the op registration check in
+ # the C API.
+ a = constant_op.constant(1.0)
+ b = constant_op.constant(1.0)
+ with g.control_dependencies([a]):
+ c = constant_op.constant(1.0)
+ d = array_ops.identity(b)
+ e = array_ops.identity(c)
+
+ self.assertEqual(c.op.control_inputs, [a.op])
+ self.assertEqual(d.op.control_inputs, [a.op])
+ # e should be dominated by c.
+ self.assertEqual(e.op.control_inputs, [])
+
+ @test_util.run_in_graph_and_eager_modes
+ def testEager(self):
+ def future():
+ future.calls += 1
+ return constant_op.constant(2.0)
+ future.calls = 0
+
+ if context.executing_eagerly():
+ a = constant_op.constant(1.0)
+ b = future
+ with ops.control_dependencies([a, b]):
+ c = constant_op.constant(3.0)
+ self.assertEqual(future.calls, 1)
+ else:
+ g = ops.Graph()
+ with g.as_default():
+ a = constant_op.constant(1.0)
+ b = future()
+ with g.control_dependencies([a, b]):
+ c = constant_op.constant(3.0)
+ self.assertEqual(c.op.control_inputs, [a.op, b.op])
+ self.assertEqual(future.calls, 1)
+
+ def testBasicWithConversion(self):
+ g = ops.Graph()
+ a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+
+ class ConvertibleObj(object):
+
+ def _as_graph_element(self):
+ return a
+
+ with g.control_dependencies([ConvertibleObj()]):
+ c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+
+ self.assertEqual(c.op.control_inputs, [a.op])
+
+ def testNested(self):
+ g = ops.Graph()
+ a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+
+ with g.control_dependencies([a_1, a_2, a_3, a_4]):
+ b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+
+ with g.control_dependencies([a_1]):
+ with g.control_dependencies([a_2]):
+ with g.control_dependencies([a_3]):
+ with g.control_dependencies([a_4]):
+ b_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+
+ self.assertItemsEqual([a_1.op, a_2.op, a_3.op, a_4.op],
+ b_1.op.control_inputs)
+ self.assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs)
+
+ def testClear(self):
+ g = ops.Graph()
+ a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+
+ with g.control_dependencies([a_1]):
+ with g.control_dependencies([a_2]):
+ with g.control_dependencies(None):
+ with g.control_dependencies([a_3]):
+ with g.control_dependencies([a_4]):
+ # deps [a_3, a_4]
+ b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ # deps = [a_3]
+ b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ # deps back to None
+ b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ # deps back to [a_1, a_2]
+ b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ # deps back to [a_1]
+ b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ with g.control_dependencies(None):
+ # deps are None again
+ b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+
+ self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs)
+ self.assertItemsEqual([a_3.op], b_3.op.control_inputs)
+ self.assertItemsEqual([], b_none.op.control_inputs)
+ self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs)
+ self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
+ self.assertItemsEqual([], b_none2.op.control_inputs)
+
+ def testComplex(self):
+ g = ops.Graph()
+
+ # Usage pattern:
+ # * Nodes a_i are constants defined at the outermost scope, and are used
+ # as control inputs for the ith nested scope.
+ # * Nodes b_i are defined as Mul(a_3, a_4) at each scope.
+ # * Nodes c_i are defined as Mul(a_1, b_1) at each scope.
+ # * Nodes d_i are defined as Mul(b_i, c_i) at each scope.
+ # * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1.
+
+ a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+
+ with g.control_dependencies([a_1]):
+ b_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
+ [dtypes.float32])
+ c_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
+ [dtypes.float32])
+ d_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_1, c_1],
+ [dtypes.float32])
+ e_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ with g.control_dependencies([a_2]):
+ b_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
+ [dtypes.float32])
+ c_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
+ [dtypes.float32])
+ d_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_2, c_2],
+ [dtypes.float32])
+ e_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_1, e_1],
+ [dtypes.float32])
+ with g.control_dependencies([a_3]):
+ b_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
+ [dtypes.float32])
+ c_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
+ [dtypes.float32])
+ d_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_3, c_3],
+ [dtypes.float32])
+ e_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_2, e_2],
+ [dtypes.float32])
+ with g.control_dependencies([a_4]):
+ b_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
+ [dtypes.float32])
+ c_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
+ [dtypes.float32])
+ d_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_4, c_4],
+ [dtypes.float32])
+ e_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_3, e_3],
+ [dtypes.float32])
+
+ self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
+ self.assertItemsEqual([a_1.op, a_2.op], b_2.op.control_inputs)
+ self.assertItemsEqual([a_1.op, a_2.op], b_3.op.control_inputs)
+ self.assertItemsEqual([a_1.op, a_2.op], b_4.op.control_inputs)
+
+ self.assertItemsEqual([], c_1.op.control_inputs)
+ self.assertItemsEqual([a_2.op], c_2.op.control_inputs)
+ self.assertItemsEqual([a_2.op, a_3.op], c_3.op.control_inputs)
+ self.assertItemsEqual([a_2.op, a_3.op, a_4.op], c_4.op.control_inputs)
+
+ self.assertItemsEqual([], d_1.op.control_inputs)
+ self.assertItemsEqual([], d_2.op.control_inputs)
+ self.assertItemsEqual([], d_3.op.control_inputs)
+ self.assertItemsEqual([], d_4.op.control_inputs)
+
+ self.assertItemsEqual([a_1.op], e_1.op.control_inputs)
+ self.assertItemsEqual([a_2.op], e_2.op.control_inputs)
+ self.assertItemsEqual([a_3.op], e_3.op.control_inputs)
+ self.assertItemsEqual([a_4.op], e_4.op.control_inputs)
+
+ def testRepeatedDependency(self):
+ g = ops.Graph()
+ a = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32])
+ a_0, a_1 = a.outputs
+ with g.control_dependencies([a_0]):
+ b = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ with g.control_dependencies([a_1]):
+ c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+
+ self.assertEqual(b.op.control_inputs, [a])
+ self.assertEqual(c.op.control_inputs, [a])
+
+ def testNoControlDependencyWithDataDependency(self):
+ g = ops.Graph()
+ a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ with g.control_dependencies([a]):
+ b = _apply_op(g, "Identity", [a], [dtypes.float32])
+
+ self.assertEqual(b.op.control_inputs, [])
+
+
+class OpScopeTest(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNames(self):
+ with ops.name_scope("foo") as foo:
+ self.assertEqual("foo/", foo)
+ with ops.name_scope("foo2") as foo2:
+ self.assertEqual("foo/foo2/", foo2)
+ with ops.name_scope(None) as empty1:
+ self.assertEqual("", empty1)
+ with ops.name_scope("foo3") as foo3:
+ self.assertEqual("foo3/", foo3)
+ with ops.name_scope("") as empty2:
+ self.assertEqual("", empty2)
+ with ops.name_scope("foo/") as outer_foo:
+ self.assertEqual("foo/", outer_foo)
+ with ops.name_scope("") as empty3:
+ self.assertEqual("", empty3)
+ with ops.name_scope("foo4") as foo4:
+ self.assertEqual("foo/foo4/", foo4)
+ with ops.name_scope("foo5//") as foo5:
+ self.assertEqual("foo5//", foo5)
+ with ops.name_scope("foo6") as foo6:
+ self.assertEqual("foo5//foo6/", foo6)
+ with ops.name_scope("/") as foo7:
+ self.assertEqual("/", foo7)
+ with ops.name_scope("//") as foo8:
+ self.assertEqual("//", foo8)
+ with ops.name_scope("a//b/c") as foo9:
+ self.assertEqual("foo/a//b/c/", foo9)
+ with ops.name_scope("a//b/c") as foo10:
+ self.assertEqual("a//b/c/", foo10)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testEagerDefaultScopeName(self):
+ with ops.name_scope(None, "default") as scope:
+ self.assertEqual(scope, "default/")
+ with ops.name_scope(None, "default2") as scope2:
+ self.assertEqual(scope2, "default/default2/")
+
+ @test_util.run_deprecated_v1
+ def testNoScopeName(self):
+ g0 = ops.Graph()
+ values = [
+ g0.create_op("A", [], [dtypes.float32]),
+ g0.create_op("B", [], [dtypes.float32])
+ ]
+ with self.assertRaises(ValueError):
+ with ops.name_scope(None, values=values):
+ pass
+ with self.assertRaises(ValueError):
+ with ops.name_scope(None, None, values):
+ pass
+
+ @test_util.run_deprecated_v1
+ def testEmptyScopeName(self):
+ g0 = ops.Graph()
+ a = g0.create_op("A", [], [dtypes.float32])
+ b = g0.create_op("B", [], [dtypes.float32])
+ with ops.name_scope("", values=[a, b]) as scope:
+ self.assertEqual("", scope)
+ self.assertEqual(g0, ops.get_default_graph())
+ with ops.name_scope("", "my_default_scope", [a, b]) as scope:
+ self.assertEqual("", scope)
+ self.assertEqual(g0, ops.get_default_graph())
+
+ @test_util.run_deprecated_v1
+ def testDefaultScopeName(self):
+ g0 = ops.Graph()
+ a = g0.create_op("A", [], [dtypes.float32])
+ b = g0.create_op("B", [], [dtypes.float32])
+ scope_name = "my_scope"
+ default_scope_name = "my_default_scope"
+ with ops.name_scope(scope_name, default_scope_name, [a, b]) as scope:
+ self.assertEqual("%s/" % scope_name, scope)
+ self.assertEqual(g0, ops.get_default_graph())
+ with ops.name_scope(None, default_scope_name, [a, b]) as scope:
+ self.assertEqual("%s/" % default_scope_name, scope)
+ self.assertEqual(g0, ops.get_default_graph())
+
+ def _testGraphElements(self, graph_elements):
+ scope_name = "my_scope"
+ with ops.name_scope(scope_name, values=graph_elements) as scope:
+ self.assertEqual("%s/" % scope_name, scope)
+ self.assertEqual(graph_elements[0].graph, ops.get_default_graph())
+ g1 = ops.Graph()
+ a = g1.create_op("A", [], [dtypes.float32])
+ with self.assertRaises(ValueError):
+ with ops.name_scope(scope_name, values=graph_elements + [a]):
+ pass
+
+ @test_util.run_deprecated_v1
+ def testTensor(self):
+ g0 = ops.Graph()
+ a = g0.create_op("A", [], [dtypes.float32])
+ b = g0.create_op("B", [], [dtypes.float32])
+ self._testGraphElements([a, b])
+
+ @test_util.run_deprecated_v1
+ def testSparseTensor(self):
+ g0 = ops.Graph()
+ a = g0.create_op("A", [], [dtypes.float32])
+ b = g0.create_op("B", [], [dtypes.float32])
+ sparse = sparse_tensor.SparseTensor(
+ _apply_op(g0, "Int64Output", [], [dtypes.int64]),
+ _apply_op(g0, "FloatOutput", [], [dtypes.float32]),
+ _apply_op(g0, "Int64Output", [], [dtypes.int64]))
+ self._testGraphElements([a, sparse, b])
+
+ @test_util.run_deprecated_v1
+ def testVariable(self):
+ g0 = ops.Graph()
+ with g0.as_default():
+ variable = variables.Variable([1.0])
+ a = g0.create_op("A", [], [dtypes.float32])
+ b = g0.create_op("B", [], [dtypes.float32])
+ self._testGraphElements([a, variable, b])
+
+
+class InitScopeTest(test_util.TensorFlowTestCase):
+
+ def testClearsControlDependencies(self):
+ g = ops.Graph()
+ a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+
+ with g.as_default():
+ with g.control_dependencies([a_1]):
+ with g.control_dependencies([a_2]):
+ with ops.init_scope():
+ with g.control_dependencies([a_3]):
+ with g.control_dependencies([a_4]):
+ # deps [a_3, a_4]
+ b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ # deps = [a_3]
+ b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ # deps back to None
+ b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ # deps back to [a_1, a_2]
+ b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ # deps back to [a_1]
+ b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ with ops.init_scope():
+ # deps are None again
+ b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+
+ self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs)
+ self.assertItemsEqual([a_3.op], b_3.op.control_inputs)
+ self.assertItemsEqual([], b_none.op.control_inputs)
+ self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs)
+ self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
+ self.assertItemsEqual([], b_none2.op.control_inputs)
+
+ def testLiftsOpsFromFunctions(self):
+ g0 = ops.Graph()
+ g1 = ops.Graph()
+ g1._building_function = True # pylint: disable=protected-access
+ g2 = ops.Graph()
+ g2._building_function = True # pylint: disable=protected-access
+
+ with g0.as_default():
+ with g1.as_default():
+ with g2.as_default():
+ with ops.init_scope():
+ _ = constant_op.constant(1.0)
+
+ self.assertEqual(len(g2.get_operations()), 0)
+ self.assertEqual(len(g1.get_operations()), 0)
+ self.assertEqual(len(g0.get_operations()), 1)
+
+ def testPreservesDevices(self):
+ g0 = ops.Graph()
+ with g0.as_default(), ops.device("CPU:0"):
+ g1 = ops.Graph()
+ g1._building_function = True # pylint: disable=protected-access
+ with g1.as_default(), ops.device("GPU:0"):
+ with ops.init_scope():
+ # init_scope should preserve device set under `g1`.
+ on_gpu = constant_op.constant(1.0)
+ self.assertEqual(on_gpu.device, "/device:GPU:0")
+ still_on_gpu = constant_op.constant(1.0)
+ self.assertEqual(still_on_gpu.device, "/device:GPU:0")
+ on_cpu = constant_op.constant(1.0)
+ self.assertEqual(on_cpu.device, "/device:CPU:0")
+
+ def testComposes(self):
+ g0 = ops.Graph()
+ g1 = ops.Graph()
+ g1._building_function = True # pylint: disable=protected-access
+ g2 = ops.Graph()
+ g2._building_function = True # pylint: disable=protected-access
+ g3 = ops.Graph()
+ g3._building_function = False # pylint: disable=protected-access
+
+ with g0.as_default():
+ with g1.as_default():
+ with ops.init_scope():
+ # This op should be lifted into g0.
+ _ = constant_op.constant(1.0)
+ self.assertIs(g0, ops.get_default_graph())
+ self.assertEqual(len(g2.get_operations()), 0)
+ self.assertEqual(len(g1.get_operations()), 0)
+ self.assertEqual(len(g0.get_operations()), 1)
+ with g2.as_default():
+ with ops.init_scope():
+ # This op should be lifted into g0.
+ _ = constant_op.constant(1.0)
+ self.assertIs(g0, ops.get_default_graph())
+ with g3.as_default():
+ with ops.init_scope():
+ # This op should be lifted into g3, because g3 is not building a
+ # function.
+ _ = constant_op.constant(1.0)
+ self.assertIs(g3, ops.get_default_graph())
+
+ self.assertEqual(len(g3.get_operations()), 1)
+ self.assertEqual(len(g2.get_operations()), 0)
+ self.assertEqual(len(g1.get_operations()), 0)
+ self.assertEqual(len(g0.get_operations()), 2)
+
+ def testEscapesToEagerContext(self):
+ g = ops.Graph()
+ g._building_function = True # pylint: disable=protected-access
+ with context.eager_mode():
+ with context.graph_mode():
+ with g.as_default():
+ with ops.init_scope():
+ # Because g is building a function, init_scope should
+ # escape out to the eager context.
+ self.assertTrue(context.executing_eagerly())
+ # g should be reinstated as the default graph, and the
+ # graph context should be re-entered.
+ self.assertIs(g, ops.get_default_graph())
+ self.assertFalse(context.executing_eagerly())
+
+ def testStaysInEagerWhenOnlyEagerContextActive(self):
+ with context.eager_mode():
+ with ops.init_scope():
+ self.assertTrue(context.eager_mode())
+ self.assertTrue(context.eager_mode())
+
+ def testEscapesDefunWhenInEagerMode(self):
+
+ def function_with_variables():
+ with ops.init_scope():
+ self.v = resource_variable_ops.ResourceVariable(3)
+ return self.v.assign_add(1)
+
+ with context.eager_mode():
+ # Each invocation of function_with_variables recreates a variable.
+ self.assertEqual(4, int(function_with_variables()))
+ self.assertEqual(4, int(function_with_variables()))
+
+ compiled = eager_function.defun(function_with_variables)
+ # The init_scope in function_with_variables lifts the variable out
+ # of the graph function constructed by defun; hence,
+ # compiled now appears to be stateful.
+ self.assertEqual(4, int(compiled()))
+ self.assertEqual(5, int(compiled()))
+
+ def testEscapesDefunWhenInGraphMode(self):
+ def function_with_variables(name):
+ with ops.init_scope():
+ _ = variable_scope.get_variable(name, shape=(1,))
+
+ g = ops.Graph()
+ with g.as_default():
+ with self.cached_session():
+ # First ensure that graphs that are not building functions are
+ # not escaped.
+ function_with_variables("foo")
+ with self.assertRaisesRegexp(ValueError,
+ r"Variable foo already exists.*"):
+ # This will fail because reuse is not set to True.
+ function_with_variables("foo")
+
+ compiled = eager_function.defun(function_with_variables)
+ compiled("bar")
+ self.assertEqual(
+ len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2)
+
+ # The second call to `compiled` should not create variables: the
+ # init_scope has lifted the variable creation code out of the defun.
+ compiled("bar")
+ self.assertEqual(
+ len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2)
+
+ def testEscapesNestedDefun(self):
+
+ def inner_function():
+ with ops.init_scope():
+ self.v = resource_variable_ops.ResourceVariable(1)
+ return self.v.assign_add(2)
+
+ def outer_function(inner=None):
+ with ops.init_scope():
+ self.v0 = resource_variable_ops.ResourceVariable(0)
+ return self.v0.assign_add(1) + inner()
+
+ with context.eager_mode():
+ # Each invocation of outer_function recreates variables.
+ self.assertEqual(4, int(outer_function(inner=inner_function)))
+ self.assertEqual(4, int(outer_function(inner=inner_function)))
+
+ compiled_inner = eager_function.defun(inner_function)
+ compiled_outer = eager_function.defun(outer_function)
+ # The init_scope lifts variables out of the graph functions
+ # constructed by defun; hence, compiled_outer should now appear to be
+ # stateful.
+ self.assertEqual(4, int(compiled_outer(inner=compiled_inner)))
+ self.assertEqual(7, int(compiled_outer(inner=compiled_inner)))
+
+ @test_util.run_v1_only("b/120545219")
+ def testFallsBackToGlobalGraphWhenAllGraphsAreBuildingFunctions(self):
+ with context.graph_mode():
+ ops.reset_default_graph()
+ # This doesn't push anything onto the graph stack, but it does
+ # set the stack's global graph.
+ global_graph = ops.get_default_graph()
+ fn_graph = ops.Graph()
+
+ # pylint: disable=protected-access
+ fn_graph._building_function = True
+ self.assertEqual(len(ops._default_graph_stack.stack), 0)
+ with fn_graph.as_default():
+ self.assertEqual(len(ops._default_graph_stack.stack), 1)
+ with ops.init_scope():
+ self.assertGreater(len(ops._default_graph_stack.stack), 1)
+ dummy = constant_op.constant(1.0)
+ self.assertEqual(len(ops._default_graph_stack.stack), 1)
+ # Note that the global graph is _not_ on the graph stack.
+ self.assertEqual(len(ops._default_graph_stack.stack), 0)
+ # Ensure that `dummy` was added to the global graph.
+ self.assertEqual(global_graph, dummy.graph)
+ # pylint: enable=protected-access
+
+ def testInstallsDefaultGraphWhenGraphStackIsEmptyInGraphMode(self):
+ with context.graph_mode():
+ # pylint: disable=protected-access
+ self.assertEqual(len(ops._default_graph_stack.stack), 0)
+ with ops.init_scope():
+ self.assertGreater(len(ops._default_graph_stack.stack), 0)
+ self.assertEqual(len(ops._default_graph_stack.stack), 0)
+ # pylint: enable=protected-access
+
+ def testPreservesNameScopeInGraphConstruction(self):
+ with ops.Graph().as_default():
+ function_graph = ops.Graph()
+ with function_graph.as_default():
+ with ops.name_scope("inner"), ops.init_scope():
+ self.assertEqual(ops.get_name_scope(), "inner")
+ self.assertEqual(ops.get_name_scope(), "")
+
+ def testEnteringGraphFromEagerIsSticky(self):
+ with context.eager_mode():
+ g = ops.Graph()
+ with g.as_default():
+ with ops.init_scope():
+ self.assertFalse(context.executing_eagerly())
+ self.assertEqual(g, ops.get_default_graph())
+
+ def testMixGraphEager(self):
+ with context.eager_mode():
+ c = constant_op.constant(1.0)
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(
+ RuntimeError, "Attempting to capture an EagerTensor"):
+ math_ops.add(c, c)
+ c2 = constant_op.constant(2.0)
+ with self.assertRaisesRegexp(
+ TypeError, "contains objects other than 'EagerTensor'"):
+ math_ops.add(c2, c2)
+
+ def testPreservesNameScopeInEagerExecution(self):
+ with context.eager_mode():
+ def foo():
+ with ops.name_scope("inner"), ops.init_scope():
+ if context.executing_eagerly():
+ # A trailing slash is always appended when eager execution is
+ # enabled.
+ self.assertEqual(context.context().scope_name, "inner/")
+ else:
+ self.assertEqual(ops.get_name_scope(), "inner")
+
+ foo()
+ self.assertEqual(ops.get_name_scope(), "")
+ foo_compiled = eager_function.defun(foo)
+ foo_compiled()
+ self.assertEqual(ops.get_name_scope(), "")
+
+ def testExecutingEagerlyOutsideFunctions(self):
+
+ @eager_function.defun
+ def f():
+ return ops.executing_eagerly_outside_functions()
+
+ with context.eager_mode():
+ self.assertTrue(ops.executing_eagerly_outside_functions())
+ self.assertTrue(f())
+ g = ops.Graph()
+ with g.as_default():
+ self.assertFalse(ops.executing_eagerly_outside_functions())
+
+
+class GraphTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ ops.reset_default_graph()
+
+ def _AssertDefault(self, expected):
+ self.assertIs(expected, ops.get_default_graph())
+
+ def testResetDefaultGraphNesting(self):
+ g0 = ops.Graph()
+ with self.assertRaises(AssertionError):
+ with g0.as_default():
+ ops.reset_default_graph()
+
+ def testGraphContextManagerCancelsEager(self):
+ with context.eager_mode():
+ with ops.Graph().as_default():
+ self.assertFalse(context.executing_eagerly())
+
+ def testGraphContextManager(self):
+ g0 = ops.Graph()
+ with g0.as_default() as g1:
+ self.assertIs(g0, g1)
+
+ def testDefaultGraph(self):
+ orig = ops.get_default_graph()
+ self._AssertDefault(orig)
+ g0 = ops.Graph()
+ self._AssertDefault(orig)
+ context_manager_0 = g0.as_default()
+ self._AssertDefault(orig)
+ with context_manager_0 as g0:
+ self._AssertDefault(g0)
+ with ops.Graph().as_default() as g1:
+ self._AssertDefault(g1)
+ self._AssertDefault(g0)
+ self._AssertDefault(orig)
+
+ def testPreventFeeding(self):
+ g = ops.Graph()
+ a = constant_op.constant(2.0)
+ self.assertTrue(g.is_feedable(a))
+ g.prevent_feeding(a)
+ self.assertFalse(g.is_feedable(a))
+
+ @test_util.run_deprecated_v1
+ def testPreventFetching(self):
+ g = ops.Graph()
+ a = constant_op.constant(2.0)
+ self.assertTrue(g.is_fetchable(a))
+ g.prevent_fetching(a.op)
+ self.assertFalse(g.is_fetchable(a))
+
+ def testAsGraphElementConversions(self):
+
+ class ConvertibleObj(object):
+
+ def _as_graph_element(self):
+ return "FloatOutput:0"
+
+ class NonConvertibleObj(object):
+
+ pass
+
+ g = ops.Graph()
+ a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ self.assertEqual(a, g.as_graph_element(ConvertibleObj()))
+ with self.assertRaises(TypeError):
+ g.as_graph_element(NonConvertibleObj())
+
+ # Regression test against creating custom __del__ functions in classes
+ # involved in cyclic references, e.g. Graph and Operation. (Python won't gc
+ # cycles that require calling a __del__ method, because the __del__ method can
+ # theoretically increase the object's refcount to "save" it from gc, and any
+ # already-deleted objects in the cycle would have be to restored.)
+ def testGarbageCollected(self):
+ # Create a graph we can delete and a weak reference to monitor if it's gc'd
+ g = ops.Graph()
+ g_ref = weakref.ref(g)
+ # Create some ops
+ with g.as_default():
+ a = constant_op.constant(2.0)
+ b = constant_op.constant(3.0)
+ c = math_ops.add(a, b)
+ # Create a session we can delete
+ with session.Session(graph=g) as sess:
+ self.evaluate(c)
+ # Delete all references and trigger gc
+ del g
+ del a
+ del b
+ del c
+ del sess
+ gc.collect()
+ self.assertIsNone(g_ref())
+
+ def testRunnableAfterInvalidShape(self):
+ with ops.Graph().as_default():
+ with self.assertRaises(ValueError):
+ math_ops.add([1, 2], [1, 2, 3])
+ a = constant_op.constant(1)
+ with session.Session() as sess:
+ self.evaluate(a)
+
+ def testRunnableAfterInvalidShapeWithKernelLabelMap(self):
+ g = ops.Graph()
+ with g.as_default():
+ with g._kernel_label_map({"KernelLabelRequired": "overload_1"}):
+ with self.assertRaises(ValueError):
+ test_ops.kernel_label_required(1)
+ a = constant_op.constant(1)
+ with session.Session() as sess:
+ self.evaluate(a)
+
+
+class AttrScopeTest(test_util.TensorFlowTestCase):
+
+ def _get_test_attrs(self):
+ x = control_flow_ops.no_op()
+ try:
+ a = compat.as_text(x.get_attr("_A"))
+ except ValueError:
+ a = None
+ try:
+ b = compat.as_text(x.get_attr("_B"))
+ except ValueError:
+ b = None
+ return (a, b)
+
+ @test_util.run_deprecated_v1
+ def testNoLabel(self):
+ with self.cached_session():
+ self.assertAllEqual((None, None), self._get_test_attrs())
+
+ @test_util.run_deprecated_v1
+ def testLabelMap(self):
+ with self.cached_session() as sess:
+ a1 = self._get_test_attrs()
+ with sess.graph._attr_scope({
+ "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("foo"))
+ }):
+ a2 = self._get_test_attrs()
+ with sess.graph._attr_scope({
+ "_A": None,
+ "_B": attr_value_pb2.AttrValue(s=compat.as_bytes("bar"))
+ }):
+ a3 = self._get_test_attrs()
+ with sess.graph._attr_scope({
+ "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("baz"))
+ }):
+ a4 = self._get_test_attrs()
+ a5 = self._get_test_attrs()
+ a6 = self._get_test_attrs()
+ a7 = self._get_test_attrs()
+
+ self.assertAllEqual((None, None), a1)
+ self.assertAllEqual(("foo", None), a2)
+ self.assertAllEqual((None, "bar"), a3)
+ self.assertAllEqual(("baz", "bar"), a4)
+ self.assertAllEqual((None, "bar"), a5)
+ self.assertAllEqual(("foo", None), a6)
+ self.assertAllEqual((None, None), a7)
+
+
+ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape)
+
+
+class KernelLabelTest(test_util.TensorFlowTestCase):
+
+ @test_util.run_deprecated_v1
+ def testNoLabel(self):
+ with self.cached_session():
+ self.assertAllEqual(b"My label is: default",
+ test_ops.kernel_label().eval())
+
+ @test_util.run_deprecated_v1
+ def testLabelMap(self):
+ with self.cached_session() as sess:
+ default_1 = test_ops.kernel_label()
+ # pylint: disable=protected-access
+ with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}):
+ overload_1_1 = test_ops.kernel_label()
+ with sess.graph._kernel_label_map({"KernelLabel": "overload_2"}):
+ overload_2 = test_ops.kernel_label()
+ with sess.graph._kernel_label_map({"KernelLabel": ""}):
+ default_2 = test_ops.kernel_label()
+ overload_1_2 = test_ops.kernel_label()
+ # pylint: enable=protected-access
+ default_3 = test_ops.kernel_label()
+
+ self.assertAllEqual(b"My label is: default", self.evaluate(default_1))
+ self.assertAllEqual(b"My label is: default", self.evaluate(default_2))
+ self.assertAllEqual(b"My label is: default", self.evaluate(default_3))
+ self.assertAllEqual(b"My label is: overload_1",
+ self.evaluate(overload_1_1))
+ self.assertAllEqual(b"My label is: overload_1",
+ self.evaluate(overload_1_2))
+ self.assertAllEqual(b"My label is: overload_2", self.evaluate(overload_2))
+
+
+class AsGraphDefTest(test_util.TensorFlowTestCase):
+
+ def testGraphDefVersion(self):
+ """Test that the graphdef version is plumbed through to kernels."""
+ with ops.Graph().as_default() as g:
+ version = g.graph_def_versions.producer
+ with self.session(graph=g):
+ v = test_ops.graph_def_version().eval()
+ self.assertEqual(version, v)
+
+ def testAddShapes(self):
+ with ops.Graph().as_default() as g:
+ t1, t2, t3, t4, t5 = _apply_op(g, "FiveFloatOutputs", [],
+ [dtypes.float32] * 5)
+ t1.set_shape(None)
+ t2.set_shape([])
+ t3.set_shape([None])
+ t4.set_shape([43, 37])
+ t5.set_shape([43, None])
+
+ b = constant_op.constant(1.0) # pylint: disable=unused-variable
+
+ gd = g.as_graph_def(add_shapes=True)
+ self.assertProtoEqualsVersion("""
+ node { name: "FiveFloatOutputs" op: "FiveFloatOutputs"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape { unknown_rank: true }
+ shape { }
+ shape { dim { size: -1 } }
+ shape { dim { size: 43 } dim { size: 37 } }
+ shape { dim { size: 43 } dim { size: -1 } }
+ }
+ }
+ }
+ }
+ node { name: "Const" op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape { }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value { type: DT_FLOAT }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape { }
+ float_val: 1.0 } } } }
+ """, gd)
+
+
+@ops.RegisterStatistics("a", "flops")
+def _calc_a_forward_flops(unused_graph, unused_node):
+ return ops.OpStats("flops", 20)
+
+
+class StatisticsTest(test_util.TensorFlowTestCase):
+
+ def testRegisteredNode(self):
+ graph = ops.Graph()
+ node = ops._NodeDef("a", "an_a")
+ flops = ops.get_stats_for_node_def(graph, node, "flops")
+ self.assertEqual(20, flops.value)
+ missing_stat = ops.get_stats_for_node_def(graph, node, "missing_stat")
+ self.assertEqual(None, missing_stat.value)
+
+ def testUnregisteredNode(self):
+ graph = ops.Graph()
+ node = ops._NodeDef("b", "a_b")
+ weight_params = ops.get_stats_for_node_def(graph, node, "weight_params")
+ self.assertEqual(None, weight_params.value)
+
+ def testAccumulateStatistics(self):
+ flops_total = ops.OpStats("flops")
+ self.assertEqual(None, flops_total.value)
+ second_flops = ops.OpStats("flops", 3)
+ flops_total += second_flops
+ self.assertEqual(3, flops_total.value)
+
+
+class DeviceStackTest(test_util.TensorFlowTestCase):
+
+ @test_util.run_deprecated_v1
+ def testBasicDeviceAssignmentMetadata(self):
+
+ def device_func(unused_op):
+ return "/cpu:*"
+
+ const_zero = constant_op.constant([0.0], name="zero")
+ with ops.device("/cpu"):
+ const_one = constant_op.constant([1.0], name="one")
+ with ops.device("/cpu:0"):
+ const_two = constant_op.constant([2.0], name="two")
+ with ops.device(device_func):
+ const_three = constant_op.constant(3.0, name="three")
+
+ self.assertEqual(0, len(const_zero.op._device_assignments))
+
+ one_list = const_one.op._device_assignments
+ self.assertEqual(1, len(one_list))
+ self.assertEqual("/cpu", one_list[0].obj)
+ self.assertEqual("ops_test.py", os.path.basename(one_list[0].filename))
+
+ two_list = const_two.op._device_assignments
+ self.assertEqual(2, len(two_list))
+ devices = [t.obj for t in two_list]
+ self.assertEqual(set(["/cpu", "/cpu:0"]), set(devices))
+
+ three_list = const_three.op._device_assignments
+ self.assertEqual(1, len(three_list))
+ func_description = three_list[0].obj
+ expected_regex = r"device_func<.*ops_test.py, [0-9]+"
+ self.assertRegexpMatches(func_description, expected_regex)
+
+ @test_util.run_deprecated_v1
+ def testDeviceAssignmentMetadataForGraphDeviceAndTfDeviceFunctions(self):
+
+ with ops.device("/cpu"):
+ const_one = constant_op.constant([1.0], name="one")
+ with ops.get_default_graph().device("/cpu"):
+ const_two = constant_op.constant([2.0], name="two")
+
+ one_metadata = const_one.op._device_assignments[0]
+ two_metadata = const_two.op._device_assignments[0]
+
+ # Verify both types of device assignment return the right stack info.
+ self.assertRegexpMatches("ops_test.py",
+ os.path.basename(one_metadata.filename))
+ self.assertEqual(one_metadata.filename, two_metadata.filename)
+ self.assertEqual(one_metadata.lineno + 2, two_metadata.lineno)
+
+
+class ColocationGroupTest(test_util.TensorFlowTestCase):
+
+ @test_util.run_deprecated_v1
+ def testBasic(self):
+ a = constant_op.constant([2.0], name="a")
+ with ops.colocate_with(a.op):
+ b = constant_op.constant(3.0)
+ c = constant_op.constant(4.0)
+ self.assertEqual([b"loc:@a"], a.op.colocation_groups())
+ self.assertEqual([b"loc:@a"], b.op.colocation_groups())
+ with self.assertRaises(ValueError):
+ c.op.get_attr("_class")
+
+ @test_util.run_deprecated_v1
+ def testBasicColocationMetadata(self):
+ const_two = constant_op.constant([2.0], name="two")
+ with ops.colocate_with(const_two.op):
+ const_three = constant_op.constant(3.0, name="three")
+ locations_dict = const_three.op._colocation_dict
+ self.assertIn("two", locations_dict)
+ metadata = locations_dict["two"]
+ self.assertIsNone(metadata.obj)
+ # Check that this test's filename is recorded as the file containing the
+ # colocation statement.
+ self.assertEqual("ops_test.py", os.path.basename(metadata.filename))
+
+ @test_util.run_deprecated_v1
+ def testColocationDeviceInteraction(self):
+ with ops.device("/cpu:0"):
+ with ops.device("/device:GPU:0"):
+ a = constant_op.constant([2.0], name="a")
+ with ops.colocate_with(a.op):
+ # 'b' is created in the scope of /cpu:0, but it is
+ # colocated with 'a', which is on '/device:GPU:0'. colocate_with
+ # overrides devices because it is a stronger constraint.
+ b = constant_op.constant(3.0)
+ self.assertEqual([b"loc:@a"], b.op.colocation_groups())
+ self.assertEqual(a.op.device, b.op.device)
+
+ @test_util.run_deprecated_v1
+ def testColocationCanonicalization(self):
+ with ops.device("/device:GPU:0"):
+ _ = constant_op.constant(2.0)
+ with ops.device(lambda op: "/device:GPU:0"):
+ b = constant_op.constant(3.0)
+ with ops.get_default_graph().colocate_with(b):
+ with ops.device("/device:GPU:0"):
+ c = constant_op.constant(4.0)
+
+ # A's device will be /device:GPU:0
+ # B's device will be /device:GPU:0
+ # C's device will be /device:GPU:0 because it
+ # inherits B's device name, after canonicalizing the names.
+ self.assertEqual(b.op.device, c.op.device)
+
+ @test_util.run_deprecated_v1
+ def testLocationOverrides(self):
+ with ops.device("/cpu:0"):
+ with ops.device("/device:GPU:0"):
+ a = constant_op.constant([2.0], name="a")
+ # Note that this colocation is "redundant", since we are
+ # within the scope of "/device:GPU:0". However, we would like to
+ # preserve in the GraphDef that these two ops should be
+ # colocated in a portable way.
+ with ops.colocate_with(a.op):
+ b = constant_op.constant(3.0)
+ c = constant_op.constant(4.0)
+ d = constant_op.constant(5.0)
+
+ self.assertEqual([b"loc:@a"], b.op.colocation_groups())
+ self.assertEqual("/device:GPU:0", a.op.device)
+ self.assertEqual(a.op.device, b.op.device)
+
+ # Test that device function stack is restored.
+ self.assertEqual("/device:GPU:0", c.op.device)
+ self.assertEqual("/device:CPU:0", d.op.device)
+
+ @test_util.run_deprecated_v1
+ def testNestedColocateWith(self):
+ a = constant_op.constant([2.0], name="a")
+ with ops.colocate_with(a.op):
+ b = constant_op.constant(3.0)
+ with ops.colocate_with(b.op):
+ c = constant_op.constant(4.0)
+ self.assertEqual([b"loc:@a"], b.op.colocation_groups())
+ self.assertEqual([b"loc:@a"], c.op.colocation_groups())
+
+ @test_util.run_deprecated_v1
+ def testMultiColocationGroups(self):
+ a = constant_op.constant([2.0], name="a")
+ b = constant_op.constant(3.0, name="b")
+ with ops.colocate_with(a.op):
+ with ops.colocate_with(b.op):
+ c = constant_op.constant(4.0)
+ self.assertEqual(set([b"loc:@a", b"loc:@b"]), set(c.op.colocation_groups()))
+
+ @test_util.run_deprecated_v1
+ def testColocationIgnoreStack(self):
+ a = constant_op.constant([2.0], name="a")
+ b = constant_op.constant(3.0, name="b")
+ with ops.colocate_with(a.op):
+ with ops.colocate_with(b.op, ignore_existing=True):
+ c = constant_op.constant(4.0)
+ self.assertEqual(set([b"loc:@b"]), set(c.op.colocation_groups()))
+
+ @test_util.run_deprecated_v1
+ def testColocateWithReset(self):
+ a = constant_op.constant([2.0], name="a")
+ with ops.colocate_with(a.op):
+ b = constant_op.constant(3.0, name="b")
+ with ops.colocate_with(None, ignore_existing=True):
+ c = constant_op.constant(4.0, name="c")
+ self.assertEqual([b"loc:@a"], b.op.colocation_groups())
+ self.assertEqual([b"loc:@c"], c.op.colocation_groups())
+
+ @test_util.run_deprecated_v1
+ def testColocateWithInitialNoneThenNested(self):
+ a = constant_op.constant([2.0], name="a")
+ with ops.colocate_with(a.op):
+ with ops.colocate_with(None, ignore_existing=True):
+ b = constant_op.constant(3.0, name="b")
+ with ops.colocate_with(b.op):
+ c = constant_op.constant(4.0, name="c")
+ self.assertEqual([b"loc:@b"], b.op.colocation_groups())
+ self.assertEqual([b"loc:@b"], c.op.colocation_groups())
+
+ @test_util.run_deprecated_v1
+ def testColocateVariables(self):
+ a = variables.Variable([2.0], name="a")
+ with ops.colocate_with(a.op):
+ b = variables.Variable([3.0], name="b")
+ self.assertEqual([b"loc:@a"], b.op.colocation_groups())
+
+
+class DeprecatedTest(test_util.TensorFlowTestCase):
+
+ def testSuccess(self):
+ with ops.Graph().as_default() as g:
+ test_util.set_producer_version(g, 7)
+ old = test_ops.old()
+ with self.session(graph=g):
+ old.run()
+
+ def _error(self):
+ return ((r"Op Old is not available in GraphDef version %d\. "
+ r"It has been removed in version 8\. For reasons\.") %
+ versions.GRAPH_DEF_VERSION)
+
+ def testGraphConstructionFail(self):
+ with ops.Graph().as_default():
+ with self.assertRaisesRegexp(NotImplementedError, self._error()):
+ test_ops.old()
+
+
+class DenseTensorLikeTypeTest(test_util.TensorFlowTestCase):
+
+ def testSuccess(self):
+ op = ops.Operation(
+ ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
+ t = op.outputs[0]
+ self.assertTrue(ops.is_dense_tensor_like(t))
+
+ v = variables.Variable([17])
+ self.assertTrue(ops.is_dense_tensor_like(v))
+
+ class BadClassNoName(object):
+ pass
+
+ class BadClassBadName(object):
+
+ def name(self):
+ pass
+
+ class BadClassNoDtype(object):
+
+ @property
+ def name(self):
+ pass
+
+ class BadClassBadDtype(object):
+
+ @property
+ def name(self):
+ pass
+
+ def dtype(self):
+ pass
+
+ def testBadClass(self):
+ with self.assertRaisesRegexp(TypeError, "`name`"):
+ ops.register_dense_tensor_like_type(
+ DenseTensorLikeTypeTest.BadClassNoName)
+ with self.assertRaisesRegexp(TypeError, "`name`"):
+ ops.register_dense_tensor_like_type(
+ DenseTensorLikeTypeTest.BadClassBadName)
+ with self.assertRaisesRegexp(TypeError, "`dtype`"):
+ ops.register_dense_tensor_like_type(
+ DenseTensorLikeTypeTest.BadClassNoDtype)
+ with self.assertRaisesRegexp(TypeError, "`dtype`"):
+ ops.register_dense_tensor_like_type(
+ DenseTensorLikeTypeTest.BadClassBadDtype)
+
+
+class NameScopeTest(test_util.TensorFlowTestCase):
+
+ def testStripAndPrependScope(self):
+ strs = [
+ "hidden1/hidden1/weights", # Same prefix. Should strip.
+ "hidden1///hidden1/weights", # Extra "/". Should strip.
+ "^hidden1/hidden1/weights", # Same prefix. Should strip.
+ "loc:@hidden1/hidden1/weights", # Same prefix. Should strip.
+ "hhidden1/hidden1/weights", # Different prefix. Should keep.
+ "hidden1"
+ ] # Not a prefix. Should keep.
+ expected_striped = [
+ "hidden1/weights", "hidden1/weights", "^hidden1/weights",
+ "loc:@hidden1/weights", "hhidden1/hidden1/weights", "hidden1"
+ ]
+ expected_prepended = [
+ "hidden2/hidden1/weights", "hidden2/hidden1/weights",
+ "^hidden2/hidden1/weights", "loc:@hidden2/hidden1/weights",
+ "hidden2/hhidden1/hidden1/weights", "hidden2/hidden1"
+ ]
+ name_scope_to_strip = "hidden1"
+ name_scope_to_add = "hidden2"
+ for es, ep, s in zip(expected_striped, expected_prepended, strs):
+ striped = ops.strip_name_scope(s, name_scope_to_strip)
+ self.assertEqual(es, striped)
+ self.assertEqual(ep, ops.prepend_name_scope(striped, name_scope_to_add))
+
+ def testGetNameScope(self):
+ with ops.Graph().as_default() as g:
+ with ops.name_scope("scope1"):
+ with ops.name_scope("scope2"):
+ with ops.name_scope("scope3"):
+ self.assertEqual("scope1/scope2/scope3", g.get_name_scope())
+ self.assertEqual("scope1/scope2", g.get_name_scope())
+ self.assertEqual("scope1", g.get_name_scope())
+ self.assertEqual("", g.get_name_scope())
+
+ def testTwoGraphs(self):
+
+ def f():
+ g1 = ops.Graph()
+ g2 = ops.Graph()
+ with g1.as_default():
+ with g2.as_default():
+ with ops.name_scope("_"):
+ pass
+
+ self.assertRaisesRegexp(ValueError, "'_' is not a valid scope name", f)
+
+
+class TracebackTest(test_util.TensorFlowTestCase):
+
+ @test_util.run_deprecated_v1
+ def testTracebackWithStartLines(self):
+ with self.cached_session() as sess:
+ a = constant_op.constant(2.0)
+ sess.run(
+ a,
+ options=config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE))
+ self.assertTrue(sess.graph.get_operations())
+
+ # Tests that traceback_with_start_lines is the same as traceback
+ # but includes one more element at the end.
+ for op in sess.graph.get_operations():
+ self.assertEquals(len(op.traceback), len(op.traceback_with_start_lines))
+ for frame, frame_with_start_line in zip(
+ op.traceback, op.traceback_with_start_lines):
+ self.assertEquals(5, len(frame_with_start_line))
+ self.assertEquals(frame, frame_with_start_line[:-1])
+
+
+class EnableEagerExecutionTest(test_util.TensorFlowTestCase):
+
+ @test_util.run_v1_only("b/120545219")
+ def testBadArgumentsToEnableEagerExecution(self):
+ with self.assertRaisesRegexp(TypeError, "config must be a tf.ConfigProto"):
+ ops.enable_eager_execution(context.DEVICE_PLACEMENT_SILENT)
+ with self.assertRaisesRegexp(ValueError, "device_policy must be one of"):
+ c = config_pb2.ConfigProto()
+ ops.enable_eager_execution(c, c)
+ with self.assertRaisesRegexp(ValueError, "execution_mode must be one of"):
+ c = config_pb2.ConfigProto()
+ ops.enable_eager_execution(c, execution_mode=c)
+
+
+if __name__ == "__main__":
+ googletest.main()