diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 03d58dbc..081893c2 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -12,6 +12,51 @@ namespace Tensorflow /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. /// https://www.tensorflow.org/guide/graphs /// + /* + A TensorFlow computation, represented as a dataflow graph. + + A `Graph` contains a set of + `tf.Operation` objects, + which represent units of computation; and + `tf.Tensor` objects, which represent + the units of data that flow between operations. + + A default `Graph` is always registered, and accessible by calling + `tf.get_default_graph`. + To add an operation to the default graph, simply call one of the functions + that defines a new `Operation`: + + ```python + c = tf.constant(4.0) + assert c.graph is tf.get_default_graph() + ``` + + Another typical usage involves the + `tf.Graph.as_default` + context manager, which overrides the current default graph for the + lifetime of the context: + + ```python + g = tf.Graph() + with g.as_default(): + # Define operations and tensors in `g`. + c = tf.constant(30.0) + assert c.graph is g + ``` + + Important note: This class *is not* thread-safe for graph construction. All + operations should be created from a single thread, or external + synchronization must be provided. Unless otherwise specified, all methods + are not thread-safe. + + A `Graph` instance supports an arbitrary number of "collections" + that are identified by name. For convenience when building a large + graph, collections can store groups of related objects: for + example, the `tf.Variable` uses a collection (named + `tf.GraphKeys.GLOBAL_VARIABLES`) for + all variables that are created during the construction of a graph. The caller + may define additional collections by specifying a new name. + */ public partial class Graph : IPython, IDisposable { private IntPtr _handle; diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs index add752ea..7b208854 100644 --- a/src/TensorFlowNET.Core/ops.py.cs +++ b/src/TensorFlowNET.Core/ops.py.cs @@ -49,19 +49,59 @@ namespace Tensorflow return get_default_graph().get_collection_ref(key); } - private static Graph default_graph; + private static Graph default_graph; + /// + /// Returns the default graph for the current thread. + /// + /// The returned graph will be the innermost graph on which a + /// `Graph.as_default()` context has been entered, or a global default + /// graph if none has been explicitly created. + /// + /// NOTE: The default graph is a property of the current thread.If you + /// create a new thread, and wish to use the default graph in that + /// thread, you must explicitly add a `with g.as_default():` in that + /// thread's function. + /// + /// public static Graph get_default_graph() { + //TODO: original source indicates there should be a _default_graph_stack! + //return _default_graph_stack.get_default() if (default_graph == null) default_graph = tf.Graph(); return default_graph; } public static Graph set_default_graph(Graph graph) { + //TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack! default_graph = graph; return default_graph; + } + + /// + /// Clears the default graph stack and resets the global default graph. + /// + /// NOTE: The default graph is a property of the current thread.This + /// function applies only to the current thread.Calling this function while + /// a `tf.Session` or `tf.InteractiveSession` is active will result in undefined + /// behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects + /// after calling this function will result in undefined behavior. + /// + /// + public static void reset_default_graph() + { + //TODO: original source indicates there should be a _default_graph_stack! + //if (!_default_graph_stack.is_cleared()) + // throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " + + // "nested graphs. If you need a cleared graph, " + + // "exit the nesting and create a new graph."); + //_default_graph_stack.reset(); + if (default_graph!=null) + default_graph.Dispose(); + default_graph = tf.Graph(); } + public static Graph _get_graph_from_inputs(List op_input_list, Graph graph = null) { foreach(var op_input in op_input_list) diff --git a/test/TensorFlowNET.UnitTest/ops_test/GraphTest.cs b/test/TensorFlowNET.UnitTest/ops_test/GraphTest.cs new file mode 100644 index 00000000..a4a8c299 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/ops_test/GraphTest.cs @@ -0,0 +1,200 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow; +using Tensorflow.Operations; + +namespace TensorFlowNET.UnitTest.ops_test +{ + /// + /// excerpt of tensorflow/python/framework/ops_test.py + /// + [TestClass] + public class GraphTest : PythonTest + { + + [TestInitialize] + public void SetUp() + { + ops.reset_default_graph(); + } + + [TestCleanup] + public void TearDown() + { + ops.reset_default_graph(); + } + + private void _AssertDefault(Graph expected) { + Assert.AreSame(ops.get_default_graph(), expected); + } + + + [Ignore("Todo: Port")] + [TestMethod] + public void testResetDefaultGraphNesting() + { +/* + def testResetDefaultGraphNesting(self): + g0 = ops.Graph() + with self.assertRaises(AssertionError): + with g0.as_default(): + ops.reset_default_graph() +*/ + } + + [Ignore("Todo: Port")] + [TestMethod] + public void testGraphContextManagerCancelsEager() + { + /* + def testGraphContextManagerCancelsEager(self): + with context.eager_mode(): + with ops.Graph().as_default(): + self.assertFalse(context.executing_eagerly()) + */ + } + + + [Ignore("Todo: Port")] + [TestMethod] + public void testGraphContextManager() + { + /* + def testGraphContextManager(self): + g0 = ops.Graph() + with g0.as_default() as g1: + self.assertIs(g0, g1) + */ + } + + [Ignore("Todo: Port")] + [TestMethod] + public void testDefaultGraph() + { + /* + def testDefaultGraph(self): + orig = ops.get_default_graph() + self._AssertDefault(orig) + g0 = ops.Graph() + self._AssertDefault(orig) + context_manager_0 = g0.as_default() + self._AssertDefault(orig) + with context_manager_0 as g0: + self._AssertDefault(g0) + with ops.Graph().as_default() as g1: + self._AssertDefault(g1) + self._AssertDefault(g0) + self._AssertDefault(orig) + */ + } + + [Ignore("Todo: Port")] + [TestMethod] + public void testPreventFeeding() + { + /* + def testPreventFeeding(self): + g = ops.Graph() + a = constant_op.constant(2.0) + self.assertTrue(g.is_feedable(a)) + g.prevent_feeding(a) + self.assertFalse(g.is_feedable(a)) + */ + } + + + [Ignore("Todo: Port")] + [TestMethod] + public void testAsGraphElementConversions() + { + /* + def testAsGraphElementConversions(self): + + class ConvertibleObj(object): + + def _as_graph_element(self): + return "FloatOutput:0" + + class NonConvertibleObj(object): + + pass + + g = ops.Graph() + a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + self.assertEqual(a, g.as_graph_element(ConvertibleObj())) + with self.assertRaises(TypeError): + g.as_graph_element(NonConvertibleObj()) + */ + } + + [Ignore("Todo: Port")] + [TestMethod] + public void testGarbageCollected() + { + /* + # Regression test against creating custom __del__ functions in classes + # involved in cyclic references, e.g. Graph and Operation. (Python won't gc + # cycles that require calling a __del__ method, because the __del__ method can + # theoretically increase the object's refcount to "save" it from gc, and any + # already-deleted objects in the cycle would have be to restored.) + def testGarbageCollected(self): + # Create a graph we can delete and a weak reference to monitor if it's gc'd + g = ops.Graph() + g_ref = weakref.ref(g) + # Create some ops + with g.as_default(): + a = constant_op.constant(2.0) + b = constant_op.constant(3.0) + c = math_ops.add(a, b) + # Create a session we can delete + with session.Session(graph=g) as sess: + self.evaluate(c) + # Delete all references and trigger gc + del g + del a + del b + del c + del sess + gc.collect() + self.assertIsNone(g_ref()) + */ + } + + [Ignore("Todo: Port")] + [TestMethod] + public void testRunnableAfterInvalidShape() + { + /* + def testRunnableAfterInvalidShape(self): + with ops.Graph().as_default(): + with self.assertRaises(ValueError): + math_ops.add([1, 2], [1, 2, 3]) + a = constant_op.constant(1) + with session.Session() as sess: + self.evaluate(a) + */ + } + + [Ignore("Todo: Port")] + [TestMethod] + public void testRunnableAfterInvalidShapeWithKernelLabelMap() + { + /* + def testRunnableAfterInvalidShapeWithKernelLabelMap(self): + g = ops.Graph() + with g.as_default(): + with g._kernel_label_map({"KernelLabelRequired": "overload_1"}): + with self.assertRaises(ValueError): + test_ops.kernel_label_required(1) + a = constant_op.constant(1) + with session.Session() as sess: + self.evaluate(a) + */ + } + + + } +} diff --git a/test/TensorFlowNET.UnitTest/ops_test/ops_test_r1.13.py b/test/TensorFlowNET.UnitTest/ops_test/ops_test_r1.13.py new file mode 100644 index 00000000..2d7ee1a9 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/ops_test/ops_test_r1.13.py @@ -0,0 +1,3014 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.python.framework.ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gc +import os +import threading +import weakref + +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session +from tensorflow.python.eager import context +from tensorflow.python.eager import function as eager_function +from tensorflow.python.framework import common_shapes +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import device as pydev +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.framework import test_ops +from tensorflow.python.framework import test_util +from tensorflow.python.framework import versions +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import resources +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +import tensorflow.python.ops.gradients # pylint: disable=unused-import +from tensorflow.python.platform import googletest +from tensorflow.python.util import compat + +ops._set_call_cpp_shape_fn(common_shapes.call_cpp_shape_fn) + + +class ResourceTest(test_util.TensorFlowTestCase): + + @test_util.run_deprecated_v1 + def testBuildGraph(self): + with self.cached_session(): + pt = test_ops.stub_resource_handle_op(container="a", shared_name="b") + test_ops.resource_create_op(pt).run() + + @test_util.run_deprecated_v1 + def testInitialize(self): + with self.cached_session(): + handle = test_ops.stub_resource_handle_op(container="a", shared_name="b") + resources.register_resource( + handle=handle, + create_op=test_ops.resource_create_op(handle), + is_initialized_op=test_ops.resource_initialized_op(handle)) + self.assertEquals( + len( + resources.report_uninitialized_resources( + resources.shared_resources()).eval()), 1) + resources.initialize_resources(resources.shared_resources()).run() + self.assertEquals( + len( + resources.report_uninitialized_resources( + resources.shared_resources()).eval()), 0) + + +class TensorAndShapeTest(test_util.TensorFlowTestCase): + + def testShape(self): + op = ops.Operation( + ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]) + t = op.outputs[0] + self.assertEqual(tensor_shape.unknown_shape(), t.get_shape()) + t.set_shape([1, 2, 3]) + self.assertEqual([1, 2, 3], t.get_shape()) + + def testIterable(self): + op = ops.Operation( + ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]) + t = op.outputs[0] + self.assertTrue(isinstance(t, ops.Tensor)) + with self.assertRaisesRegexp(TypeError, "iter"): + for _ in t: + pass + + def testAddShape(self): + with self.cached_session(): + a = array_ops.zeros([2, 3]) + b = array_ops.ones([1, 3]) + c = a + b + self.assertEqual([2, 3], c.shape) + + @test_util.run_deprecated_v1 + def testUnknownDim(self): + with self.cached_session(): + a = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3]) + b = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3]) + c = a + b + self.assertEqual([2, None, 3], c.shape.as_list()) + + @test_util.run_deprecated_v1 + def testUnknownShape(self): + with self.cached_session(): + a = array_ops.placeholder(dtype=dtypes.float32, shape=None) + b = array_ops.ones([1, 3]) + c = a + b + self.assertEqual(tensor_shape.unknown_shape(), c.shape) + + @test_util.run_deprecated_v1 + def testScalarShape(self): + with self.cached_session(): + a = array_ops.placeholder(dtype=dtypes.float32, shape=[]) + b = array_ops.ones([]) + c = a + b + self.assertEqual(tensor_shape.scalar(), c.shape) + + @test_util.run_deprecated_v1 + def testShapeFunctionError(self): + with self.cached_session(): + a = array_ops.ones([1, 2, 3]) + b = array_ops.ones([4, 5, 6]) + with self.assertRaisesRegexp( + ValueError, + r"Dimensions must be equal, but are 2 and 5 for 'add' \(op: 'Add'\) " + r"with input shapes: \[1,2,3\], \[4,5,6\]."): + _ = a + b + + +class IndexedSlicesTest(test_util.TensorFlowTestCase): + + @test_util.run_in_graph_and_eager_modes + def testToTensor(self): + values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) + indices = constant_op.constant([0, 2]) + dense_shape = constant_op.constant([3, 2]) + x = ops.IndexedSlices(values, indices, dense_shape) + tensor = ops.convert_to_tensor(x, name="tensor") + self.assertAllEqual(self.evaluate(tensor), [[2, 3], [0, 0], [5, 7]]) + + @test_util.run_deprecated_v1 + def testNegation(self): + with self.cached_session(): + values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) + indices = constant_op.constant([0, 2]) + x = -ops.IndexedSlices(values, indices) + self.assertAllEqual(x.values.eval(), [[-2, -3], [-5, -7]]) + self.assertAllEqual(x.indices.eval(), [0, 2]) + + @test_util.run_deprecated_v1 + def testScalarMul(self): + with self.cached_session(): + values = constant_op.constant([2, 3, 5, 7], shape=[2, 2]) + indices = constant_op.constant([0, 2]) + x = math_ops.scalar_mul(-2, ops.IndexedSlices(values, indices)) + self.assertAllEqual(x.values.eval(), [[-4, -6], [-10, -14]]) + self.assertAllEqual(x.indices.eval(), [0, 2]) + + +class NodeDefConstructorTest(test_util.TensorFlowTestCase): + + def testNoArgs(self): + nodedef = ops._NodeDef("None", "bar") + self.assertProtoEquals("op: 'None' name: 'bar'", nodedef) + + def testArgs(self): + nodedef = ops._NodeDef("foo", "bar", device="/device:baz:*") + self.assertProtoEquals("op:'foo' name:'bar' device:'/device:baz:*'", + nodedef) + nodedef = ops._NodeDef("foo", "bar", device=pydev.DeviceSpec(job="j")) + self.assertProtoEquals("op:'foo' name:'bar' device:'/job:j'", nodedef) + + +def _apply_op(g, *args, **kwargs): + op = g.create_op(*args, **kwargs) + if len(op.outputs) == 1: + return op.outputs[0] + else: + return op.outputs + + +class OperationTest(test_util.TensorFlowTestCase): + + @test_util.run_deprecated_v1 + def testNoInputs(self): + op = test_ops.float_output_string_output(name="myop").a.op + self.assertEqual(2, len(op.values())) + self.assertEqual(0, len(op.inputs)) + self.assertEqual("myop", op.name) + + float_t, label_str_t = op.values() + self.assertEqual(dtypes.float32, float_t.dtype) + self.assertEqual(op, float_t.op) + self.assertEqual(0, float_t._value_index) + self.assertEqual(0, len(float_t.consumers())) + self.assertEqual("myop", float_t._as_node_def_input()) + + self.assertEqual(dtypes.string, label_str_t.dtype) + self.assertEqual(op, label_str_t.op) + self.assertEqual(1, label_str_t._value_index) + self.assertEqual(0, len(label_str_t.consumers())) + self.assertEqual("myop:1", label_str_t._as_node_def_input()) + + self.assertProtoEquals("op:'FloatOutputStringOutput' name:'myop'", + op.node_def) + + @test_util.run_deprecated_v1 + def testNoOutputs(self): + op1 = test_ops.float_output(name="myop1").op + float_t, = op1.values() + op2 = test_ops.float_input(float_t, name="myop2") + self.assertEqual(0, len(op2.values())) + self.assertEqual(1, len(op2.inputs)) + self.assertIs(float_t, op2.inputs[0]) + + self.assertEqual(1, len(float_t.consumers())) + self.assertEqual(op2, float_t.consumers()[0]) + + self.assertProtoEquals("op:'FloatOutput' name:'myop1'", op1.node_def) + self.assertProtoEquals("op:'FloatInput' name:'myop2' input:'myop1'", + op2.node_def) + + @test_util.run_deprecated_v1 + def testInputsAndOutputs(self): + op1 = test_ops.float_output(name="myop1").op + self.assertEqual(1, len(op1.values())) + float1_t, = op1.values() + + op2 = test_ops.float_output_string_output(name="myop2").a.op + self.assertEqual(2, len(op2.values())) + float2_t, label2_str_t = op2.values() + + # Note that we consume label2_str_t twice here. + op3 = test_ops.foo2(float1_t, label2_str_t, label2_str_t, name="myop3").d.op + self.assertEqual(2, len(op3.values())) + + self.assertEqual(1, len(float1_t.consumers())) + self.assertEqual(op3, float1_t.consumers()[0]) + + self.assertEqual(0, len(float2_t.consumers())) + + self.assertEqual(2, len(label2_str_t.consumers())) + self.assertEqual(op3, label2_str_t.consumers()[0]) + self.assertEqual(op3, label2_str_t.consumers()[1]) + + self.assertProtoEquals(""" + op:'Foo2' name:'myop3' + input:'myop1' input:'myop2:1' input:'myop2:1' + """, op3.node_def) + + def testDeviceFromNodeDef(self): + op = ops.Operation( + ops._NodeDef("None", "myop", device="/job:goo/device:GPU:0"), + ops.Graph(), [], []) + self.assertEqual("/job:goo/device:GPU:0", op.device) + + def testDeviceObject(self): + op = ops.Operation(ops._NodeDef("None", "myop"), ops.Graph(), [], []) + op._set_device("/job:goo/device:GPU:0") + self.assertProtoEquals( + "op:'None' name:'myop' device:'/job:goo/device:GPU:0' ", op.node_def) + op = ops.Operation(ops._NodeDef("None", "op2"), ops.Graph(), [], []) + op._set_device( + pydev.DeviceSpec( + job="muu", device_type="CPU", device_index=0)) + self.assertProtoEquals( + "op:'None' name:'op2' device:'/job:muu/device:CPU:0'", op.node_def) + + def testReferenceInput(self): + g = ops.Graph() + op1 = ops.Operation( + ops._NodeDef("RefOutputFloatOutput", "op1"), g, [], + [dtypes.float32_ref, dtypes.float32]) + self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def) + self.assertEquals([], list(op1.inputs)) + ref_t, nonref_t = op1.values() + # NOTE(mrry): Must specify input_types to preserve ref-typed input. + op2 = ops.Operation( + ops._NodeDef("RefInputFloatInput", "op2"), + g, [ref_t, nonref_t], [], + input_types=[dtypes.float32_ref, dtypes.float32]) + self.assertProtoEquals( + "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'", + op2.node_def) + self.assertEquals([ref_t, nonref_t], list(op2.inputs)) + op3 = ops.Operation( + ops._NodeDef("TwoFloatInputs", "op3"), g, [ref_t, nonref_t], []) + self.assertProtoEquals( + "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'", + op3.node_def) + + def testInvalidNames(self): + g = ops.Graph() + with self.assertRaises(ValueError): + ops.Operation(ops._NodeDef("op", ""), g) + with self.assertRaises(ValueError): + ops.Operation(ops._NodeDef("op", "_invalid"), g) + with self.assertRaises(ValueError): + ops.Operation(ops._NodeDef("op", "-invalid"), g) + with self.assertRaises(ValueError): + ops.Operation(ops._NodeDef("op", "/invalid"), g) + with self.assertRaises(ValueError): + ops.Operation(ops._NodeDef("op", "invalid:0"), g) + + @test_util.run_deprecated_v1 + def testNoShapeFunction(self): + op = test_ops.a() + self.assertEqual(tensor_shape.unknown_shape(), op.get_shape()) + + @test_util.run_in_graph_and_eager_modes + def testConvertToTensorNestedArray(self): + values = [[2], [3], [5], [7]] + tensor = ops.convert_to_tensor(values) + self.assertAllEqual((4, 1), tensor.get_shape().as_list()) + self.assertAllEqual(values, self.evaluate(tensor)) + + def testShapeTuple(self): + with self.cached_session(): + c = constant_op.constant(1) + self.assertEqual(c._shape_tuple(), ()) # pylint: disable=protected-access + + def testConvertToTensorEager(self): + with context.eager_mode(): + t = constant_op.constant(1) + self.assertTrue(isinstance(t, ops.EagerTensor)) + converted = ops.convert_to_tensor(t) + self.assertTrue(isinstance(converted, ops.EagerTensor)) + converted = ops.convert_to_tensor(1) + self.assertTrue(isinstance(converted, ops.EagerTensor)) + + @test_util.run_in_graph_and_eager_modes + def testConvertToTensorNestedTuple(self): + values = ((2,), (3,), (5,), (7,)) + tensor = ops.convert_to_tensor(values) + self.assertAllEqual((4, 1), tensor.get_shape().as_list()) + self.assertAllEqual(values, self.evaluate(ops.convert_to_tensor(values))) + + @test_util.run_in_graph_and_eager_modes + def testConvertToTensorNestedTensors(self): + values = ((2,), (3,), (5,), (7,)) + tensor = ops.convert_to_tensor( + [constant_op.constant(row) for row in values]) + self.assertAllEqual((4, 1), tensor.get_shape().as_list()) + self.assertAllEqual(values, self.evaluate(tensor)) + tensor = ops.convert_to_tensor( + [[constant_op.constant(v) for v in row] for row in values]) + self.assertAllEqual((4, 1), tensor.get_shape().as_list()) + self.assertAllEqual(values, self.evaluate(tensor)) + + @test_util.run_in_graph_and_eager_modes + def testConvertToTensorNestedMix(self): + values = ([2], (3,), [constant_op.constant(5)], constant_op.constant([7])) + tensor = ops.convert_to_tensor(values) + self.assertAllEqual((4, 1), tensor.get_shape().as_list()) + self.assertAllEqual(((2,), (3,), (5,), (7,)), self.evaluate(tensor)) + + @test_util.run_in_graph_and_eager_modes + def testConvertToTensorPreferred(self): + values = [2, 3, 5, 7] + tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.float32) + self.assertEqual(dtypes.float32, tensor.dtype) + + # Convert empty tensor to anything. + values = [] + tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64) + self.assertEqual(dtypes.int64, tensor.dtype) + + # The preferred dtype is a type error and will convert to + # float32 instead. + values = [1.23] + tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64) + self.assertEqual(dtypes.float32, tensor.dtype) + + @test_util.run_in_graph_and_eager_modes + def testConvertToInvalidTensorType(self): + with self.assertRaises(TypeError): + # Forcing an invalid dtype should fail with a type error. + values = [1.23] + ops.convert_to_tensor(values, dtype=dtypes.int64) + + @test_util.run_in_graph_and_eager_modes + def testConvertToTensorFromInvalidTensor(self): + tensor = constant_op.constant(42.0, dtype=dtypes.float32) + with self.assertRaises(ValueError): + ops.convert_to_tensor(tensor, dtype=dtypes.int32) + + @test_util.run_deprecated_v1 + def testNoConvert(self): + # Operation cannot be converted to Tensor. + op = control_flow_ops.no_op() + with self.assertRaisesRegexp(TypeError, + r"Can't convert Operation '.*' to Tensor"): + ops.convert_to_tensor(op) + + def testStr(self): + node_def = ops._NodeDef("None", "op1") + op = ops.Operation(node_def, ops.Graph(), [], [dtypes.float32]) + self.assertEqual(str(node_def), str(op)) + + def testRepr(self): + op = ops.Operation( + ops._NodeDef("None", "op1"), ops.Graph(), [], [dtypes.float32]) + self.assertEqual("", repr(op)) + + @test_util.run_deprecated_v1 + def testGetAttr(self): + op = test_ops.default_attrs() + self.assertEqual(op.get_attr("string_val"), b"abc") + self.assertEqual(op.get_attr("string_list_val"), [b"abc", b""]) + self.assertEqual(op.get_attr("int_val"), 123) + self.assertEqual(op.get_attr("int_list_val"), [1, 2, 3]) + self.assertEqual(op.get_attr("float_val"), 10.0) + self.assertEqual(op.get_attr("float_list_val"), [10.0]) + self.assertEqual(op.get_attr("bool_val"), True) + self.assertEqual(op.get_attr("bool_list_val"), [True, False]) + self.assertEqual(op.get_attr("shape_val"), + tensor_shape.as_shape([2, 1]).as_proto()) + self.assertEqual(op.get_attr("shape_list_val"), + [tensor_shape.as_shape([]).as_proto(), + tensor_shape.as_shape([1]).as_proto()]) + self.assertEqual(op.get_attr("tensor_val"), + tensor_util.make_tensor_proto(1, dtypes.int32)) + self.assertEqual(op.get_attr("tensor_list_val"), + [tensor_util.make_tensor_proto(1, dtypes.int32)]) + + type_val = op.get_attr("type_val") + # First check that type_val is a DType, because the assertEquals will work + # no matter what since DType overrides __eq__ + self.assertIsInstance(type_val, dtypes.DType) + self.assertEqual(type_val, dtypes.int32) + + type_list_val = op.get_attr("type_list_val") + self.assertTrue(all(isinstance(x, dtypes.DType) for x in type_list_val)) + self.assertEqual(type_list_val, [dtypes.int32, dtypes.float32]) + + @function.Defun(dtypes.float32, func_name="MyFunc") + def func(x): + return x + + op = test_ops.func_attr(func) + self.assertEqual(op.get_attr("f"), + attr_value_pb2.NameAttrList(name="MyFunc")) + + # Try fetching missing attr + with self.assertRaisesRegexp( + ValueError, "Operation 'FuncAttr' has no attr named 'FakeAttr'."): + op.get_attr("FakeAttr") + + # TODO(b/65162920): remove this test when users who are directly mutating the + # node_def have been updated to proper usage. + @test_util.run_deprecated_v1 + def testSetAttr(self): + op = test_ops.int_attr().op + op._set_attr("foo", attr_value_pb2.AttrValue(i=2)) + # TODO(skyewm): add node_def check + self.assertEqual(op.get_attr("foo"), 2) + + # TODO(nolivia): test all error cases + def testAddControlInput(self): + with ops.Graph().as_default(): + x = constant_op.constant(1).op + y = constant_op.constant(2).op + z = constant_op.constant(3).op + z._add_control_input(x) # pylint: disable=protected-access + self.assertEqual(z.control_inputs, [x]) + z._add_control_input(x) # pylint: disable=protected-access + self.assertEqual(z.control_inputs, [x]) + z._add_control_inputs([x, y, y]) # pylint: disable=protected-access + self.assertEqual(z.control_inputs, [x, y]) + self.assertEqual(x._control_outputs, [z]) + + @test_util.run_deprecated_v1 + def testRemoveAllControlInputs(self): + a = constant_op.constant(1) + with ops.control_dependencies([a]): + b = constant_op.constant(2) + c = constant_op.constant(3) + d = constant_op.constant(4) + e = constant_op.constant(5) + with ops.control_dependencies([a, c]): + f = d + e + + self.assertEqual(a.op.control_inputs, []) + self.assertEqual(b.op.control_inputs, [a.op]) + self.assertEqual(f.op.control_inputs, [a.op, c.op]) + + a.op._remove_all_control_inputs() # pylint: disable=protected-access + self.assertEqual(a.op.control_inputs, []) + + b.op._remove_all_control_inputs() # pylint: disable=protected-access + self.assertEqual(b.op.control_inputs, []) + + f.op._remove_all_control_inputs() # pylint: disable=protected-access + self.assertEqual(f.op.control_inputs, []) + self.assertEqual(list(f.op.inputs), [d, e]) + + @test_util.run_deprecated_v1 + def testControlInputCycle(self): + graph = ops.Graph() + with graph.as_default(): + z = constant_op.constant(0) + x = constant_op.constant(1) + y = constant_op.constant(2) + y.op._add_control_input(z.op) # pylint: disable=protected-access + y.op._add_control_input(x.op) # pylint: disable=protected-access + x.op._add_control_input(y.op) # pylint: disable=protected-access + with self.session(graph=graph) as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "Graph is invalid, contains a cycle with 2 nodes"): + self.evaluate(x) + + def testUpdateInput(self): + g = ops.Graph() + with g.as_default(): + x = constant_op.constant(1) + y = constant_op.constant(2) + z = x + y + + z.op._update_input(0, y) # pylint: disable=protected-access + self.assertEquals(list(z.op.inputs), [y, y]) + self.assertEquals(x.consumers(), []) + self.assertEquals(y.consumers(), [z.op, z.op]) + with session.Session(graph=g) as sess: + self.assertEquals(self.evaluate(z), 4) + + z.op._update_input(0, x) # pylint: disable=protected-access + self.assertEquals(list(z.op.inputs), [x, y]) + self.assertEquals(x.consumers(), [z.op]) + self.assertEquals(y.consumers(), [z.op]) + with session.Session(graph=g) as sess: + self.assertEquals(self.evaluate(z), 3) + + z.op._update_input(1, y) # pylint: disable=protected-access + self.assertEquals(list(z.op.inputs), [x, y]) + self.assertEquals(x.consumers(), [z.op]) + self.assertEquals(y.consumers(), [z.op]) + with session.Session(graph=g) as sess: + self.assertEquals(self.evaluate(z), 3) + + def testUpdateInputGraphError(self): + g_0 = ops.Graph() + g_1 = ops.Graph() + with g_0.as_default(): + x = constant_op.constant(1) + with g_1.as_default(): + y = constant_op.constant(2) + z = y * 2 + with self.assertRaisesRegexp(ValueError, "must be from the same graph"): + z.op._update_input(0, x) # pylint: disable=protected-access + + def testUpdateInputTypeError(self): + g = ops.Graph() + with g.as_default(): + w = constant_op.constant(0) + x = constant_op.constant("") + y = constant_op.constant(1) + z = y + w + z.op._update_input(0, x) # pylint: disable=protected-access + with session.Session(graph=g) as sess: + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + "Input 0 of node add was passed string from Const_1:0 incompatible " + "with expected int32"): + self.evaluate(z) + + def testUpdateInputShapeError(self): + g = ops.Graph() + with g.as_default(): + w = constant_op.constant(2, shape=[3, 1]) + x = constant_op.constant(0, shape=[3, 1]) + y = constant_op.constant(1, shape=[2, 2]) + z = w + x + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r"Cannot update edge, incompatible shapes: \[2,2\] and \[3,1\]"): + z.op._update_input(0, y) # pylint: disable=protected-access + + def testUpdateInputOutOfRange(self): + g = ops.Graph() + with g.as_default(): + x = constant_op.constant(1) + with self.assertRaisesRegexp( + errors.OutOfRangeError, + r"Cannot update edge. Input index \[1\] is greater than the number of " + r"total inputs \[0\]." + ): + x.op._update_input(1, x) # pylint: disable=protected-access + + @test_util.enable_control_flow_v2 + @test_util.run_v1_only("b/120545219") + def testAddWhileInput(self): + @eager_function.defun + def test(): + output = control_flow_ops.while_loop(lambda x: x < 3, lambda x: x + 1, + [1]) + while_op = output.op.inputs[0].op + self.assertEqual(while_op.type, "While") + orig_num_inputs = len(while_op.inputs) + + # Make sure we can handle the while op having a control input. + while_op._add_control_input(constant_op.constant(0).op) + + new_input1 = constant_op.constant(1.0) + new_input2 = constant_op.constant(True) + + while_op._set_type_list_attr("T", + [t.dtype for t in while_op.inputs] + + [new_input1.dtype, new_input2.dtype]) + + while_op._add_while_inputs([new_input1, new_input2]) + # Can't add an edge beyond what's specified by "T" + with self.assertRaises(errors.OutOfRangeError): + while_op._add_while_inputs([new_input2]) + self.assertEqual(len(while_op.inputs), orig_num_inputs + 2) # pylint: disable=g-deprecated-assert + + test() + + @test_util.run_deprecated_v1 + def testOpDef(self): + x = constant_op.constant(0) + y = constant_op.constant(1) + z = x + y + + self.assertEqual(x.op.op_def.name, "Const") + self.assertEqual(len(x.op.op_def.input_arg), 0) + self.assertEqual(len(x.op.op_def.output_arg), 1) + + self.assertEqual(z.op.op_def.name, "Add") + self.assertEqual(len(z.op.op_def.input_arg), 2) + self.assertEqual(len(z.op.op_def.output_arg), 1) + + def testInputFromDifferentGraphError(self): + g_0 = ops.Graph() + g_1 = ops.Graph() + with g_0.as_default(): + x = constant_op.constant(1) + with g_1.as_default(): + y = constant_op.constant(2) + with self.assertRaisesRegexp(ValueError, "must be from the same graph"): + y * x # pylint: disable=pointless-statement + + def testInputsAreImmutable(self): + g = ops.Graph() + with g.as_default(): + x = test_ops.int_output() + op = test_ops.int_input_int_output(x, name="myop").op + with self.assertRaisesRegexp( + AttributeError, "'_InputList' object has no attribute 'append'"): + op.inputs.append(None) + + +class CreateOpTest(test_util.TensorFlowTestCase): + + def testNodeDefArgs(self): + g = ops.Graph() + op1 = g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1") + with g.device("/device:GPU:0"): + op2 = g.create_op( + "FloatOutputStringOutput", [], [dtypes.float32, dtypes.string], None, + name="myop2") + op3 = g.create_op( + "Foo3", + [list(op1.values())[0], list(op2.values())[1], list(op2.values())[0]], + [dtypes.float32, dtypes.int32], + None, + name="myop3") + self.assertDeviceEqual(None, op1.device) + self.assertDeviceEqual("/device:GPU:0", op2.device) + self.assertDeviceEqual(None, op3.device) + self.assertProtoEquals("name:'myop1' op:'FloatOutput'", op1.node_def) + self.assertProtoEquals( + "name:'myop2' op:'FloatOutputStringOutput' device:'/device:GPU:0'", + op2.node_def) + self.assertProtoEquals( + "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo3'", + op3.node_def) + + def testReferenceInput(self): + g = ops.Graph() + op1 = g.create_op( + "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32], + name="op1") + self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def) + ref_t, nonref_t = op1.values() + # NOTE(mrry): Must specify input_types to preserve ref-typed input. + op2 = g.create_op( + "RefInputFloatInput", [ref_t, nonref_t], [], + input_types=[dtypes.float32_ref, dtypes.float32], + name="op2") + self.assertProtoEquals( + "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'", + op2.node_def) + op3 = g.create_op("TwoFloatInputs", [ref_t, nonref_t], [], name="op3") + self.assertProtoEquals( + "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'", + op3.node_def) + + def testFinalized(self): + g = ops.Graph() + g.finalize() + with self.assertRaises(RuntimeError): + g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1") + + # Test unfinalize. + g._unsafe_unfinalize() + g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1") + + +# NOTE(skyewm): these cases test the private Graph._create_op_from_tf_operation +# method. Arguably we should only test the public APIs that depend on this +# method. However, this logic is complex and tricky, and it can be difficult to +# ascertain if we have adequate coverage (e.g. a graph may run successfully if +# the control flow context isn't set properly, but a more complicated use case +# that might not be obvious to test will fail). Thus we instead explicitly test +# the low-level behavior. +class CreateOpFromTFOperationTest(test_util.TensorFlowTestCase): + + @test_util.run_deprecated_v1 + def testBasic(self): + g = ops.Graph() + with g.as_default(): + x = test_ops.int_output() + c_op = ops._create_c_op( + g, ops._NodeDef("IntInputIntOutput", "myop"), [x], []) + op = g._create_op_from_tf_operation(c_op) + + self.assertEqual(op.name, "myop") + self.assertEqual(op.type, "IntInputIntOutput") + self.assertEqual(len(op.outputs), 1) + self.assertEqual(op.outputs[0].shape, tensor_shape.unknown_shape()) + self.assertEqual(list(op.inputs), [x]) + self.assertEqual(op.control_inputs, []) + self.assertEqual(op.graph, g) + self.assertEqual(x.consumers(), [op]) + self.assertIsNotNone(op.traceback) + self.assertEqual(g.get_operation_by_name("myop"), op) + self.assertEqual(g.get_tensor_by_name("myop:0"), op.outputs[0]) + + def testShape(self): + g = ops.Graph() + with g.as_default(): + x = constant_op.constant([[1, 2, 3], [4, 5, 6]]) + c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), [x], []) + op = g._create_op_from_tf_operation(c_op) + + self.assertEqual(op.name, "myop") + self.assertEqual(op.type, "Identity") + self.assertEqual(len(op.outputs), 1) + self.assertEqual(op.outputs[0].shape, tensor_shape.matrix(2, 3)) + + def testUniqueName(self): + g = ops.Graph() + with g.as_default(): + c_op = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop"), [], []) + c_op2 = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop_1"), [], []) + op = g._create_op_from_tf_operation(c_op) + op2 = g._create_op_from_tf_operation(c_op2) + + # Create ops with same names as op1 and op2. We expect the new names to be + # uniquified. + op3 = test_ops.int_output(name="myop").op + op4 = test_ops.int_output(name="myop_1").op + + self.assertEqual(op.name, "myop") + self.assertEqual(op2.name, "myop_1") + self.assertEqual(op3.name, "myop_2") + self.assertEqual(op4.name, "myop_1_1") + + @test_util.run_v1_only("b/120545219") + def testCond(self): + g = ops.Graph() + with g.as_default(): + x = test_ops.int_output() + + def true_fn(): + ops._create_c_op(ops.get_default_graph(), + ops._NodeDef("IntInput", "cond/myop"), [x], []) + new_ops = g._add_new_tf_operations() + self.assertEqual(len(new_ops), 1) + return x + + control_flow_ops.cond(x < 10, true_fn, lambda: x) + + op = g.get_operation_by_name("cond/myop") + self.assertIsNotNone(op) + self.assertEqual(op.name, "cond/myop") + self.assertEqual(op.type, "IntInput") + self.assertEqual(op.outputs, []) + op_input = op.inputs[0].op + self.assertEqual(op_input.type, "Switch") + self.assertEqual(op_input.inputs[0], x) + self.assertEqual(op.graph, g) + # pylint: disable=protected-access + self.assertIsNotNone(op._get_control_flow_context()) + self.assertEqual(op._get_control_flow_context().name, + "cond/cond_text") + # pylint: enable=protected-access + + @test_util.run_v1_only("b/120545219") + def testWhileLoop(self): + g = ops.Graph() + with g.as_default(): + x = test_ops.int_output() + + def body(i): + ops._create_c_op(ops.get_default_graph(), + ops._NodeDef("IntInput", "myloop/myop"), [x], []) + new_ops = g._add_new_tf_operations() + self.assertEqual(len(new_ops), 1) + return i + + control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") + + op = g.get_operation_by_name("myloop/myop") + self.assertIsNotNone(op) + self.assertEqual(op.name, "myloop/myop") + self.assertEqual(op.type, "IntInput") + self.assertEqual(op.outputs, []) + op_input = op.inputs[0].op + self.assertEqual(op_input.type, "Enter") + self.assertEqual(list(op_input.inputs), [x]) + self.assertEqual(op.graph, g) + # pylint: disable=protected-access + self.assertIsNotNone(op._get_control_flow_context()) + self.assertEqual(op._get_control_flow_context().name, + "myloop/while_context") + # pylint: enable=protected-access + + @test_util.run_v1_only("b/120545219") + def testWhileLoopWithInternalControlDep(self): + g = ops.Graph() + with g.as_default(): + x = test_ops.int_output() + + def body(i): + c = constant_op.constant(1.0, name="c") + ops._create_c_op(ops.get_default_graph(), + ops._NodeDef("IntInput", "myloop/myop"), [x], []) + with ops.control_dependencies([c]): + new_ops = g._add_new_tf_operations() + self.assertEqual(len(new_ops), 1) + return i + + control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") + + op = g.get_operation_by_name("myloop/myop") + self.assertIsNotNone(op) + c = g.get_operation_by_name("myloop/c") + self.assertIsNotNone(c) + # Internal control dep is preserved + self.assertEqual(op.control_inputs, [c]) + + @test_util.run_v1_only("b/120545219") + def testWhileLoopWithExternalControlDep(self): + g = ops.Graph() + with g.as_default(): + x = test_ops.int_output() + c = constant_op.constant(1.0) + + def body(i): + ops._create_c_op(ops.get_default_graph(), + ops._NodeDef("IntInput", "myloop/myop"), [x], []) + with ops.control_dependencies([c]): + new_ops = g._add_new_tf_operations() + self.assertEqual(len(new_ops), 1) + return i + + control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop") + + op = g.get_operation_by_name("myloop/myop") + self.assertIsNotNone(op) + # External control dep is removed and replaced with internal control dep + self.assertNotEqual(op.control_inputs[0], c.op) + self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context()) + + +class ApplyOpTest(test_util.TensorFlowTestCase): + + def testNodeDefArgs(self): + g = ops.Graph() + t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1") + with g.device("/device:GPU:0"): + t2 = _apply_op( + g, "TwoIntOutputs", [], [dtypes.int32, dtypes.int32], name="myop2") + t3 = _apply_op( + g, + "Foo1", [t1, t2[1], t2[0]], [dtypes.float32, dtypes.int32], + name="myop3") + self.assertTrue(isinstance(t1, ops.Tensor)) + self.assertTrue(isinstance(t2, list)) + self.assertTrue(isinstance(t3, list)) + self.assertTrue(isinstance(t3[0], ops.Tensor)) + self.assertEqual("myop1", t1._as_node_def_input()) + self.assertEqual("myop2", t2[0]._as_node_def_input()) + self.assertEqual("myop2:1", t2[1]._as_node_def_input()) + self.assertEqual("myop3", t3[0]._as_node_def_input()) + # Validate that we got the right ops as well + self.assertProtoEquals("name:'myop1' op:'FloatOutput'", t1.op.node_def) + self.assertProtoEquals( + "name:'myop2' op:'TwoIntOutputs' device:'/device:GPU:0'", + t2[0].op.node_def) + self.assertProtoEquals( + "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo1'", + t3[0].op.node_def) + + def testReferenceInput(self): + g = ops.Graph() + ref_t, nonref_t = _apply_op( + g, "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32], + name="op1") + self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", + ref_t.op.node_def) + # NOTE(mrry): Must specify input_types to preserve ref-typed input. + out_2 = _apply_op( + g, + "RefInputFloatInputIntOutput", [ref_t, nonref_t], [dtypes.int32], + input_types=[dtypes.float32_ref, dtypes.float32], + name="op2") + self.assertProtoEquals( + "op:'RefInputFloatInputIntOutput' name:'op2' input:'op1' input:'op1:1'", + out_2.op.node_def) + out_3 = _apply_op( + g, "TwoFloatInputsIntOutput", [ref_t, nonref_t], [dtypes.int32], + name="op3") + self.assertProtoEquals( + "op:'TwoFloatInputsIntOutput' name:'op3' input:'op1' input:'op1:1'", + out_3.op.node_def) + + +class NameStackTest(test_util.TensorFlowTestCase): + + def testBasics(self): + g = ops.Graph() + self.assertEqual("foo", g.unique_name("foo", mark_as_used=False)) + self.assertEqual("foo", g.unique_name("foo", mark_as_used=False)) + self.assertEqual("foo", g.unique_name("foo")) + self.assertEqual("foo_1", g.unique_name("foo", mark_as_used=False)) + self.assertEqual("foo_1", g.unique_name("foo")) + self.assertEqual("foo_2", g.unique_name("foo", mark_as_used=False)) + self.assertEqual("foo_2", g.unique_name("foo")) + self.assertEqual("foo_1_1", g.unique_name("foo_1", mark_as_used=False)) + self.assertEqual("foo_1_1", g.unique_name("foo_1")) + self.assertEqual("foo_1_2", g.unique_name("foo_1", mark_as_used=False)) + self.assertEqual("foo_1_2", g.unique_name("foo_1")) + self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2", mark_as_used=False)) + self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2")) + with g.name_scope("bar"): + self.assertEqual("bar/foo", g.unique_name("foo", mark_as_used=False)) + self.assertEqual("bar/foo", g.unique_name("foo")) + self.assertEqual("bar/foo_1", g.unique_name("foo", mark_as_used=False)) + self.assertEqual("bar/foo_1", g.unique_name("foo")) + with g.name_scope(None): + self.assertEqual("foo_3", g.unique_name("foo", mark_as_used=False)) + self.assertEqual("foo_3", g.unique_name("foo")) + with g.name_scope("baz"): + self.assertEqual( + "bar/baz/foo", g.unique_name( + "foo", mark_as_used=False)) + self.assertEqual("bar/baz/foo", g.unique_name("foo")) + self.assertEqual( + "bar/baz/foo_1", g.unique_name( + "foo", mark_as_used=False)) + self.assertEqual("bar/baz/foo_1", g.unique_name("foo")) + with g.name_scope("baz"): + self.assertEqual( + "bar/baz_1/foo", g.unique_name( + "foo", mark_as_used=False)) + self.assertEqual("bar/baz_1/foo", g.unique_name("foo")) + self.assertEqual( + "bar/baz_1/foo_1", g.unique_name( + "foo", mark_as_used=False)) + self.assertEqual("bar/baz_1/foo_1", g.unique_name("foo")) + with g.name_scope("quux"): + self.assertEqual("quux/foo", g.unique_name("foo", mark_as_used=False)) + self.assertEqual("quux/foo", g.unique_name("foo")) + with g.name_scope("bar"): + with g.name_scope("baz"): + self.assertEqual( + "bar_1/baz/foo", g.unique_name( + "foo", mark_as_used=False)) + self.assertEqual("bar_1/baz/foo", g.unique_name("foo")) + self.assertEqual("foo_4", g.unique_name("foo", mark_as_used=False)) + self.assertEqual("foo_4", g.unique_name("foo")) + self.assertEqual("bar_2", g.unique_name("bar", mark_as_used=False)) + self.assertEqual("bar_2", g.unique_name("bar")) + + @test_util.run_deprecated_v1 + def testNameAndVariableScope(self): + with self.cached_session() as sess: + with sess.graph.name_scope("l0"): + with variable_scope.variable_scope("l1"): + with sess.graph.name_scope("l1") as scope: + self.assertEqual("l0/l1/l1/", scope) + self.assertEqual( + "l0/l1/l1/foo", + sess.graph.unique_name( + "foo", mark_as_used=False)) + self.assertEqual("l0/l1/l1/foo", sess.graph.unique_name("foo")) + with sess.graph.name_scope("l2") as scope: + self.assertEqual("l0/l1/l2/", scope) + self.assertEqual( + "l0/l1/l2/foo", + sess.graph.unique_name( + "foo", mark_as_used=False)) + self.assertEqual("l0/l1/l2/foo", sess.graph.unique_name("foo")) + + def testOutOfOrderUniqueName(self): + g = ops.Graph() + self.assertEqual("foo_2", g.unique_name("foo_2")) + self.assertEqual("foo", g.unique_name("foo")) + self.assertEqual("foo_1", g.unique_name("foo")) + self.assertEqual("foo_3", g.unique_name("foo")) + + def testUniqueNameCaseInsensitivity(self): + g = ops.Graph() + self.assertEqual("foo", g.unique_name("foo")) + self.assertEqual("Foo_1", g.unique_name("Foo")) + with g.name_scope("bar"): + self.assertEqual("bar/foo", g.unique_name("foo")) + with g.name_scope("Bar"): + self.assertEqual("Bar_1/foo", g.unique_name("foo")) + + def testInvalidNameRaisesError(self): + g = ops.Graph() + with g.name_scope(""): # Should not raise + pass + with g.name_scope("foo/"): # Should not raise + with g.name_scope("_bar"): # Should not raise + pass + with self.assertRaises(ValueError): + with g.name_scope("foo:0"): + pass + with self.assertRaises(ValueError): + with g.name_scope("_bar"): + pass + + +class NameTest(test_util.TensorFlowTestCase): + + def testGenerateName(self): + g = ops.Graph() + op0 = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32]) + self.assertEqual("TwoFloatOutputs", op0.name) + self.assertEqual("TwoFloatOutputs:0", op0.outputs[0].name) + self.assertEqual("TwoFloatOutputs:1", op0.outputs[1].name) + + op1 = g.create_op("FloatOutput", [], [dtypes.float32]) + self.assertEqual("FloatOutput", op1.name) + self.assertEqual("FloatOutput:0", op1.outputs[0].name) + + op2 = g.create_op("FloatOutput", [], [dtypes.float32]) + self.assertEqual("FloatOutput_1", op2.name) + self.assertEqual("FloatOutput_1:0", op2.outputs[0].name) + + op3 = g.create_op("FloatOutput", [], [dtypes.float32], name="my_op") + self.assertEqual("my_op", op3.name) + self.assertEqual("my_op:0", op3.outputs[0].name) + + def testNameScope(self): + g = ops.Graph() + + with g.name_scope("foo") as foo: + self.assertEqual("foo/", foo) + with g.name_scope("foo2") as foo2: + self.assertEqual("foo/foo2/", foo2) + with g.name_scope(None) as empty1: + self.assertEqual("", empty1) + with g.name_scope("foo3") as foo3: + self.assertEqual("foo3/", foo3) + with g.name_scope("") as empty2: + self.assertEqual("", empty2) + + self.assertEqual("FloatOutput", + g.create_op("FloatOutput", [], [dtypes.float32]).name) + with g.name_scope("bar") as scope: + self.assertEqual("bar/FloatOutput", + g.create_op("FloatOutput", [], [dtypes.float32]).name) + self.assertEqual("bar/FloatOutput_1", + g.create_op("FloatOutput", [], [dtypes.float32]).name) + # If you use the value from "with .. as", that values is used as-is. + self.assertEqual( + "bar", g.create_op( + "FloatOutput", [], [dtypes.float32], name=scope).name) + with g.name_scope("baz") as scope: + with g.name_scope("quux"): + self.assertEqual("baz/quux/FloatOutput", + g.create_op("FloatOutput", [], [dtypes.float32]).name) + # If you use the value from the enclosing "with .. as", nothing is pushed. + with g.name_scope(scope): + self.assertEqual("baz/FloatOutput", + g.create_op("FloatOutput", [], [dtypes.float32]).name) + self.assertEqual( + "baz", g.create_op( + "FloatOutput", [], [dtypes.float32], name=scope).name) + self.assertEqual( + "trailing", + g.create_op( + "FloatOutput", [], [dtypes.float32], name="trailing/").name) + with g.name_scope("bar"): + self.assertEqual("bar_1/FloatOutput", + g.create_op("FloatOutput", [], [dtypes.float32]).name) + with g.name_scope("bar/"): + self.assertEqual("bar/FloatOutput_2", + g.create_op("FloatOutput", [], [dtypes.float32]).name) + + +class DeviceTest(test_util.TensorFlowTestCase): + + def testNoDevice(self): + g = ops.Graph() + op = g.create_op("FloatOutput", [], [dtypes.float32]) + self.assertDeviceEqual(None, op.device) + gd = g.as_graph_def() + self.assertProtoEqualsVersion(""" + node { name: "FloatOutput" op: "FloatOutput" } + """, gd) + + def testEagerBackingDevice(self): + with context.eager_mode(): + with ops.device("/device:CPU:0"): + t = constant_op.constant(1.0) + self.assertRegexpMatches(t.device, "/device:CPU:0") + self.assertRegexpMatches(t.backing_device, "/device:CPU:0") + + def testDevicePartialString(self): + g = ops.Graph() + with g.device("/job:worker/replica:2"): + g.create_op("FloatOutput", [], [dtypes.float32]) + gd = g.as_graph_def() + self.assertProtoEqualsVersion(""" + node { name: "FloatOutput" op: "FloatOutput" + device: "/job:worker/replica:2" } + """, gd) + + def testDeviceFull(self): + g = ops.Graph() + with g.device( + pydev.DeviceSpec( + job="worker", replica=2, task=0, device_type="CPU", + device_index=3)): + g.create_op("FloatOutput", [], [dtypes.float32]) + gd = g.as_graph_def() + self.assertProtoEqualsVersion(""" + node { name: "FloatOutput" op: "FloatOutput" + device: "/job:worker/replica:2/task:0/device:CPU:3" } + """, gd) + + def testNesting(self): + g = ops.Graph() + with g.device("/job:worker/replica:2"): + g.create_op("FloatOutput", [], [dtypes.float32]) + with g.device("/job:worker/replica:3/task:0"): + g.create_op("FloatOutput", [], [dtypes.float32]) + g.create_op("FloatOutput", [], [dtypes.float32]) + gd = g.as_graph_def() + self.assertProtoEqualsVersion(""" + node { name: "FloatOutput" op: "FloatOutput" + device: "/job:worker/replica:2" } + node { name: "FloatOutput_1" op: "FloatOutput" + device: "/job:worker/replica:3/task:0" } + node { name: "FloatOutput_2" op: "FloatOutput" + device: "/job:worker/replica:2" } + """, gd) + + def testNestingString(self): + g = ops.Graph() + with g.device("/job:worker/replica:2"): + g.create_op("FloatOutput", [], [dtypes.float32]) + with g.device("/job:worker/replica:3/task:0"): + g.create_op("FloatOutput", [], [dtypes.float32]) + g.create_op("FloatOutput", [], [dtypes.float32]) + gd = g.as_graph_def() + self.assertProtoEqualsVersion(""" + node { name: "FloatOutput" op: "FloatOutput" + device: "/job:worker/replica:2" } + node { name: "FloatOutput_1" op: "FloatOutput" + device: "/job:worker/replica:3/task:0" } + node { name: "FloatOutput_2" op: "FloatOutput" + device: "/job:worker/replica:2" } + """, gd) + + def testNestingOverrideGpuCpu(self): + g = ops.Graph() + with g.device("/job:worker/replica:2/device:CPU:1"): + g.create_op("FloatOutput", [], [dtypes.float32]) + with g.device("/job:worker/replica:2/device:GPU:2"): + g.create_op("FloatOutput", [], [dtypes.float32]) + g.create_op("FloatOutput", [], [dtypes.float32]) + gd = g.as_graph_def() + self.assertProtoEqualsVersion(""" + node { name: "FloatOutput" op: "FloatOutput" + device: "/job:worker/replica:2/device:CPU:1" } + node { name: "FloatOutput_1" op: "FloatOutput" + device: "/job:worker/replica:2/device:GPU:2" } + node { name: "FloatOutput_2" op: "FloatOutput" + device: "/job:worker/replica:2/device:CPU:1" } + """, gd) + + def testNestingWithMergeDeviceFunction(self): + g = ops.Graph() + + with g.device(pydev.merge_device("/device:GPU:0")): + g.create_op("FloatOutput", [], [dtypes.float32]) + with g.device(pydev.merge_device("/job:worker")): + g.create_op("FloatOutput", [], [dtypes.float32]) + with g.device(pydev.merge_device("/device:CPU:0")): + g.create_op("FloatOutput", [], [dtypes.float32]) + with g.device(pydev.merge_device("/job:ps")): + g.create_op("FloatOutput", [], [dtypes.float32]) + with g.device(pydev.merge_device(None)): + g.create_op("FloatOutput", [], [dtypes.float32]) + + gd = g.as_graph_def() + self.assertProtoEqualsVersion(""" + node { name: "FloatOutput" op: "FloatOutput" + device: "/device:GPU:0" } + node { name: "FloatOutput_1" op: "FloatOutput" + device: "/job:worker/device:GPU:0" } + node { name: "FloatOutput_2" op: "FloatOutput" + device: "/job:worker/device:CPU:0" } + node { name: "FloatOutput_3" op: "FloatOutput" + device: "/job:ps/device:CPU:0" } + node { name: "FloatOutput_4" op: "FloatOutput" + device: "/job:ps/device:CPU:0" } + """, gd) + + def testNestingWithDeviceStrings(self): + g = ops.Graph() + + with g.device("/device:GPU:0"): + g.create_op("FloatOutput", [], [dtypes.float32]) + with g.device("/job:worker"): + g.create_op("FloatOutput", [], [dtypes.float32]) + with g.device("/device:CPU:0"): + g.create_op("FloatOutput", [], [dtypes.float32]) + with g.device("/job:ps"): + g.create_op("FloatOutput", [], [dtypes.float32]) + with g.device(""): + g.create_op("FloatOutput", [], [dtypes.float32]) + + gd = g.as_graph_def() + self.assertProtoEqualsVersion(""" + node { name: "FloatOutput" op: "FloatOutput" + device: "/device:GPU:0" } + node { name: "FloatOutput_1" op: "FloatOutput" + device: "/job:worker/device:GPU:0" } + node { name: "FloatOutput_2" op: "FloatOutput" + device: "/job:worker/device:CPU:0" } + node { name: "FloatOutput_3" op: "FloatOutput" + device: "/job:ps/device:CPU:0" } + node { name: "FloatOutput_4" op: "FloatOutput" + device: "/job:ps/device:CPU:0" } + """, gd) + + def testNestingWithDeviceStringWildcard(self): + g = ops.Graph() + + with g.device("/device:GPU:7"): + g.create_op("FloatOutput", [], [dtypes.float32]) + with g.device("/device:GPU:*"): + g.create_op("FloatOutput", [], [dtypes.float32]) + + with g.device("/device:CPU:*"): + g.create_op("FloatOutput", [], [dtypes.float32]) + with g.device("/device:CPU:5"): + g.create_op("FloatOutput", [], [dtypes.float32]) + + gd = g.as_graph_def() + self.assertProtoEqualsVersion(""" + node { name: "FloatOutput" op: "FloatOutput" + device: "/device:GPU:7" } + node { name: "FloatOutput_1" op: "FloatOutput" + device: "/device:GPU:7" } + node { name: "FloatOutput_2" op: "FloatOutput" + device: "/device:CPU:*" } + node { name: "FloatOutput_3" op: "FloatOutput" + device: "/device:CPU:5" } + """, gd) + + def testNoneClearsDefault(self): + g = ops.Graph() + with g.device("/job:worker/replica:2/device:CPU:1"): + g.create_op("FloatOutput", [], [dtypes.float32]) + with g.device(None): + g.create_op("FloatOutput", [], [dtypes.float32]) + g.create_op("FloatOutput", [], [dtypes.float32]) + gd = g.as_graph_def() + self.assertProtoEqualsVersion(""" + node { name: "FloatOutput" op: "FloatOutput" + device: "/job:worker/replica:2/device:CPU:1" } + node { name: "FloatOutput_1" op: "FloatOutput" } + node { name: "FloatOutput_2" op: "FloatOutput" + device: "/job:worker/replica:2/device:CPU:1" } + """, gd) + + def testNoneIgnoresOuterDeviceFunction(self): + g = ops.Graph() + with g.device(lambda op: "/job:worker/replica:2/device:CPU:1"): + g.create_op("FloatOutput", [], [dtypes.float32]) + with g.device(None): + g.create_op("FloatOutput", [], [dtypes.float32]) + g.create_op("FloatOutput", [], [dtypes.float32]) + gd = g.as_graph_def() + self.assertProtoEqualsVersion(""" + node { name: "FloatOutput" op: "FloatOutput" + device: "/job:worker/replica:2/device:CPU:1" } + node { name: "FloatOutput_1" op: "FloatOutput" } + node { name: "FloatOutput_2" op: "FloatOutput" + device: "/job:worker/replica:2/device:CPU:1" } + """, gd) + + def _overwritingDeviceFunction(self, unused_op): + # This device function unconditionally overwrites the device of ops. + # + # NOTE(mrry): Writing device functions like this is not + # recommended. Instead, in most cases you should use + # `pydev.merge_device("/job:ps")` or simply `"/job:ps"` as the + # argument to `tf.device()` and the device component will be merged in. + return "/job:overwrite" + + def testOverwritingBehavior(self): + g = ops.Graph() + with g.device(self._overwritingDeviceFunction): + g.create_op("FloatOutput", [], [dtypes.float32]) + with g.device("/job:ps"): # Will be overwritten. + g.create_op("FloatOutput", [], [dtypes.float32]) + with g.device(pydev.merge_device("/job:ps")): # Will be overwritten. + g.create_op("FloatOutput", [], [dtypes.float32]) + with g.device(None): # Disables overwriting device function + with g.device("/job:ps"): + g.create_op("FloatOutput", [], [dtypes.float32]) + with g.device(None): # Disables overwriting device function + with g.device(pydev.merge_device("/job:ps")): + g.create_op("FloatOutput", [], [dtypes.float32]) + gd = g.as_graph_def() + self.assertProtoEqualsVersion(""" + node { name: "FloatOutput" op: "FloatOutput" + device: "/job:overwrite" } + node { name: "FloatOutput_1" op: "FloatOutput" + device: "/job:overwrite" } + node { name: "FloatOutput_2" op: "FloatOutput" + device: "/job:overwrite" } + node { name: "FloatOutput_3" op: "FloatOutput" + device: "/job:ps" } + node { name: "FloatOutput_4" op: "FloatOutput" + device: "/job:ps" } + """, gd) + + +class MultithreadedGraphStateTest(test_util.TensorFlowTestCase): + + class TestThread(threading.Thread): + + def __init__(self, graph, replica_id): + super(MultithreadedGraphStateTest.TestThread, self).__init__() + self._graph = graph + self._replica_id = replica_id + # This thread sets this event when it mutated the graph. The caller can + # wait for that. + self.has_mutated_graph = threading.Event() + # This thread waits for when it should continue. The caller can set this + # event. + self.should_continue = threading.Event() + + def run(self): + # Mutate a graph's stack, then set `has_mutated_graph`, then wait for + # `should_continue`, then add an op to the graph affected by the graph's + # stack. + raise NotImplementedError("must be implemented in descendants") + + def testDeviceFunctionStack(self): + + class DeviceSettingThread(self.TestThread): + + def run(self): + with g.device("/job:worker/replica:{}".format(self._replica_id)): + self.has_mutated_graph.set() + self.should_continue.wait() + self.should_continue.clear() + g.create_op( + "FloatOutput", [], [dtypes.float32], + name="FloatOutput_{}".format(self._replica_id)) + + g = ops.Graph() + # If `switch_to_thread` isn't called, then device placement of the ops + # below is not deterministic. + g.switch_to_thread_local() + threads = [DeviceSettingThread(g, i) for i in range(3)] + for t in threads: + t.start() + t.has_mutated_graph.wait() + t.has_mutated_graph.clear() + for t in threads: + t.should_continue.set() + t.join() + + gd = g.as_graph_def() + self.assertProtoEqualsVersion(""" + node { name: "FloatOutput_0" op: "FloatOutput" + device: "/job:worker/replica:0" } + node { name: "FloatOutput_1" op: "FloatOutput" + device: "/job:worker/replica:1" } + node { name: "FloatOutput_2" op: "FloatOutput" + device: "/job:worker/replica:2" } + """, gd) + + def testColocateWith(self): + + class ColocatingThread(self.TestThread): + + def __init__(self, graph, replica_id, op_to_colocate_with): + super(ColocatingThread, self).__init__(graph, replica_id) + self._op_to_colocate_with = op_to_colocate_with + + def run(self): + with g.colocate_with(self._op_to_colocate_with): + self.has_mutated_graph.set() + self.should_continue.wait() + self.should_continue.clear() + g.create_op( + "FloatOutput", [], [dtypes.float32], + name="FloatOutput_{}".format(self._replica_id)) + + g = ops.Graph() + ops_to_colocate_with = [] + for i in range(3): + with g.device("/job:worker/replica:{}".format(i)): + ops_to_colocate_with.append( + g.create_op( + "FloatOutput", [], [dtypes.float32], + name="ColocateWithMe_{}".format(i))) + + # If `switch_to_thread` isn't called, then `device` and `attr` values for + # the ops below are not deterministic. + g.switch_to_thread_local() + threads = [ + ColocatingThread(g, i, ops_to_colocate_with[i]) for i in range(3) + ] + for t in threads: + t.start() + t.has_mutated_graph.wait() + t.has_mutated_graph.clear() + for t in threads: + t.should_continue.set() + t.join() + + gd = g.as_graph_def() + self.assertProtoEqualsVersion(""" + node { name: "ColocateWithMe_0" op: "FloatOutput" + device: "/job:worker/replica:0" } + node { name: "ColocateWithMe_1" op: "FloatOutput" + device: "/job:worker/replica:1" } + node { name: "ColocateWithMe_2" op: "FloatOutput" + device: "/job:worker/replica:2" } + node { name: "FloatOutput_0" op: "FloatOutput" + device: "/job:worker/replica:0" + attr { key: "_class" + value { list { + s: "loc:@ColocateWithMe_0"}}}} + node { name: "FloatOutput_1" op: "FloatOutput" + device: "/job:worker/replica:1" + attr { key: "_class" + value { list { + s: "loc:@ColocateWithMe_1"}}}} + node { name: "FloatOutput_2" op: "FloatOutput" + device: "/job:worker/replica:2" + attr { key: "_class" + value { list { + s: "loc:@ColocateWithMe_2"}}}} + """, gd) + + def testControlDependencies(self): + + class DependingThread(self.TestThread): + + def __init__(self, graph, replica_id, dependency_op): + super(DependingThread, self).__init__(graph, replica_id) + self._dependency_op = dependency_op + + def run(self): + with g.control_dependencies([self._dependency_op]): + self.has_mutated_graph.set() + self.should_continue.wait() + self.should_continue.clear() + g.create_op( + "FloatOutput", [], [dtypes.float32], + name="FloatOutput_{}".format(self._replica_id)) + + g = ops.Graph() + dependency_ops = [] + for i in range(3): + dependency_ops.append( + g.create_op( + "FloatOutput", [], [dtypes.float32], + name="ColocateWithMe_{}".format(i))) + + # If `switch_to_thread` isn't called, then `input` values for the ops below + # are not deterministic. + g.switch_to_thread_local() + threads = [DependingThread(g, i, dependency_ops[i]) for i in range(3)] + for t in threads: + t.start() + t.has_mutated_graph.wait() + t.has_mutated_graph.clear() + for t in threads: + t.should_continue.set() + t.join() + + gd = g.as_graph_def() + self.assertProtoEqualsVersion(""" + node { name: "ColocateWithMe_0" op: "FloatOutput" } + node { name: "ColocateWithMe_1" op: "FloatOutput" } + node { name: "ColocateWithMe_2" op: "FloatOutput" } + node { name: "FloatOutput_0" op: "FloatOutput" + input: "^ColocateWithMe_0" } + node { name: "FloatOutput_1" op: "FloatOutput" + input: "^ColocateWithMe_1" } + node { name: "FloatOutput_2" op: "FloatOutput" + input: "^ColocateWithMe_2" } + """, gd) + + def testNameStack(self): + + class NameSettingThread(self.TestThread): + + def run(self): + with g.name_scope("foo"): + op1 = g.create_op("FloatOutput", [], [dtypes.float32]) + self.has_mutated_graph.set() + self.should_continue.wait() + self.should_continue.clear() + op2 = g.create_op("FloatOutput", [], [dtypes.float32]) + self.result = (op1, op2) + + g = ops.Graph() + threads = [NameSettingThread(g, i) for i in range(3)] + for t in threads: + t.start() + t.has_mutated_graph.wait() + t.has_mutated_graph.clear() + + for t in threads: + t.should_continue.set() + t.join() + + suffixes = ["", "_1", "_2"] + for t, s in zip(threads, suffixes): + self.assertEquals("foo" + s + "/FloatOutput", t.result[0].name) + self.assertEquals("foo" + s + "/FloatOutput_1", t.result[1].name) + + +class ObjectWithName(object): + + def __init__(self, name): + self._name = name + + @property + def name(self): + return self._name + + +class CollectionTest(test_util.TensorFlowTestCase): + + def test_get_collections(self): + g = ops.Graph() + self.assertSequenceEqual(g.collections, []) + g.add_to_collection("key", 12) + g.add_to_collection("key", 15) + self.assertSequenceEqual(g.collections, ["key"]) + g.add_to_collection("other", "foo") + self.assertSequenceEqual(sorted(g.collections), ["key", "other"]) + + def test_add_to_collection(self): + g = ops.Graph() + g.add_to_collection("key", 12) + g.add_to_collection("other", "foo") + g.add_to_collection("key", 34) + + # Note that only blank1 is returned. + g.add_to_collection("blah", 27) + blank1 = ObjectWithName("prefix/foo") + g.add_to_collection("blah", blank1) + blank2 = ObjectWithName("junk/foo") + g.add_to_collection("blah", blank2) + + self.assertEqual([12, 34], g.get_collection("key")) + self.assertEqual([], g.get_collection("nothing")) + self.assertEqual([27, blank1, blank2], g.get_collection("blah")) + self.assertEqual([blank1], g.get_collection("blah", "prefix")) + self.assertEqual([blank1], g.get_collection("blah", ".*x")) + + # Make sure that get_collection() returns a first-level + # copy of the collection, while get_collection_ref() returns + # the original list. + other_collection_snapshot = g.get_collection("other") + other_collection_ref = g.get_collection_ref("other") + self.assertEqual(["foo"], other_collection_snapshot) + self.assertEqual(["foo"], other_collection_ref) + g.add_to_collection("other", "bar") + self.assertEqual(["foo"], other_collection_snapshot) + self.assertEqual(["foo", "bar"], other_collection_ref) + self.assertEqual(["foo", "bar"], g.get_collection("other")) + self.assertTrue(other_collection_ref is g.get_collection_ref("other")) + + # Verify that getting an empty collection ref returns a modifiable list. + empty_coll_ref = g.get_collection_ref("empty") + self.assertEqual([], empty_coll_ref) + empty_coll = g.get_collection("empty") + self.assertEqual([], empty_coll) + self.assertFalse(empty_coll is empty_coll_ref) + empty_coll_ref2 = g.get_collection_ref("empty") + self.assertTrue(empty_coll_ref2 is empty_coll_ref) + # Add to the collection. + empty_coll_ref.append("something") + self.assertEqual(["something"], empty_coll_ref) + self.assertEqual(["something"], empty_coll_ref2) + self.assertEqual([], empty_coll) + self.assertEqual(["something"], g.get_collection("empty")) + empty_coll_ref3 = g.get_collection_ref("empty") + self.assertTrue(empty_coll_ref3 is empty_coll_ref) + + def test_add_to_collections_uniquify(self): + g = ops.Graph() + g.add_to_collections([1, 2, 1], "key") + # Make sure "key" is not added twice + self.assertEqual(["key"], g.get_collection(1)) + + def test_add_to_collections_from_list(self): + g = ops.Graph() + g.add_to_collections(["abc", "123"], "key") + self.assertEqual(["key"], g.get_collection("abc")) + self.assertEqual(["key"], g.get_collection("123")) + + def test_add_to_collections_from_tuple(self): + g = ops.Graph() + g.add_to_collections(("abc", "123"), "key") + self.assertEqual(["key"], g.get_collection("abc")) + self.assertEqual(["key"], g.get_collection("123")) + + def test_add_to_collections_from_generator(self): + g = ops.Graph() + + def generator(): + yield "abc" + yield "123" + + g.add_to_collections(generator(), "key") + self.assertEqual(["key"], g.get_collection("abc")) + self.assertEqual(["key"], g.get_collection("123")) + + def test_add_to_collections_from_set(self): + g = ops.Graph() + g.add_to_collections(set(["abc", "123"]), "key") + self.assertEqual(["key"], g.get_collection("abc")) + self.assertEqual(["key"], g.get_collection("123")) + + def test_add_to_collections_from_string(self): + g = ops.Graph() + g.add_to_collections("abc", "key") + self.assertEqual(["key"], g.get_collection("abc")) + + def test_default_graph(self): + with ops.Graph().as_default(): + ops.add_to_collection("key", 90) + ops.add_to_collection("key", 100) + # Collections are ordered. + self.assertEqual([90, 100], ops.get_collection("key")) + + def test_defun(self): + with context.eager_mode(): + + @eager_function.defun + def defun(): + ops.add_to_collection("int", 1) + ops.add_to_collection("tensor", constant_op.constant(2)) + + @eager_function.defun + def inner_defun(): + self.assertEqual(ops.get_collection("int"), [1]) + three = ops.get_collection("tensor")[0] + ops.get_collection("int")[0] + ops.add_to_collection("int", 2) + self.assertEqual(ops.get_collection("int"), [1, 2]) + ops.add_to_collection("foo", "bar") + self.assertEqual(ops.get_collection("foo"), ["bar"]) + return three + + self.assertEqual(ops.get_collection("int"), [1]) + three = inner_defun() + self.assertEqual(ops.get_collection("int"), [1]) + self.assertEqual(ops.get_collection("foo"), []) + return three + + three = defun() + self.assertEqual(three.numpy(), 3) + + +ops.NotDifferentiable("FloatOutput") + + +@ops.RegisterGradient("CopyOp") +def _CopyGrad(op, x_grad): # pylint: disable=invalid-name + _ = op + return x_grad + + +@ops.RegisterGradient("copy_override") +def _CopyOverrideGrad(op, x_grad): # pylint: disable=invalid-name + _ = op + return x_grad + + +class RegistrationTest(test_util.TensorFlowTestCase): + + @test_util.run_deprecated_v1 + def testRegisterGradients(self): + x = test_ops.float_output() + y = test_ops.copy_op(x) + fn = ops.get_gradient_function(y.op) + self.assertEqual(_CopyGrad, fn) + + def testOverrideGradients(self): + g = ops.Graph() + with g.as_default(): + x = test_ops.float_output() + with g.gradient_override_map({"CopyOp": "copy_override"}): + y = test_ops.copy_op(x) + fn = ops.get_gradient_function(y.op) + self.assertEqual(_CopyOverrideGrad, fn) + + def testNonExistentOverride(self): + g = ops.Graph() + with g.as_default(): + x = test_ops.float_output() + with g.gradient_override_map({"CopyOp": "unknown_override"}): + y = test_ops.copy_op(x) + with self.assertRaisesRegexp(LookupError, "unknown_override"): + ops.get_gradient_function(y.op) + + +class ComparisonTest(test_util.TensorFlowTestCase): + + def testMembershipAllowed(self): + g = ops.Graph() + t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1") + t2 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop2") + self.assertTrue(isinstance(t1, ops.Tensor)) + self.assertTrue(isinstance(t2, ops.Tensor)) + self.assertTrue(t1 in [t1]) + self.assertTrue(t1 not in [t2]) + + +class ControlDependenciesTest(test_util.TensorFlowTestCase): + + @test_util.run_deprecated_v1 + def testBasic(self): + g = ops.Graph() + with g.as_default(): + # Creating unregistered ops with _apply_op() doesn't work with the C API + # TODO(skyewm): address this more consistently. Possible solutions are + # to use registered ops in all tests, create a way to register ops in + # Python tests, or conditionally disable the op registration check in + # the C API. + a = constant_op.constant(1.0) + b = constant_op.constant(1.0) + with g.control_dependencies([a]): + c = constant_op.constant(1.0) + d = array_ops.identity(b) + e = array_ops.identity(c) + + self.assertEqual(c.op.control_inputs, [a.op]) + self.assertEqual(d.op.control_inputs, [a.op]) + # e should be dominated by c. + self.assertEqual(e.op.control_inputs, []) + + @test_util.run_in_graph_and_eager_modes + def testEager(self): + def future(): + future.calls += 1 + return constant_op.constant(2.0) + future.calls = 0 + + if context.executing_eagerly(): + a = constant_op.constant(1.0) + b = future + with ops.control_dependencies([a, b]): + c = constant_op.constant(3.0) + self.assertEqual(future.calls, 1) + else: + g = ops.Graph() + with g.as_default(): + a = constant_op.constant(1.0) + b = future() + with g.control_dependencies([a, b]): + c = constant_op.constant(3.0) + self.assertEqual(c.op.control_inputs, [a.op, b.op]) + self.assertEqual(future.calls, 1) + + def testBasicWithConversion(self): + g = ops.Graph() + a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + + class ConvertibleObj(object): + + def _as_graph_element(self): + return a + + with g.control_dependencies([ConvertibleObj()]): + c = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + + self.assertEqual(c.op.control_inputs, [a.op]) + + def testNested(self): + g = ops.Graph() + a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + + with g.control_dependencies([a_1, a_2, a_3, a_4]): + b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + + with g.control_dependencies([a_1]): + with g.control_dependencies([a_2]): + with g.control_dependencies([a_3]): + with g.control_dependencies([a_4]): + b_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + + self.assertItemsEqual([a_1.op, a_2.op, a_3.op, a_4.op], + b_1.op.control_inputs) + self.assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs) + + def testClear(self): + g = ops.Graph() + a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + + with g.control_dependencies([a_1]): + with g.control_dependencies([a_2]): + with g.control_dependencies(None): + with g.control_dependencies([a_3]): + with g.control_dependencies([a_4]): + # deps [a_3, a_4] + b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + # deps = [a_3] + b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + # deps back to None + b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + # deps back to [a_1, a_2] + b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + # deps back to [a_1] + b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + with g.control_dependencies(None): + # deps are None again + b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + + self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs) + self.assertItemsEqual([a_3.op], b_3.op.control_inputs) + self.assertItemsEqual([], b_none.op.control_inputs) + self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs) + self.assertItemsEqual([a_1.op], b_1.op.control_inputs) + self.assertItemsEqual([], b_none2.op.control_inputs) + + def testComplex(self): + g = ops.Graph() + + # Usage pattern: + # * Nodes a_i are constants defined at the outermost scope, and are used + # as control inputs for the ith nested scope. + # * Nodes b_i are defined as Mul(a_3, a_4) at each scope. + # * Nodes c_i are defined as Mul(a_1, b_1) at each scope. + # * Nodes d_i are defined as Mul(b_i, c_i) at each scope. + # * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1. + + a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + + with g.control_dependencies([a_1]): + b_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], + [dtypes.float32]) + c_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], + [dtypes.float32]) + d_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_1, c_1], + [dtypes.float32]) + e_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + with g.control_dependencies([a_2]): + b_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], + [dtypes.float32]) + c_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], + [dtypes.float32]) + d_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_2, c_2], + [dtypes.float32]) + e_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_1, e_1], + [dtypes.float32]) + with g.control_dependencies([a_3]): + b_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], + [dtypes.float32]) + c_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], + [dtypes.float32]) + d_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_3, c_3], + [dtypes.float32]) + e_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_2, e_2], + [dtypes.float32]) + with g.control_dependencies([a_4]): + b_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], + [dtypes.float32]) + c_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], + [dtypes.float32]) + d_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_4, c_4], + [dtypes.float32]) + e_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_3, e_3], + [dtypes.float32]) + + self.assertItemsEqual([a_1.op], b_1.op.control_inputs) + self.assertItemsEqual([a_1.op, a_2.op], b_2.op.control_inputs) + self.assertItemsEqual([a_1.op, a_2.op], b_3.op.control_inputs) + self.assertItemsEqual([a_1.op, a_2.op], b_4.op.control_inputs) + + self.assertItemsEqual([], c_1.op.control_inputs) + self.assertItemsEqual([a_2.op], c_2.op.control_inputs) + self.assertItemsEqual([a_2.op, a_3.op], c_3.op.control_inputs) + self.assertItemsEqual([a_2.op, a_3.op, a_4.op], c_4.op.control_inputs) + + self.assertItemsEqual([], d_1.op.control_inputs) + self.assertItemsEqual([], d_2.op.control_inputs) + self.assertItemsEqual([], d_3.op.control_inputs) + self.assertItemsEqual([], d_4.op.control_inputs) + + self.assertItemsEqual([a_1.op], e_1.op.control_inputs) + self.assertItemsEqual([a_2.op], e_2.op.control_inputs) + self.assertItemsEqual([a_3.op], e_3.op.control_inputs) + self.assertItemsEqual([a_4.op], e_4.op.control_inputs) + + def testRepeatedDependency(self): + g = ops.Graph() + a = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32]) + a_0, a_1 = a.outputs + with g.control_dependencies([a_0]): + b = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + with g.control_dependencies([a_1]): + c = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + + self.assertEqual(b.op.control_inputs, [a]) + self.assertEqual(c.op.control_inputs, [a]) + + def testNoControlDependencyWithDataDependency(self): + g = ops.Graph() + a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + with g.control_dependencies([a]): + b = _apply_op(g, "Identity", [a], [dtypes.float32]) + + self.assertEqual(b.op.control_inputs, []) + + +class OpScopeTest(test_util.TensorFlowTestCase): + + @test_util.run_in_graph_and_eager_modes + def testNames(self): + with ops.name_scope("foo") as foo: + self.assertEqual("foo/", foo) + with ops.name_scope("foo2") as foo2: + self.assertEqual("foo/foo2/", foo2) + with ops.name_scope(None) as empty1: + self.assertEqual("", empty1) + with ops.name_scope("foo3") as foo3: + self.assertEqual("foo3/", foo3) + with ops.name_scope("") as empty2: + self.assertEqual("", empty2) + with ops.name_scope("foo/") as outer_foo: + self.assertEqual("foo/", outer_foo) + with ops.name_scope("") as empty3: + self.assertEqual("", empty3) + with ops.name_scope("foo4") as foo4: + self.assertEqual("foo/foo4/", foo4) + with ops.name_scope("foo5//") as foo5: + self.assertEqual("foo5//", foo5) + with ops.name_scope("foo6") as foo6: + self.assertEqual("foo5//foo6/", foo6) + with ops.name_scope("/") as foo7: + self.assertEqual("/", foo7) + with ops.name_scope("//") as foo8: + self.assertEqual("//", foo8) + with ops.name_scope("a//b/c") as foo9: + self.assertEqual("foo/a//b/c/", foo9) + with ops.name_scope("a//b/c") as foo10: + self.assertEqual("a//b/c/", foo10) + + @test_util.run_in_graph_and_eager_modes + def testEagerDefaultScopeName(self): + with ops.name_scope(None, "default") as scope: + self.assertEqual(scope, "default/") + with ops.name_scope(None, "default2") as scope2: + self.assertEqual(scope2, "default/default2/") + + @test_util.run_deprecated_v1 + def testNoScopeName(self): + g0 = ops.Graph() + values = [ + g0.create_op("A", [], [dtypes.float32]), + g0.create_op("B", [], [dtypes.float32]) + ] + with self.assertRaises(ValueError): + with ops.name_scope(None, values=values): + pass + with self.assertRaises(ValueError): + with ops.name_scope(None, None, values): + pass + + @test_util.run_deprecated_v1 + def testEmptyScopeName(self): + g0 = ops.Graph() + a = g0.create_op("A", [], [dtypes.float32]) + b = g0.create_op("B", [], [dtypes.float32]) + with ops.name_scope("", values=[a, b]) as scope: + self.assertEqual("", scope) + self.assertEqual(g0, ops.get_default_graph()) + with ops.name_scope("", "my_default_scope", [a, b]) as scope: + self.assertEqual("", scope) + self.assertEqual(g0, ops.get_default_graph()) + + @test_util.run_deprecated_v1 + def testDefaultScopeName(self): + g0 = ops.Graph() + a = g0.create_op("A", [], [dtypes.float32]) + b = g0.create_op("B", [], [dtypes.float32]) + scope_name = "my_scope" + default_scope_name = "my_default_scope" + with ops.name_scope(scope_name, default_scope_name, [a, b]) as scope: + self.assertEqual("%s/" % scope_name, scope) + self.assertEqual(g0, ops.get_default_graph()) + with ops.name_scope(None, default_scope_name, [a, b]) as scope: + self.assertEqual("%s/" % default_scope_name, scope) + self.assertEqual(g0, ops.get_default_graph()) + + def _testGraphElements(self, graph_elements): + scope_name = "my_scope" + with ops.name_scope(scope_name, values=graph_elements) as scope: + self.assertEqual("%s/" % scope_name, scope) + self.assertEqual(graph_elements[0].graph, ops.get_default_graph()) + g1 = ops.Graph() + a = g1.create_op("A", [], [dtypes.float32]) + with self.assertRaises(ValueError): + with ops.name_scope(scope_name, values=graph_elements + [a]): + pass + + @test_util.run_deprecated_v1 + def testTensor(self): + g0 = ops.Graph() + a = g0.create_op("A", [], [dtypes.float32]) + b = g0.create_op("B", [], [dtypes.float32]) + self._testGraphElements([a, b]) + + @test_util.run_deprecated_v1 + def testSparseTensor(self): + g0 = ops.Graph() + a = g0.create_op("A", [], [dtypes.float32]) + b = g0.create_op("B", [], [dtypes.float32]) + sparse = sparse_tensor.SparseTensor( + _apply_op(g0, "Int64Output", [], [dtypes.int64]), + _apply_op(g0, "FloatOutput", [], [dtypes.float32]), + _apply_op(g0, "Int64Output", [], [dtypes.int64])) + self._testGraphElements([a, sparse, b]) + + @test_util.run_deprecated_v1 + def testVariable(self): + g0 = ops.Graph() + with g0.as_default(): + variable = variables.Variable([1.0]) + a = g0.create_op("A", [], [dtypes.float32]) + b = g0.create_op("B", [], [dtypes.float32]) + self._testGraphElements([a, variable, b]) + + +class InitScopeTest(test_util.TensorFlowTestCase): + + def testClearsControlDependencies(self): + g = ops.Graph() + a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + + with g.as_default(): + with g.control_dependencies([a_1]): + with g.control_dependencies([a_2]): + with ops.init_scope(): + with g.control_dependencies([a_3]): + with g.control_dependencies([a_4]): + # deps [a_3, a_4] + b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + # deps = [a_3] + b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + # deps back to None + b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + # deps back to [a_1, a_2] + b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + # deps back to [a_1] + b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + with ops.init_scope(): + # deps are None again + b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + + self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs) + self.assertItemsEqual([a_3.op], b_3.op.control_inputs) + self.assertItemsEqual([], b_none.op.control_inputs) + self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs) + self.assertItemsEqual([a_1.op], b_1.op.control_inputs) + self.assertItemsEqual([], b_none2.op.control_inputs) + + def testLiftsOpsFromFunctions(self): + g0 = ops.Graph() + g1 = ops.Graph() + g1._building_function = True # pylint: disable=protected-access + g2 = ops.Graph() + g2._building_function = True # pylint: disable=protected-access + + with g0.as_default(): + with g1.as_default(): + with g2.as_default(): + with ops.init_scope(): + _ = constant_op.constant(1.0) + + self.assertEqual(len(g2.get_operations()), 0) + self.assertEqual(len(g1.get_operations()), 0) + self.assertEqual(len(g0.get_operations()), 1) + + def testPreservesDevices(self): + g0 = ops.Graph() + with g0.as_default(), ops.device("CPU:0"): + g1 = ops.Graph() + g1._building_function = True # pylint: disable=protected-access + with g1.as_default(), ops.device("GPU:0"): + with ops.init_scope(): + # init_scope should preserve device set under `g1`. + on_gpu = constant_op.constant(1.0) + self.assertEqual(on_gpu.device, "/device:GPU:0") + still_on_gpu = constant_op.constant(1.0) + self.assertEqual(still_on_gpu.device, "/device:GPU:0") + on_cpu = constant_op.constant(1.0) + self.assertEqual(on_cpu.device, "/device:CPU:0") + + def testComposes(self): + g0 = ops.Graph() + g1 = ops.Graph() + g1._building_function = True # pylint: disable=protected-access + g2 = ops.Graph() + g2._building_function = True # pylint: disable=protected-access + g3 = ops.Graph() + g3._building_function = False # pylint: disable=protected-access + + with g0.as_default(): + with g1.as_default(): + with ops.init_scope(): + # This op should be lifted into g0. + _ = constant_op.constant(1.0) + self.assertIs(g0, ops.get_default_graph()) + self.assertEqual(len(g2.get_operations()), 0) + self.assertEqual(len(g1.get_operations()), 0) + self.assertEqual(len(g0.get_operations()), 1) + with g2.as_default(): + with ops.init_scope(): + # This op should be lifted into g0. + _ = constant_op.constant(1.0) + self.assertIs(g0, ops.get_default_graph()) + with g3.as_default(): + with ops.init_scope(): + # This op should be lifted into g3, because g3 is not building a + # function. + _ = constant_op.constant(1.0) + self.assertIs(g3, ops.get_default_graph()) + + self.assertEqual(len(g3.get_operations()), 1) + self.assertEqual(len(g2.get_operations()), 0) + self.assertEqual(len(g1.get_operations()), 0) + self.assertEqual(len(g0.get_operations()), 2) + + def testEscapesToEagerContext(self): + g = ops.Graph() + g._building_function = True # pylint: disable=protected-access + with context.eager_mode(): + with context.graph_mode(): + with g.as_default(): + with ops.init_scope(): + # Because g is building a function, init_scope should + # escape out to the eager context. + self.assertTrue(context.executing_eagerly()) + # g should be reinstated as the default graph, and the + # graph context should be re-entered. + self.assertIs(g, ops.get_default_graph()) + self.assertFalse(context.executing_eagerly()) + + def testStaysInEagerWhenOnlyEagerContextActive(self): + with context.eager_mode(): + with ops.init_scope(): + self.assertTrue(context.eager_mode()) + self.assertTrue(context.eager_mode()) + + def testEscapesDefunWhenInEagerMode(self): + + def function_with_variables(): + with ops.init_scope(): + self.v = resource_variable_ops.ResourceVariable(3) + return self.v.assign_add(1) + + with context.eager_mode(): + # Each invocation of function_with_variables recreates a variable. + self.assertEqual(4, int(function_with_variables())) + self.assertEqual(4, int(function_with_variables())) + + compiled = eager_function.defun(function_with_variables) + # The init_scope in function_with_variables lifts the variable out + # of the graph function constructed by defun; hence, + # compiled now appears to be stateful. + self.assertEqual(4, int(compiled())) + self.assertEqual(5, int(compiled())) + + def testEscapesDefunWhenInGraphMode(self): + def function_with_variables(name): + with ops.init_scope(): + _ = variable_scope.get_variable(name, shape=(1,)) + + g = ops.Graph() + with g.as_default(): + with self.cached_session(): + # First ensure that graphs that are not building functions are + # not escaped. + function_with_variables("foo") + with self.assertRaisesRegexp(ValueError, + r"Variable foo already exists.*"): + # This will fail because reuse is not set to True. + function_with_variables("foo") + + compiled = eager_function.defun(function_with_variables) + compiled("bar") + self.assertEqual( + len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2) + + # The second call to `compiled` should not create variables: the + # init_scope has lifted the variable creation code out of the defun. + compiled("bar") + self.assertEqual( + len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2) + + def testEscapesNestedDefun(self): + + def inner_function(): + with ops.init_scope(): + self.v = resource_variable_ops.ResourceVariable(1) + return self.v.assign_add(2) + + def outer_function(inner=None): + with ops.init_scope(): + self.v0 = resource_variable_ops.ResourceVariable(0) + return self.v0.assign_add(1) + inner() + + with context.eager_mode(): + # Each invocation of outer_function recreates variables. + self.assertEqual(4, int(outer_function(inner=inner_function))) + self.assertEqual(4, int(outer_function(inner=inner_function))) + + compiled_inner = eager_function.defun(inner_function) + compiled_outer = eager_function.defun(outer_function) + # The init_scope lifts variables out of the graph functions + # constructed by defun; hence, compiled_outer should now appear to be + # stateful. + self.assertEqual(4, int(compiled_outer(inner=compiled_inner))) + self.assertEqual(7, int(compiled_outer(inner=compiled_inner))) + + @test_util.run_v1_only("b/120545219") + def testFallsBackToGlobalGraphWhenAllGraphsAreBuildingFunctions(self): + with context.graph_mode(): + ops.reset_default_graph() + # This doesn't push anything onto the graph stack, but it does + # set the stack's global graph. + global_graph = ops.get_default_graph() + fn_graph = ops.Graph() + + # pylint: disable=protected-access + fn_graph._building_function = True + self.assertEqual(len(ops._default_graph_stack.stack), 0) + with fn_graph.as_default(): + self.assertEqual(len(ops._default_graph_stack.stack), 1) + with ops.init_scope(): + self.assertGreater(len(ops._default_graph_stack.stack), 1) + dummy = constant_op.constant(1.0) + self.assertEqual(len(ops._default_graph_stack.stack), 1) + # Note that the global graph is _not_ on the graph stack. + self.assertEqual(len(ops._default_graph_stack.stack), 0) + # Ensure that `dummy` was added to the global graph. + self.assertEqual(global_graph, dummy.graph) + # pylint: enable=protected-access + + def testInstallsDefaultGraphWhenGraphStackIsEmptyInGraphMode(self): + with context.graph_mode(): + # pylint: disable=protected-access + self.assertEqual(len(ops._default_graph_stack.stack), 0) + with ops.init_scope(): + self.assertGreater(len(ops._default_graph_stack.stack), 0) + self.assertEqual(len(ops._default_graph_stack.stack), 0) + # pylint: enable=protected-access + + def testPreservesNameScopeInGraphConstruction(self): + with ops.Graph().as_default(): + function_graph = ops.Graph() + with function_graph.as_default(): + with ops.name_scope("inner"), ops.init_scope(): + self.assertEqual(ops.get_name_scope(), "inner") + self.assertEqual(ops.get_name_scope(), "") + + def testEnteringGraphFromEagerIsSticky(self): + with context.eager_mode(): + g = ops.Graph() + with g.as_default(): + with ops.init_scope(): + self.assertFalse(context.executing_eagerly()) + self.assertEqual(g, ops.get_default_graph()) + + def testMixGraphEager(self): + with context.eager_mode(): + c = constant_op.constant(1.0) + with ops.Graph().as_default(): + with self.assertRaisesRegexp( + RuntimeError, "Attempting to capture an EagerTensor"): + math_ops.add(c, c) + c2 = constant_op.constant(2.0) + with self.assertRaisesRegexp( + TypeError, "contains objects other than 'EagerTensor'"): + math_ops.add(c2, c2) + + def testPreservesNameScopeInEagerExecution(self): + with context.eager_mode(): + def foo(): + with ops.name_scope("inner"), ops.init_scope(): + if context.executing_eagerly(): + # A trailing slash is always appended when eager execution is + # enabled. + self.assertEqual(context.context().scope_name, "inner/") + else: + self.assertEqual(ops.get_name_scope(), "inner") + + foo() + self.assertEqual(ops.get_name_scope(), "") + foo_compiled = eager_function.defun(foo) + foo_compiled() + self.assertEqual(ops.get_name_scope(), "") + + def testExecutingEagerlyOutsideFunctions(self): + + @eager_function.defun + def f(): + return ops.executing_eagerly_outside_functions() + + with context.eager_mode(): + self.assertTrue(ops.executing_eagerly_outside_functions()) + self.assertTrue(f()) + g = ops.Graph() + with g.as_default(): + self.assertFalse(ops.executing_eagerly_outside_functions()) + + +class GraphTest(test_util.TensorFlowTestCase): + + def setUp(self): + ops.reset_default_graph() + + def _AssertDefault(self, expected): + self.assertIs(expected, ops.get_default_graph()) + + def testResetDefaultGraphNesting(self): + g0 = ops.Graph() + with self.assertRaises(AssertionError): + with g0.as_default(): + ops.reset_default_graph() + + def testGraphContextManagerCancelsEager(self): + with context.eager_mode(): + with ops.Graph().as_default(): + self.assertFalse(context.executing_eagerly()) + + def testGraphContextManager(self): + g0 = ops.Graph() + with g0.as_default() as g1: + self.assertIs(g0, g1) + + def testDefaultGraph(self): + orig = ops.get_default_graph() + self._AssertDefault(orig) + g0 = ops.Graph() + self._AssertDefault(orig) + context_manager_0 = g0.as_default() + self._AssertDefault(orig) + with context_manager_0 as g0: + self._AssertDefault(g0) + with ops.Graph().as_default() as g1: + self._AssertDefault(g1) + self._AssertDefault(g0) + self._AssertDefault(orig) + + def testPreventFeeding(self): + g = ops.Graph() + a = constant_op.constant(2.0) + self.assertTrue(g.is_feedable(a)) + g.prevent_feeding(a) + self.assertFalse(g.is_feedable(a)) + + @test_util.run_deprecated_v1 + def testPreventFetching(self): + g = ops.Graph() + a = constant_op.constant(2.0) + self.assertTrue(g.is_fetchable(a)) + g.prevent_fetching(a.op) + self.assertFalse(g.is_fetchable(a)) + + def testAsGraphElementConversions(self): + + class ConvertibleObj(object): + + def _as_graph_element(self): + return "FloatOutput:0" + + class NonConvertibleObj(object): + + pass + + g = ops.Graph() + a = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + self.assertEqual(a, g.as_graph_element(ConvertibleObj())) + with self.assertRaises(TypeError): + g.as_graph_element(NonConvertibleObj()) + + # Regression test against creating custom __del__ functions in classes + # involved in cyclic references, e.g. Graph and Operation. (Python won't gc + # cycles that require calling a __del__ method, because the __del__ method can + # theoretically increase the object's refcount to "save" it from gc, and any + # already-deleted objects in the cycle would have be to restored.) + def testGarbageCollected(self): + # Create a graph we can delete and a weak reference to monitor if it's gc'd + g = ops.Graph() + g_ref = weakref.ref(g) + # Create some ops + with g.as_default(): + a = constant_op.constant(2.0) + b = constant_op.constant(3.0) + c = math_ops.add(a, b) + # Create a session we can delete + with session.Session(graph=g) as sess: + self.evaluate(c) + # Delete all references and trigger gc + del g + del a + del b + del c + del sess + gc.collect() + self.assertIsNone(g_ref()) + + def testRunnableAfterInvalidShape(self): + with ops.Graph().as_default(): + with self.assertRaises(ValueError): + math_ops.add([1, 2], [1, 2, 3]) + a = constant_op.constant(1) + with session.Session() as sess: + self.evaluate(a) + + def testRunnableAfterInvalidShapeWithKernelLabelMap(self): + g = ops.Graph() + with g.as_default(): + with g._kernel_label_map({"KernelLabelRequired": "overload_1"}): + with self.assertRaises(ValueError): + test_ops.kernel_label_required(1) + a = constant_op.constant(1) + with session.Session() as sess: + self.evaluate(a) + + +class AttrScopeTest(test_util.TensorFlowTestCase): + + def _get_test_attrs(self): + x = control_flow_ops.no_op() + try: + a = compat.as_text(x.get_attr("_A")) + except ValueError: + a = None + try: + b = compat.as_text(x.get_attr("_B")) + except ValueError: + b = None + return (a, b) + + @test_util.run_deprecated_v1 + def testNoLabel(self): + with self.cached_session(): + self.assertAllEqual((None, None), self._get_test_attrs()) + + @test_util.run_deprecated_v1 + def testLabelMap(self): + with self.cached_session() as sess: + a1 = self._get_test_attrs() + with sess.graph._attr_scope({ + "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("foo")) + }): + a2 = self._get_test_attrs() + with sess.graph._attr_scope({ + "_A": None, + "_B": attr_value_pb2.AttrValue(s=compat.as_bytes("bar")) + }): + a3 = self._get_test_attrs() + with sess.graph._attr_scope({ + "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("baz")) + }): + a4 = self._get_test_attrs() + a5 = self._get_test_attrs() + a6 = self._get_test_attrs() + a7 = self._get_test_attrs() + + self.assertAllEqual((None, None), a1) + self.assertAllEqual(("foo", None), a2) + self.assertAllEqual((None, "bar"), a3) + self.assertAllEqual(("baz", "bar"), a4) + self.assertAllEqual((None, "bar"), a5) + self.assertAllEqual(("foo", None), a6) + self.assertAllEqual((None, None), a7) + + +ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape) + + +class KernelLabelTest(test_util.TensorFlowTestCase): + + @test_util.run_deprecated_v1 + def testNoLabel(self): + with self.cached_session(): + self.assertAllEqual(b"My label is: default", + test_ops.kernel_label().eval()) + + @test_util.run_deprecated_v1 + def testLabelMap(self): + with self.cached_session() as sess: + default_1 = test_ops.kernel_label() + # pylint: disable=protected-access + with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}): + overload_1_1 = test_ops.kernel_label() + with sess.graph._kernel_label_map({"KernelLabel": "overload_2"}): + overload_2 = test_ops.kernel_label() + with sess.graph._kernel_label_map({"KernelLabel": ""}): + default_2 = test_ops.kernel_label() + overload_1_2 = test_ops.kernel_label() + # pylint: enable=protected-access + default_3 = test_ops.kernel_label() + + self.assertAllEqual(b"My label is: default", self.evaluate(default_1)) + self.assertAllEqual(b"My label is: default", self.evaluate(default_2)) + self.assertAllEqual(b"My label is: default", self.evaluate(default_3)) + self.assertAllEqual(b"My label is: overload_1", + self.evaluate(overload_1_1)) + self.assertAllEqual(b"My label is: overload_1", + self.evaluate(overload_1_2)) + self.assertAllEqual(b"My label is: overload_2", self.evaluate(overload_2)) + + +class AsGraphDefTest(test_util.TensorFlowTestCase): + + def testGraphDefVersion(self): + """Test that the graphdef version is plumbed through to kernels.""" + with ops.Graph().as_default() as g: + version = g.graph_def_versions.producer + with self.session(graph=g): + v = test_ops.graph_def_version().eval() + self.assertEqual(version, v) + + def testAddShapes(self): + with ops.Graph().as_default() as g: + t1, t2, t3, t4, t5 = _apply_op(g, "FiveFloatOutputs", [], + [dtypes.float32] * 5) + t1.set_shape(None) + t2.set_shape([]) + t3.set_shape([None]) + t4.set_shape([43, 37]) + t5.set_shape([43, None]) + + b = constant_op.constant(1.0) # pylint: disable=unused-variable + + gd = g.as_graph_def(add_shapes=True) + self.assertProtoEqualsVersion(""" + node { name: "FiveFloatOutputs" op: "FiveFloatOutputs" + attr { + key: "_output_shapes" + value { + list { + shape { unknown_rank: true } + shape { } + shape { dim { size: -1 } } + shape { dim { size: 43 } dim { size: 37 } } + shape { dim { size: 43 } dim { size: -1 } } + } + } + } + } + node { name: "Const" op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { } + } + } + } + attr { + key: "dtype" + value { type: DT_FLOAT } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { } + float_val: 1.0 } } } } + """, gd) + + +@ops.RegisterStatistics("a", "flops") +def _calc_a_forward_flops(unused_graph, unused_node): + return ops.OpStats("flops", 20) + + +class StatisticsTest(test_util.TensorFlowTestCase): + + def testRegisteredNode(self): + graph = ops.Graph() + node = ops._NodeDef("a", "an_a") + flops = ops.get_stats_for_node_def(graph, node, "flops") + self.assertEqual(20, flops.value) + missing_stat = ops.get_stats_for_node_def(graph, node, "missing_stat") + self.assertEqual(None, missing_stat.value) + + def testUnregisteredNode(self): + graph = ops.Graph() + node = ops._NodeDef("b", "a_b") + weight_params = ops.get_stats_for_node_def(graph, node, "weight_params") + self.assertEqual(None, weight_params.value) + + def testAccumulateStatistics(self): + flops_total = ops.OpStats("flops") + self.assertEqual(None, flops_total.value) + second_flops = ops.OpStats("flops", 3) + flops_total += second_flops + self.assertEqual(3, flops_total.value) + + +class DeviceStackTest(test_util.TensorFlowTestCase): + + @test_util.run_deprecated_v1 + def testBasicDeviceAssignmentMetadata(self): + + def device_func(unused_op): + return "/cpu:*" + + const_zero = constant_op.constant([0.0], name="zero") + with ops.device("/cpu"): + const_one = constant_op.constant([1.0], name="one") + with ops.device("/cpu:0"): + const_two = constant_op.constant([2.0], name="two") + with ops.device(device_func): + const_three = constant_op.constant(3.0, name="three") + + self.assertEqual(0, len(const_zero.op._device_assignments)) + + one_list = const_one.op._device_assignments + self.assertEqual(1, len(one_list)) + self.assertEqual("/cpu", one_list[0].obj) + self.assertEqual("ops_test.py", os.path.basename(one_list[0].filename)) + + two_list = const_two.op._device_assignments + self.assertEqual(2, len(two_list)) + devices = [t.obj for t in two_list] + self.assertEqual(set(["/cpu", "/cpu:0"]), set(devices)) + + three_list = const_three.op._device_assignments + self.assertEqual(1, len(three_list)) + func_description = three_list[0].obj + expected_regex = r"device_func<.*ops_test.py, [0-9]+" + self.assertRegexpMatches(func_description, expected_regex) + + @test_util.run_deprecated_v1 + def testDeviceAssignmentMetadataForGraphDeviceAndTfDeviceFunctions(self): + + with ops.device("/cpu"): + const_one = constant_op.constant([1.0], name="one") + with ops.get_default_graph().device("/cpu"): + const_two = constant_op.constant([2.0], name="two") + + one_metadata = const_one.op._device_assignments[0] + two_metadata = const_two.op._device_assignments[0] + + # Verify both types of device assignment return the right stack info. + self.assertRegexpMatches("ops_test.py", + os.path.basename(one_metadata.filename)) + self.assertEqual(one_metadata.filename, two_metadata.filename) + self.assertEqual(one_metadata.lineno + 2, two_metadata.lineno) + + +class ColocationGroupTest(test_util.TensorFlowTestCase): + + @test_util.run_deprecated_v1 + def testBasic(self): + a = constant_op.constant([2.0], name="a") + with ops.colocate_with(a.op): + b = constant_op.constant(3.0) + c = constant_op.constant(4.0) + self.assertEqual([b"loc:@a"], a.op.colocation_groups()) + self.assertEqual([b"loc:@a"], b.op.colocation_groups()) + with self.assertRaises(ValueError): + c.op.get_attr("_class") + + @test_util.run_deprecated_v1 + def testBasicColocationMetadata(self): + const_two = constant_op.constant([2.0], name="two") + with ops.colocate_with(const_two.op): + const_three = constant_op.constant(3.0, name="three") + locations_dict = const_three.op._colocation_dict + self.assertIn("two", locations_dict) + metadata = locations_dict["two"] + self.assertIsNone(metadata.obj) + # Check that this test's filename is recorded as the file containing the + # colocation statement. + self.assertEqual("ops_test.py", os.path.basename(metadata.filename)) + + @test_util.run_deprecated_v1 + def testColocationDeviceInteraction(self): + with ops.device("/cpu:0"): + with ops.device("/device:GPU:0"): + a = constant_op.constant([2.0], name="a") + with ops.colocate_with(a.op): + # 'b' is created in the scope of /cpu:0, but it is + # colocated with 'a', which is on '/device:GPU:0'. colocate_with + # overrides devices because it is a stronger constraint. + b = constant_op.constant(3.0) + self.assertEqual([b"loc:@a"], b.op.colocation_groups()) + self.assertEqual(a.op.device, b.op.device) + + @test_util.run_deprecated_v1 + def testColocationCanonicalization(self): + with ops.device("/device:GPU:0"): + _ = constant_op.constant(2.0) + with ops.device(lambda op: "/device:GPU:0"): + b = constant_op.constant(3.0) + with ops.get_default_graph().colocate_with(b): + with ops.device("/device:GPU:0"): + c = constant_op.constant(4.0) + + # A's device will be /device:GPU:0 + # B's device will be /device:GPU:0 + # C's device will be /device:GPU:0 because it + # inherits B's device name, after canonicalizing the names. + self.assertEqual(b.op.device, c.op.device) + + @test_util.run_deprecated_v1 + def testLocationOverrides(self): + with ops.device("/cpu:0"): + with ops.device("/device:GPU:0"): + a = constant_op.constant([2.0], name="a") + # Note that this colocation is "redundant", since we are + # within the scope of "/device:GPU:0". However, we would like to + # preserve in the GraphDef that these two ops should be + # colocated in a portable way. + with ops.colocate_with(a.op): + b = constant_op.constant(3.0) + c = constant_op.constant(4.0) + d = constant_op.constant(5.0) + + self.assertEqual([b"loc:@a"], b.op.colocation_groups()) + self.assertEqual("/device:GPU:0", a.op.device) + self.assertEqual(a.op.device, b.op.device) + + # Test that device function stack is restored. + self.assertEqual("/device:GPU:0", c.op.device) + self.assertEqual("/device:CPU:0", d.op.device) + + @test_util.run_deprecated_v1 + def testNestedColocateWith(self): + a = constant_op.constant([2.0], name="a") + with ops.colocate_with(a.op): + b = constant_op.constant(3.0) + with ops.colocate_with(b.op): + c = constant_op.constant(4.0) + self.assertEqual([b"loc:@a"], b.op.colocation_groups()) + self.assertEqual([b"loc:@a"], c.op.colocation_groups()) + + @test_util.run_deprecated_v1 + def testMultiColocationGroups(self): + a = constant_op.constant([2.0], name="a") + b = constant_op.constant(3.0, name="b") + with ops.colocate_with(a.op): + with ops.colocate_with(b.op): + c = constant_op.constant(4.0) + self.assertEqual(set([b"loc:@a", b"loc:@b"]), set(c.op.colocation_groups())) + + @test_util.run_deprecated_v1 + def testColocationIgnoreStack(self): + a = constant_op.constant([2.0], name="a") + b = constant_op.constant(3.0, name="b") + with ops.colocate_with(a.op): + with ops.colocate_with(b.op, ignore_existing=True): + c = constant_op.constant(4.0) + self.assertEqual(set([b"loc:@b"]), set(c.op.colocation_groups())) + + @test_util.run_deprecated_v1 + def testColocateWithReset(self): + a = constant_op.constant([2.0], name="a") + with ops.colocate_with(a.op): + b = constant_op.constant(3.0, name="b") + with ops.colocate_with(None, ignore_existing=True): + c = constant_op.constant(4.0, name="c") + self.assertEqual([b"loc:@a"], b.op.colocation_groups()) + self.assertEqual([b"loc:@c"], c.op.colocation_groups()) + + @test_util.run_deprecated_v1 + def testColocateWithInitialNoneThenNested(self): + a = constant_op.constant([2.0], name="a") + with ops.colocate_with(a.op): + with ops.colocate_with(None, ignore_existing=True): + b = constant_op.constant(3.0, name="b") + with ops.colocate_with(b.op): + c = constant_op.constant(4.0, name="c") + self.assertEqual([b"loc:@b"], b.op.colocation_groups()) + self.assertEqual([b"loc:@b"], c.op.colocation_groups()) + + @test_util.run_deprecated_v1 + def testColocateVariables(self): + a = variables.Variable([2.0], name="a") + with ops.colocate_with(a.op): + b = variables.Variable([3.0], name="b") + self.assertEqual([b"loc:@a"], b.op.colocation_groups()) + + +class DeprecatedTest(test_util.TensorFlowTestCase): + + def testSuccess(self): + with ops.Graph().as_default() as g: + test_util.set_producer_version(g, 7) + old = test_ops.old() + with self.session(graph=g): + old.run() + + def _error(self): + return ((r"Op Old is not available in GraphDef version %d\. " + r"It has been removed in version 8\. For reasons\.") % + versions.GRAPH_DEF_VERSION) + + def testGraphConstructionFail(self): + with ops.Graph().as_default(): + with self.assertRaisesRegexp(NotImplementedError, self._error()): + test_ops.old() + + +class DenseTensorLikeTypeTest(test_util.TensorFlowTestCase): + + def testSuccess(self): + op = ops.Operation( + ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32]) + t = op.outputs[0] + self.assertTrue(ops.is_dense_tensor_like(t)) + + v = variables.Variable([17]) + self.assertTrue(ops.is_dense_tensor_like(v)) + + class BadClassNoName(object): + pass + + class BadClassBadName(object): + + def name(self): + pass + + class BadClassNoDtype(object): + + @property + def name(self): + pass + + class BadClassBadDtype(object): + + @property + def name(self): + pass + + def dtype(self): + pass + + def testBadClass(self): + with self.assertRaisesRegexp(TypeError, "`name`"): + ops.register_dense_tensor_like_type( + DenseTensorLikeTypeTest.BadClassNoName) + with self.assertRaisesRegexp(TypeError, "`name`"): + ops.register_dense_tensor_like_type( + DenseTensorLikeTypeTest.BadClassBadName) + with self.assertRaisesRegexp(TypeError, "`dtype`"): + ops.register_dense_tensor_like_type( + DenseTensorLikeTypeTest.BadClassNoDtype) + with self.assertRaisesRegexp(TypeError, "`dtype`"): + ops.register_dense_tensor_like_type( + DenseTensorLikeTypeTest.BadClassBadDtype) + + +class NameScopeTest(test_util.TensorFlowTestCase): + + def testStripAndPrependScope(self): + strs = [ + "hidden1/hidden1/weights", # Same prefix. Should strip. + "hidden1///hidden1/weights", # Extra "/". Should strip. + "^hidden1/hidden1/weights", # Same prefix. Should strip. + "loc:@hidden1/hidden1/weights", # Same prefix. Should strip. + "hhidden1/hidden1/weights", # Different prefix. Should keep. + "hidden1" + ] # Not a prefix. Should keep. + expected_striped = [ + "hidden1/weights", "hidden1/weights", "^hidden1/weights", + "loc:@hidden1/weights", "hhidden1/hidden1/weights", "hidden1" + ] + expected_prepended = [ + "hidden2/hidden1/weights", "hidden2/hidden1/weights", + "^hidden2/hidden1/weights", "loc:@hidden2/hidden1/weights", + "hidden2/hhidden1/hidden1/weights", "hidden2/hidden1" + ] + name_scope_to_strip = "hidden1" + name_scope_to_add = "hidden2" + for es, ep, s in zip(expected_striped, expected_prepended, strs): + striped = ops.strip_name_scope(s, name_scope_to_strip) + self.assertEqual(es, striped) + self.assertEqual(ep, ops.prepend_name_scope(striped, name_scope_to_add)) + + def testGetNameScope(self): + with ops.Graph().as_default() as g: + with ops.name_scope("scope1"): + with ops.name_scope("scope2"): + with ops.name_scope("scope3"): + self.assertEqual("scope1/scope2/scope3", g.get_name_scope()) + self.assertEqual("scope1/scope2", g.get_name_scope()) + self.assertEqual("scope1", g.get_name_scope()) + self.assertEqual("", g.get_name_scope()) + + def testTwoGraphs(self): + + def f(): + g1 = ops.Graph() + g2 = ops.Graph() + with g1.as_default(): + with g2.as_default(): + with ops.name_scope("_"): + pass + + self.assertRaisesRegexp(ValueError, "'_' is not a valid scope name", f) + + +class TracebackTest(test_util.TensorFlowTestCase): + + @test_util.run_deprecated_v1 + def testTracebackWithStartLines(self): + with self.cached_session() as sess: + a = constant_op.constant(2.0) + sess.run( + a, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE)) + self.assertTrue(sess.graph.get_operations()) + + # Tests that traceback_with_start_lines is the same as traceback + # but includes one more element at the end. + for op in sess.graph.get_operations(): + self.assertEquals(len(op.traceback), len(op.traceback_with_start_lines)) + for frame, frame_with_start_line in zip( + op.traceback, op.traceback_with_start_lines): + self.assertEquals(5, len(frame_with_start_line)) + self.assertEquals(frame, frame_with_start_line[:-1]) + + +class EnableEagerExecutionTest(test_util.TensorFlowTestCase): + + @test_util.run_v1_only("b/120545219") + def testBadArgumentsToEnableEagerExecution(self): + with self.assertRaisesRegexp(TypeError, "config must be a tf.ConfigProto"): + ops.enable_eager_execution(context.DEVICE_PLACEMENT_SILENT) + with self.assertRaisesRegexp(ValueError, "device_policy must be one of"): + c = config_pb2.ConfigProto() + ops.enable_eager_execution(c, c) + with self.assertRaisesRegexp(ValueError, "execution_mode must be one of"): + c = config_pb2.ConfigProto() + ops.enable_eager_execution(c, execution_mode=c) + + +if __name__ == "__main__": + googletest.main()