You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ops_test_r1.13.py 114 kB


  1. # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for tensorflow.python.framework.ops."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import gc
  20. import os
  21. import threading
  22. import weakref
  23. from tensorflow.core.framework import attr_value_pb2
  24. from tensorflow.core.protobuf import config_pb2
  25. from tensorflow.python.client import session
  26. from tensorflow.python.eager import context
  27. from tensorflow.python.eager import function as eager_function
  28. from tensorflow.python.framework import common_shapes
  29. from tensorflow.python.framework import constant_op
  30. from tensorflow.python.framework import device as pydev
  31. from tensorflow.python.framework import dtypes
  32. from tensorflow.python.framework import errors
  33. from tensorflow.python.framework import function
  34. from tensorflow.python.framework import ops
  35. from tensorflow.python.framework import sparse_tensor
  36. from tensorflow.python.framework import tensor_shape
  37. from tensorflow.python.framework import tensor_util
  38. from tensorflow.python.framework import test_ops
  39. from tensorflow.python.framework import test_util
  40. from tensorflow.python.framework import versions
  41. from tensorflow.python.ops import array_ops
  42. from tensorflow.python.ops import control_flow_ops
  43. from tensorflow.python.ops import math_ops
  44. from tensorflow.python.ops import resource_variable_ops
  45. from tensorflow.python.ops import resources
  46. from tensorflow.python.ops import variable_scope
  47. from tensorflow.python.ops import variables
  48. import tensorflow.python.ops.gradients # pylint: disable=unused-import
  49. from tensorflow.python.platform import googletest
  50. from tensorflow.python.util import compat
  51. ops._set_call_cpp_shape_fn(common_shapes.call_cpp_shape_fn)
  52. class ResourceTest(test_util.TensorFlowTestCase):
  53. @test_util.run_deprecated_v1
  54. def testBuildGraph(self):
  55. with self.cached_session():
  56. pt = test_ops.stub_resource_handle_op(container="a", shared_name="b")
  57. test_ops.resource_create_op(pt).run()
  58. @test_util.run_deprecated_v1
  59. def testInitialize(self):
  60. with self.cached_session():
  61. handle = test_ops.stub_resource_handle_op(container="a", shared_name="b")
  62. resources.register_resource(
  63. handle=handle,
  64. create_op=test_ops.resource_create_op(handle),
  65. is_initialized_op=test_ops.resource_initialized_op(handle))
  66. self.assertEquals(
  67. len(
  68. resources.report_uninitialized_resources(
  69. resources.shared_resources()).eval()), 1)
  70. resources.initialize_resources(resources.shared_resources()).run()
  71. self.assertEquals(
  72. len(
  73. resources.report_uninitialized_resources(
  74. resources.shared_resources()).eval()), 0)
  75. class TensorAndShapeTest(test_util.TensorFlowTestCase):
  76. def testShape(self):
  77. op = ops.Operation(
  78. ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
  79. t = op.outputs[0]
  80. self.assertEqual(tensor_shape.unknown_shape(), t.get_shape())
  81. t.set_shape([1, 2, 3])
  82. self.assertEqual([1, 2, 3], t.get_shape())
  83. def testIterable(self):
  84. op = ops.Operation(
  85. ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
  86. t = op.outputs[0]
  87. self.assertTrue(isinstance(t, ops.Tensor))
  88. with self.assertRaisesRegexp(TypeError, "iter"):
  89. for _ in t:
  90. pass
  91. def testAddShape(self):
  92. with self.cached_session():
  93. a = array_ops.zeros([2, 3])
  94. b = array_ops.ones([1, 3])
  95. c = a + b
  96. self.assertEqual([2, 3], c.shape)
  97. @test_util.run_deprecated_v1
  98. def testUnknownDim(self):
  99. with self.cached_session():
  100. a = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
  101. b = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
  102. c = a + b
  103. self.assertEqual([2, None, 3], c.shape.as_list())
  104. @test_util.run_deprecated_v1
  105. def testUnknownShape(self):
  106. with self.cached_session():
  107. a = array_ops.placeholder(dtype=dtypes.float32, shape=None)
  108. b = array_ops.ones([1, 3])
  109. c = a + b
  110. self.assertEqual(tensor_shape.unknown_shape(), c.shape)
  111. @test_util.run_deprecated_v1
  112. def testScalarShape(self):
  113. with self.cached_session():
  114. a = array_ops.placeholder(dtype=dtypes.float32, shape=[])
  115. b = array_ops.ones([])
  116. c = a + b
  117. self.assertEqual(tensor_shape.scalar(), c.shape)
  118. @test_util.run_deprecated_v1
  119. def testShapeFunctionError(self):
  120. with self.cached_session():
  121. a = array_ops.ones([1, 2, 3])
  122. b = array_ops.ones([4, 5, 6])
  123. with self.assertRaisesRegexp(
  124. ValueError,
  125. r"Dimensions must be equal, but are 2 and 5 for 'add' \(op: 'Add'\) "
  126. r"with input shapes: \[1,2,3\], \[4,5,6\]."):
  127. _ = a + b
  128. class IndexedSlicesTest(test_util.TensorFlowTestCase):
  129. @test_util.run_in_graph_and_eager_modes
  130. def testToTensor(self):
  131. values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
  132. indices = constant_op.constant([0, 2])
  133. dense_shape = constant_op.constant([3, 2])
  134. x = ops.IndexedSlices(values, indices, dense_shape)
  135. tensor = ops.convert_to_tensor(x, name="tensor")
  136. self.assertAllEqual(self.evaluate(tensor), [[2, 3], [0, 0], [5, 7]])
  137. @test_util.run_deprecated_v1
  138. def testNegation(self):
  139. with self.cached_session():
  140. values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
  141. indices = constant_op.constant([0, 2])
  142. x = -ops.IndexedSlices(values, indices)
  143. self.assertAllEqual(x.values.eval(), [[-2, -3], [-5, -7]])
  144. self.assertAllEqual(x.indices.eval(), [0, 2])
  145. @test_util.run_deprecated_v1
  146. def testScalarMul(self):
  147. with self.cached_session():
  148. values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
  149. indices = constant_op.constant([0, 2])
  150. x = math_ops.scalar_mul(-2, ops.IndexedSlices(values, indices))
  151. self.assertAllEqual(x.values.eval(), [[-4, -6], [-10, -14]])
  152. self.assertAllEqual(x.indices.eval(), [0, 2])
  153. class NodeDefConstructorTest(test_util.TensorFlowTestCase):
  154. def testNoArgs(self):
  155. nodedef = ops._NodeDef("None", "bar")
  156. self.assertProtoEquals("op: 'None' name: 'bar'", nodedef)
  157. def testArgs(self):
  158. nodedef = ops._NodeDef("foo", "bar", device="/device:baz:*")
  159. self.assertProtoEquals("op:'foo' name:'bar' device:'/device:baz:*'",
  160. nodedef)
  161. nodedef = ops._NodeDef("foo", "bar", device=pydev.DeviceSpec(job="j"))
  162. self.assertProtoEquals("op:'foo' name:'bar' device:'/job:j'", nodedef)
  163. def _apply_op(g, *args, **kwargs):
  164. op = g.create_op(*args, **kwargs)
  165. if len(op.outputs) == 1:
  166. return op.outputs[0]
  167. else:
  168. return op.outputs
  169. class OperationTest(test_util.TensorFlowTestCase):
  170. @test_util.run_deprecated_v1
  171. def testNoInputs(self):
  172. op = test_ops.float_output_string_output(name="myop").a.op
  173. self.assertEqual(2, len(op.values()))
  174. self.assertEqual(0, len(op.inputs))
  175. self.assertEqual("myop", op.name)
  176. float_t, label_str_t = op.values()
  177. self.assertEqual(dtypes.float32, float_t.dtype)
  178. self.assertEqual(op, float_t.op)
  179. self.assertEqual(0, float_t._value_index)
  180. self.assertEqual(0, len(float_t.consumers()))
  181. self.assertEqual("myop", float_t._as_node_def_input())
  182. self.assertEqual(dtypes.string, label_str_t.dtype)
  183. self.assertEqual(op, label_str_t.op)
  184. self.assertEqual(1, label_str_t._value_index)
  185. self.assertEqual(0, len(label_str_t.consumers()))
  186. self.assertEqual("myop:1", label_str_t._as_node_def_input())
  187. self.assertProtoEquals("op:'FloatOutputStringOutput' name:'myop'",
  188. op.node_def)
  189. @test_util.run_deprecated_v1
  190. def testNoOutputs(self):
  191. op1 = test_ops.float_output(name="myop1").op
  192. float_t, = op1.values()
  193. op2 = test_ops.float_input(float_t, name="myop2")
  194. self.assertEqual(0, len(op2.values()))
  195. self.assertEqual(1, len(op2.inputs))
  196. self.assertIs(float_t, op2.inputs[0])
  197. self.assertEqual(1, len(float_t.consumers()))
  198. self.assertEqual(op2, float_t.consumers()[0])
  199. self.assertProtoEquals("op:'FloatOutput' name:'myop1'", op1.node_def)
  200. self.assertProtoEquals("op:'FloatInput' name:'myop2' input:'myop1'",
  201. op2.node_def)
  202. @test_util.run_deprecated_v1
  203. def testInputsAndOutputs(self):
  204. op1 = test_ops.float_output(name="myop1").op
  205. self.assertEqual(1, len(op1.values()))
  206. float1_t, = op1.values()
  207. op2 = test_ops.float_output_string_output(name="myop2").a.op
  208. self.assertEqual(2, len(op2.values()))
  209. float2_t, label2_str_t = op2.values()
  210. # Note that we consume label2_str_t twice here.
  211. op3 = test_ops.foo2(float1_t, label2_str_t, label2_str_t, name="myop3").d.op
  212. self.assertEqual(2, len(op3.values()))
  213. self.assertEqual(1, len(float1_t.consumers()))
  214. self.assertEqual(op3, float1_t.consumers()[0])
  215. self.assertEqual(0, len(float2_t.consumers()))
  216. self.assertEqual(2, len(label2_str_t.consumers()))
  217. self.assertEqual(op3, label2_str_t.consumers()[0])
  218. self.assertEqual(op3, label2_str_t.consumers()[1])
  219. self.assertProtoEquals("""
  220. op:'Foo2' name:'myop3'
  221. input:'myop1' input:'myop2:1' input:'myop2:1'
  222. """, op3.node_def)
  223. def testDeviceFromNodeDef(self):
  224. op = ops.Operation(
  225. ops._NodeDef("None", "myop", device="/job:goo/device:GPU:0"),
  226. ops.Graph(), [], [])
  227. self.assertEqual("/job:goo/device:GPU:0", op.device)
  228. def testDeviceObject(self):
  229. op = ops.Operation(ops._NodeDef("None", "myop"), ops.Graph(), [], [])
  230. op._set_device("/job:goo/device:GPU:0")
  231. self.assertProtoEquals(
  232. "op:'None' name:'myop' device:'/job:goo/device:GPU:0' ", op.node_def)
  233. op = ops.Operation(ops._NodeDef("None", "op2"), ops.Graph(), [], [])
  234. op._set_device(
  235. pydev.DeviceSpec(
  236. job="muu", device_type="CPU", device_index=0))
  237. self.assertProtoEquals(
  238. "op:'None' name:'op2' device:'/job:muu/device:CPU:0'", op.node_def)
  239. def testReferenceInput(self):
  240. g = ops.Graph()
  241. op1 = ops.Operation(
  242. ops._NodeDef("RefOutputFloatOutput", "op1"), g, [],
  243. [dtypes.float32_ref, dtypes.float32])
  244. self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def)
  245. self.assertEquals([], list(op1.inputs))
  246. ref_t, nonref_t = op1.values()
  247. # NOTE(mrry): Must specify input_types to preserve ref-typed input.
  248. op2 = ops.Operation(
  249. ops._NodeDef("RefInputFloatInput", "op2"),
  250. g, [ref_t, nonref_t], [],
  251. input_types=[dtypes.float32_ref, dtypes.float32])
  252. self.assertProtoEquals(
  253. "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'",
  254. op2.node_def)
  255. self.assertEquals([ref_t, nonref_t], list(op2.inputs))
  256. op3 = ops.Operation(
  257. ops._NodeDef("TwoFloatInputs", "op3"), g, [ref_t, nonref_t], [])
  258. self.assertProtoEquals(
  259. "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'",
  260. op3.node_def)
  261. def testInvalidNames(self):
  262. g = ops.Graph()
  263. with self.assertRaises(ValueError):
  264. ops.Operation(ops._NodeDef("op", ""), g)
  265. with self.assertRaises(ValueError):
  266. ops.Operation(ops._NodeDef("op", "_invalid"), g)
  267. with self.assertRaises(ValueError):
  268. ops.Operation(ops._NodeDef("op", "-invalid"), g)
  269. with self.assertRaises(ValueError):
  270. ops.Operation(ops._NodeDef("op", "/invalid"), g)
  271. with self.assertRaises(ValueError):
  272. ops.Operation(ops._NodeDef("op", "invalid:0"), g)
  273. @test_util.run_deprecated_v1
  274. def testNoShapeFunction(self):
  275. op = test_ops.a()
  276. self.assertEqual(tensor_shape.unknown_shape(), op.get_shape())
  277. @test_util.run_in_graph_and_eager_modes
  278. def testConvertToTensorNestedArray(self):
  279. values = [[2], [3], [5], [7]]
  280. tensor = ops.convert_to_tensor(values)
  281. self.assertAllEqual((4, 1), tensor.get_shape().as_list())
  282. self.assertAllEqual(values, self.evaluate(tensor))
  283. def testShapeTuple(self):
  284. with self.cached_session():
  285. c = constant_op.constant(1)
  286. self.assertEqual(c._shape_tuple(), ()) # pylint: disable=protected-access
  287. def testConvertToTensorEager(self):
  288. with context.eager_mode():
  289. t = constant_op.constant(1)
  290. self.assertTrue(isinstance(t, ops.EagerTensor))
  291. converted = ops.convert_to_tensor(t)
  292. self.assertTrue(isinstance(converted, ops.EagerTensor))
  293. converted = ops.convert_to_tensor(1)
  294. self.assertTrue(isinstance(converted, ops.EagerTensor))
  295. @test_util.run_in_graph_and_eager_modes
  296. def testConvertToTensorNestedTuple(self):
  297. values = ((2,), (3,), (5,), (7,))
  298. tensor = ops.convert_to_tensor(values)
  299. self.assertAllEqual((4, 1), tensor.get_shape().as_list())
  300. self.assertAllEqual(values, self.evaluate(ops.convert_to_tensor(values)))
  301. @test_util.run_in_graph_and_eager_modes
  302. def testConvertToTensorNestedTensors(self):
  303. values = ((2,), (3,), (5,), (7,))
  304. tensor = ops.convert_to_tensor(
  305. [constant_op.constant(row) for row in values])
  306. self.assertAllEqual((4, 1), tensor.get_shape().as_list())
  307. self.assertAllEqual(values, self.evaluate(tensor))
  308. tensor = ops.convert_to_tensor(
  309. [[constant_op.constant(v) for v in row] for row in values])
  310. self.assertAllEqual((4, 1), tensor.get_shape().as_list())
  311. self.assertAllEqual(values, self.evaluate(tensor))
  312. @test_util.run_in_graph_and_eager_modes
  313. def testConvertToTensorNestedMix(self):
  314. values = ([2], (3,), [constant_op.constant(5)], constant_op.constant([7]))
  315. tensor = ops.convert_to_tensor(values)
  316. self.assertAllEqual((4, 1), tensor.get_shape().as_list())
  317. self.assertAllEqual(((2,), (3,), (5,), (7,)), self.evaluate(tensor))
  318. @test_util.run_in_graph_and_eager_modes
  319. def testConvertToTensorPreferred(self):
  320. values = [2, 3, 5, 7]
  321. tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.float32)
  322. self.assertEqual(dtypes.float32, tensor.dtype)
  323. # Convert empty tensor to anything.
  324. values = []
  325. tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
  326. self.assertEqual(dtypes.int64, tensor.dtype)
  327. # The preferred dtype is a type error and will convert to
  328. # float32 instead.
  329. values = [1.23]
  330. tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
  331. self.assertEqual(dtypes.float32, tensor.dtype)
  332. @test_util.run_in_graph_and_eager_modes
  333. def testConvertToInvalidTensorType(self):
  334. with self.assertRaises(TypeError):
  335. # Forcing an invalid dtype should fail with a type error.
  336. values = [1.23]
  337. ops.convert_to_tensor(values, dtype=dtypes.int64)
  338. @test_util.run_in_graph_and_eager_modes
  339. def testConvertToTensorFromInvalidTensor(self):
  340. tensor = constant_op.constant(42.0, dtype=dtypes.float32)
  341. with self.assertRaises(ValueError):
  342. ops.convert_to_tensor(tensor, dtype=dtypes.int32)
  343. @test_util.run_deprecated_v1
  344. def testNoConvert(self):
  345. # Operation cannot be converted to Tensor.
  346. op = control_flow_ops.no_op()
  347. with self.assertRaisesRegexp(TypeError,
  348. r"Can't convert Operation '.*' to Tensor"):
  349. ops.convert_to_tensor(op)
  350. def testStr(self):
  351. node_def = ops._NodeDef("None", "op1")
  352. op = ops.Operation(node_def, ops.Graph(), [], [dtypes.float32])
  353. self.assertEqual(str(node_def), str(op))
  354. def testRepr(self):
  355. op = ops.Operation(
  356. ops._NodeDef("None", "op1"), ops.Graph(), [], [dtypes.float32])
  357. self.assertEqual("<tf.Operation 'op1' type=None>", repr(op))
  358. @test_util.run_deprecated_v1
  359. def testGetAttr(self):
  360. op = test_ops.default_attrs()
  361. self.assertEqual(op.get_attr("string_val"), b"abc")
  362. self.assertEqual(op.get_attr("string_list_val"), [b"abc", b""])
  363. self.assertEqual(op.get_attr("int_val"), 123)
  364. self.assertEqual(op.get_attr("int_list_val"), [1, 2, 3])
  365. self.assertEqual(op.get_attr("float_val"), 10.0)
  366. self.assertEqual(op.get_attr("float_list_val"), [10.0])
  367. self.assertEqual(op.get_attr("bool_val"), True)
  368. self.assertEqual(op.get_attr("bool_list_val"), [True, False])
  369. self.assertEqual(op.get_attr("shape_val"),
  370. tensor_shape.as_shape([2, 1]).as_proto())
  371. self.assertEqual(op.get_attr("shape_list_val"),
  372. [tensor_shape.as_shape([]).as_proto(),
  373. tensor_shape.as_shape([1]).as_proto()])
  374. self.assertEqual(op.get_attr("tensor_val"),
  375. tensor_util.make_tensor_proto(1, dtypes.int32))
  376. self.assertEqual(op.get_attr("tensor_list_val"),
  377. [tensor_util.make_tensor_proto(1, dtypes.int32)])
  378. type_val = op.get_attr("type_val")
  379. # First check that type_val is a DType, because the assertEquals will work
  380. # no matter what since DType overrides __eq__
  381. self.assertIsInstance(type_val, dtypes.DType)
  382. self.assertEqual(type_val, dtypes.int32)
  383. type_list_val = op.get_attr("type_list_val")
  384. self.assertTrue(all(isinstance(x, dtypes.DType) for x in type_list_val))
  385. self.assertEqual(type_list_val, [dtypes.int32, dtypes.float32])
  386. @function.Defun(dtypes.float32, func_name="MyFunc")
  387. def func(x):
  388. return x
  389. op = test_ops.func_attr(func)
  390. self.assertEqual(op.get_attr("f"),
  391. attr_value_pb2.NameAttrList(name="MyFunc"))
  392. # Try fetching missing attr
  393. with self.assertRaisesRegexp(
  394. ValueError, "Operation 'FuncAttr' has no attr named 'FakeAttr'."):
  395. op.get_attr("FakeAttr")
  396. # TODO(b/65162920): remove this test when users who are directly mutating the
  397. # node_def have been updated to proper usage.
  398. @test_util.run_deprecated_v1
  399. def testSetAttr(self):
  400. op = test_ops.int_attr().op
  401. op._set_attr("foo", attr_value_pb2.AttrValue(i=2))
  402. # TODO(skyewm): add node_def check
  403. self.assertEqual(op.get_attr("foo"), 2)
  404. # TODO(nolivia): test all error cases
  405. def testAddControlInput(self):
  406. with ops.Graph().as_default():
  407. x = constant_op.constant(1).op
  408. y = constant_op.constant(2).op
  409. z = constant_op.constant(3).op
  410. z._add_control_input(x) # pylint: disable=protected-access
  411. self.assertEqual(z.control_inputs, [x])
  412. z._add_control_input(x) # pylint: disable=protected-access
  413. self.assertEqual(z.control_inputs, [x])
  414. z._add_control_inputs([x, y, y]) # pylint: disable=protected-access
  415. self.assertEqual(z.control_inputs, [x, y])
  416. self.assertEqual(x._control_outputs, [z])
  417. @test_util.run_deprecated_v1
  418. def testRemoveAllControlInputs(self):
  419. a = constant_op.constant(1)
  420. with ops.control_dependencies([a]):
  421. b = constant_op.constant(2)
  422. c = constant_op.constant(3)
  423. d = constant_op.constant(4)
  424. e = constant_op.constant(5)
  425. with ops.control_dependencies([a, c]):
  426. f = d + e
  427. self.assertEqual(a.op.control_inputs, [])
  428. self.assertEqual(b.op.control_inputs, [a.op])
  429. self.assertEqual(f.op.control_inputs, [a.op, c.op])
  430. a.op._remove_all_control_inputs() # pylint: disable=protected-access
  431. self.assertEqual(a.op.control_inputs, [])
  432. b.op._remove_all_control_inputs() # pylint: disable=protected-access
  433. self.assertEqual(b.op.control_inputs, [])
  434. f.op._remove_all_control_inputs() # pylint: disable=protected-access
  435. self.assertEqual(f.op.control_inputs, [])
  436. self.assertEqual(list(f.op.inputs), [d, e])
  437. @test_util.run_deprecated_v1
  438. def testControlInputCycle(self):
  439. graph = ops.Graph()
  440. with graph.as_default():
  441. z = constant_op.constant(0)
  442. x = constant_op.constant(1)
  443. y = constant_op.constant(2)
  444. y.op._add_control_input(z.op) # pylint: disable=protected-access
  445. y.op._add_control_input(x.op) # pylint: disable=protected-access
  446. x.op._add_control_input(y.op) # pylint: disable=protected-access
  447. with self.session(graph=graph) as sess:
  448. with self.assertRaisesRegexp(
  449. errors.InvalidArgumentError,
  450. "Graph is invalid, contains a cycle with 2 nodes"):
  451. self.evaluate(x)
  452. def testUpdateInput(self):
  453. g = ops.Graph()
  454. with g.as_default():
  455. x = constant_op.constant(1)
  456. y = constant_op.constant(2)
  457. z = x + y
  458. z.op._update_input(0, y) # pylint: disable=protected-access
  459. self.assertEquals(list(z.op.inputs), [y, y])
  460. self.assertEquals(x.consumers(), [])
  461. self.assertEquals(y.consumers(), [z.op, z.op])
  462. with session.Session(graph=g) as sess:
  463. self.assertEquals(self.evaluate(z), 4)
  464. z.op._update_input(0, x) # pylint: disable=protected-access
  465. self.assertEquals(list(z.op.inputs), [x, y])
  466. self.assertEquals(x.consumers(), [z.op])
  467. self.assertEquals(y.consumers(), [z.op])
  468. with session.Session(graph=g) as sess:
  469. self.assertEquals(self.evaluate(z), 3)
  470. z.op._update_input(1, y) # pylint: disable=protected-access
  471. self.assertEquals(list(z.op.inputs), [x, y])
  472. self.assertEquals(x.consumers(), [z.op])
  473. self.assertEquals(y.consumers(), [z.op])
  474. with session.Session(graph=g) as sess:
  475. self.assertEquals(self.evaluate(z), 3)
  476. def testUpdateInputGraphError(self):
  477. g_0 = ops.Graph()
  478. g_1 = ops.Graph()
  479. with g_0.as_default():
  480. x = constant_op.constant(1)
  481. with g_1.as_default():
  482. y = constant_op.constant(2)
  483. z = y * 2
  484. with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
  485. z.op._update_input(0, x) # pylint: disable=protected-access
  486. def testUpdateInputTypeError(self):
  487. g = ops.Graph()
  488. with g.as_default():
  489. w = constant_op.constant(0)
  490. x = constant_op.constant("")
  491. y = constant_op.constant(1)
  492. z = y + w
  493. z.op._update_input(0, x) # pylint: disable=protected-access
  494. with session.Session(graph=g) as sess:
  495. with self.assertRaisesRegexp(
  496. errors.InvalidArgumentError,
  497. "Input 0 of node add was passed string from Const_1:0 incompatible "
  498. "with expected int32"):
  499. self.evaluate(z)
  500. def testUpdateInputShapeError(self):
  501. g = ops.Graph()
  502. with g.as_default():
  503. w = constant_op.constant(2, shape=[3, 1])
  504. x = constant_op.constant(0, shape=[3, 1])
  505. y = constant_op.constant(1, shape=[2, 2])
  506. z = w + x
  507. with self.assertRaisesRegexp(
  508. errors.InvalidArgumentError,
  509. r"Cannot update edge, incompatible shapes: \[2,2\] and \[3,1\]"):
  510. z.op._update_input(0, y) # pylint: disable=protected-access
  511. def testUpdateInputOutOfRange(self):
  512. g = ops.Graph()
  513. with g.as_default():
  514. x = constant_op.constant(1)
  515. with self.assertRaisesRegexp(
  516. errors.OutOfRangeError,
  517. r"Cannot update edge. Input index \[1\] is greater than the number of "
  518. r"total inputs \[0\]."
  519. ):
  520. x.op._update_input(1, x) # pylint: disable=protected-access
  521. @test_util.enable_control_flow_v2
  522. @test_util.run_v1_only("b/120545219")
  523. def testAddWhileInput(self):
  524. @eager_function.defun
  525. def test():
  526. output = control_flow_ops.while_loop(lambda x: x < 3, lambda x: x + 1,
  527. [1])
  528. while_op = output.op.inputs[0].op
  529. self.assertEqual(while_op.type, "While")
  530. orig_num_inputs = len(while_op.inputs)
  531. # Make sure we can handle the while op having a control input.
  532. while_op._add_control_input(constant_op.constant(0).op)
  533. new_input1 = constant_op.constant(1.0)
  534. new_input2 = constant_op.constant(True)
  535. while_op._set_type_list_attr("T",
  536. [t.dtype for t in while_op.inputs] +
  537. [new_input1.dtype, new_input2.dtype])
  538. while_op._add_while_inputs([new_input1, new_input2])
  539. # Can't add an edge beyond what's specified by "T"
  540. with self.assertRaises(errors.OutOfRangeError):
  541. while_op._add_while_inputs([new_input2])
  542. self.assertEqual(len(while_op.inputs), orig_num_inputs + 2) # pylint: disable=g-deprecated-assert
  543. test()
  544. @test_util.run_deprecated_v1
  545. def testOpDef(self):
  546. x = constant_op.constant(0)
  547. y = constant_op.constant(1)
  548. z = x + y
  549. self.assertEqual(x.op.op_def.name, "Const")
  550. self.assertEqual(len(x.op.op_def.input_arg), 0)
  551. self.assertEqual(len(x.op.op_def.output_arg), 1)
  552. self.assertEqual(z.op.op_def.name, "Add")
  553. self.assertEqual(len(z.op.op_def.input_arg), 2)
  554. self.assertEqual(len(z.op.op_def.output_arg), 1)
  555. def testInputFromDifferentGraphError(self):
  556. g_0 = ops.Graph()
  557. g_1 = ops.Graph()
  558. with g_0.as_default():
  559. x = constant_op.constant(1)
  560. with g_1.as_default():
  561. y = constant_op.constant(2)
  562. with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
  563. y * x # pylint: disable=pointless-statement
  564. def testInputsAreImmutable(self):
  565. g = ops.Graph()
  566. with g.as_default():
  567. x = test_ops.int_output()
  568. op = test_ops.int_input_int_output(x, name="myop").op
  569. with self.assertRaisesRegexp(
  570. AttributeError, "'_InputList' object has no attribute 'append'"):
  571. op.inputs.append(None)
  572. class CreateOpTest(test_util.TensorFlowTestCase):
  573. def testNodeDefArgs(self):
  574. g = ops.Graph()
  575. op1 = g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1")
  576. with g.device("/device:GPU:0"):
  577. op2 = g.create_op(
  578. "FloatOutputStringOutput", [], [dtypes.float32, dtypes.string], None,
  579. name="myop2")
  580. op3 = g.create_op(
  581. "Foo3",
  582. [list(op1.values())[0], list(op2.values())[1], list(op2.values())[0]],
  583. [dtypes.float32, dtypes.int32],
  584. None,
  585. name="myop3")
  586. self.assertDeviceEqual(None, op1.device)
  587. self.assertDeviceEqual("/device:GPU:0", op2.device)
  588. self.assertDeviceEqual(None, op3.device)
  589. self.assertProtoEquals("name:'myop1' op:'FloatOutput'", op1.node_def)
  590. self.assertProtoEquals(
  591. "name:'myop2' op:'FloatOutputStringOutput' device:'/device:GPU:0'",
  592. op2.node_def)
  593. self.assertProtoEquals(
  594. "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo3'",
  595. op3.node_def)
  596. def testReferenceInput(self):
  597. g = ops.Graph()
  598. op1 = g.create_op(
  599. "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32],
  600. name="op1")
  601. self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def)
  602. ref_t, nonref_t = op1.values()
  603. # NOTE(mrry): Must specify input_types to preserve ref-typed input.
  604. op2 = g.create_op(
  605. "RefInputFloatInput", [ref_t, nonref_t], [],
  606. input_types=[dtypes.float32_ref, dtypes.float32],
  607. name="op2")
  608. self.assertProtoEquals(
  609. "op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'",
  610. op2.node_def)
  611. op3 = g.create_op("TwoFloatInputs", [ref_t, nonref_t], [], name="op3")
  612. self.assertProtoEquals(
  613. "op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'",
  614. op3.node_def)
  615. def testFinalized(self):
  616. g = ops.Graph()
  617. g.finalize()
  618. with self.assertRaises(RuntimeError):
  619. g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1")
  620. # Test unfinalize.
  621. g._unsafe_unfinalize()
  622. g.create_op("FloatOutput", [], [dtypes.float32], None, name="myop1")
  623. # NOTE(skyewm): these cases test the private Graph._create_op_from_tf_operation
  624. # method. Arguably we should only test the public APIs that depend on this
  625. # method. However, this logic is complex and tricky, and it can be difficult to
  626. # ascertain if we have adequate coverage (e.g. a graph may run successfully if
  627. # the control flow context isn't set properly, but a more complicated use case
  628. # that might not be obvious to test will fail). Thus we instead explicitly test
  629. # the low-level behavior.
  630. class CreateOpFromTFOperationTest(test_util.TensorFlowTestCase):
  631. @test_util.run_deprecated_v1
  632. def testBasic(self):
  633. g = ops.Graph()
  634. with g.as_default():
  635. x = test_ops.int_output()
  636. c_op = ops._create_c_op(
  637. g, ops._NodeDef("IntInputIntOutput", "myop"), [x], [])
  638. op = g._create_op_from_tf_operation(c_op)
  639. self.assertEqual(op.name, "myop")
  640. self.assertEqual(op.type, "IntInputIntOutput")
  641. self.assertEqual(len(op.outputs), 1)
  642. self.assertEqual(op.outputs[0].shape, tensor_shape.unknown_shape())
  643. self.assertEqual(list(op.inputs), [x])
  644. self.assertEqual(op.control_inputs, [])
  645. self.assertEqual(op.graph, g)
  646. self.assertEqual(x.consumers(), [op])
  647. self.assertIsNotNone(op.traceback)
  648. self.assertEqual(g.get_operation_by_name("myop"), op)
  649. self.assertEqual(g.get_tensor_by_name("myop:0"), op.outputs[0])
  650. def testShape(self):
  651. g = ops.Graph()
  652. with g.as_default():
  653. x = constant_op.constant([[1, 2, 3], [4, 5, 6]])
  654. c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), [x], [])
  655. op = g._create_op_from_tf_operation(c_op)
  656. self.assertEqual(op.name, "myop")
  657. self.assertEqual(op.type, "Identity")
  658. self.assertEqual(len(op.outputs), 1)
  659. self.assertEqual(op.outputs[0].shape, tensor_shape.matrix(2, 3))
  660. def testUniqueName(self):
  661. g = ops.Graph()
  662. with g.as_default():
  663. c_op = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop"), [], [])
  664. c_op2 = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop_1"), [], [])
  665. op = g._create_op_from_tf_operation(c_op)
  666. op2 = g._create_op_from_tf_operation(c_op2)
  667. # Create ops with same names as op1 and op2. We expect the new names to be
  668. # uniquified.
  669. op3 = test_ops.int_output(name="myop").op
  670. op4 = test_ops.int_output(name="myop_1").op
  671. self.assertEqual(op.name, "myop")
  672. self.assertEqual(op2.name, "myop_1")
  673. self.assertEqual(op3.name, "myop_2")
  674. self.assertEqual(op4.name, "myop_1_1")
  675. @test_util.run_v1_only("b/120545219")
  676. def testCond(self):
  677. g = ops.Graph()
  678. with g.as_default():
  679. x = test_ops.int_output()
  680. def true_fn():
  681. ops._create_c_op(ops.get_default_graph(),
  682. ops._NodeDef("IntInput", "cond/myop"), [x], [])
  683. new_ops = g._add_new_tf_operations()
  684. self.assertEqual(len(new_ops), 1)
  685. return x
  686. control_flow_ops.cond(x < 10, true_fn, lambda: x)
  687. op = g.get_operation_by_name("cond/myop")
  688. self.assertIsNotNone(op)
  689. self.assertEqual(op.name, "cond/myop")
  690. self.assertEqual(op.type, "IntInput")
  691. self.assertEqual(op.outputs, [])
  692. op_input = op.inputs[0].op
  693. self.assertEqual(op_input.type, "Switch")
  694. self.assertEqual(op_input.inputs[0], x)
  695. self.assertEqual(op.graph, g)
  696. # pylint: disable=protected-access
  697. self.assertIsNotNone(op._get_control_flow_context())
  698. self.assertEqual(op._get_control_flow_context().name,
  699. "cond/cond_text")
  700. # pylint: enable=protected-access
  701. @test_util.run_v1_only("b/120545219")
  702. def testWhileLoop(self):
  703. g = ops.Graph()
  704. with g.as_default():
  705. x = test_ops.int_output()
  706. def body(i):
  707. ops._create_c_op(ops.get_default_graph(),
  708. ops._NodeDef("IntInput", "myloop/myop"), [x], [])
  709. new_ops = g._add_new_tf_operations()
  710. self.assertEqual(len(new_ops), 1)
  711. return i
  712. control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
  713. op = g.get_operation_by_name("myloop/myop")
  714. self.assertIsNotNone(op)
  715. self.assertEqual(op.name, "myloop/myop")
  716. self.assertEqual(op.type, "IntInput")
  717. self.assertEqual(op.outputs, [])
  718. op_input = op.inputs[0].op
  719. self.assertEqual(op_input.type, "Enter")
  720. self.assertEqual(list(op_input.inputs), [x])
  721. self.assertEqual(op.graph, g)
  722. # pylint: disable=protected-access
  723. self.assertIsNotNone(op._get_control_flow_context())
  724. self.assertEqual(op._get_control_flow_context().name,
  725. "myloop/while_context")
  726. # pylint: enable=protected-access
  727. @test_util.run_v1_only("b/120545219")
  728. def testWhileLoopWithInternalControlDep(self):
  729. g = ops.Graph()
  730. with g.as_default():
  731. x = test_ops.int_output()
  732. def body(i):
  733. c = constant_op.constant(1.0, name="c")
  734. ops._create_c_op(ops.get_default_graph(),
  735. ops._NodeDef("IntInput", "myloop/myop"), [x], [])
  736. with ops.control_dependencies([c]):
  737. new_ops = g._add_new_tf_operations()
  738. self.assertEqual(len(new_ops), 1)
  739. return i
  740. control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
  741. op = g.get_operation_by_name("myloop/myop")
  742. self.assertIsNotNone(op)
  743. c = g.get_operation_by_name("myloop/c")
  744. self.assertIsNotNone(c)
  745. # Internal control dep is preserved
  746. self.assertEqual(op.control_inputs, [c])
  747. @test_util.run_v1_only("b/120545219")
  748. def testWhileLoopWithExternalControlDep(self):
  749. g = ops.Graph()
  750. with g.as_default():
  751. x = test_ops.int_output()
  752. c = constant_op.constant(1.0)
  753. def body(i):
  754. ops._create_c_op(ops.get_default_graph(),
  755. ops._NodeDef("IntInput", "myloop/myop"), [x], [])
  756. with ops.control_dependencies([c]):
  757. new_ops = g._add_new_tf_operations()
  758. self.assertEqual(len(new_ops), 1)
  759. return i
  760. control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
  761. op = g.get_operation_by_name("myloop/myop")
  762. self.assertIsNotNone(op)
  763. # External control dep is removed and replaced with internal control dep
  764. self.assertNotEqual(op.control_inputs[0], c.op)
  765. self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context())
  766. class ApplyOpTest(test_util.TensorFlowTestCase):
  767. def testNodeDefArgs(self):
  768. g = ops.Graph()
  769. t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1")
  770. with g.device("/device:GPU:0"):
  771. t2 = _apply_op(
  772. g, "TwoIntOutputs", [], [dtypes.int32, dtypes.int32], name="myop2")
  773. t3 = _apply_op(
  774. g,
  775. "Foo1", [t1, t2[1], t2[0]], [dtypes.float32, dtypes.int32],
  776. name="myop3")
  777. self.assertTrue(isinstance(t1, ops.Tensor))
  778. self.assertTrue(isinstance(t2, list))
  779. self.assertTrue(isinstance(t3, list))
  780. self.assertTrue(isinstance(t3[0], ops.Tensor))
  781. self.assertEqual("myop1", t1._as_node_def_input())
  782. self.assertEqual("myop2", t2[0]._as_node_def_input())
  783. self.assertEqual("myop2:1", t2[1]._as_node_def_input())
  784. self.assertEqual("myop3", t3[0]._as_node_def_input())
  785. # Validate that we got the right ops as well
  786. self.assertProtoEquals("name:'myop1' op:'FloatOutput'", t1.op.node_def)
  787. self.assertProtoEquals(
  788. "name:'myop2' op:'TwoIntOutputs' device:'/device:GPU:0'",
  789. t2[0].op.node_def)
  790. self.assertProtoEquals(
  791. "name:'myop3' input:'myop1' input:'myop2:1' input:'myop2' op:'Foo1'",
  792. t3[0].op.node_def)
  793. def testReferenceInput(self):
  794. g = ops.Graph()
  795. ref_t, nonref_t = _apply_op(
  796. g, "RefOutputFloatOutput", [], [dtypes.float32_ref, dtypes.float32],
  797. name="op1")
  798. self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'",
  799. ref_t.op.node_def)
  800. # NOTE(mrry): Must specify input_types to preserve ref-typed input.
  801. out_2 = _apply_op(
  802. g,
  803. "RefInputFloatInputIntOutput", [ref_t, nonref_t], [dtypes.int32],
  804. input_types=[dtypes.float32_ref, dtypes.float32],
  805. name="op2")
  806. self.assertProtoEquals(
  807. "op:'RefInputFloatInputIntOutput' name:'op2' input:'op1' input:'op1:1'",
  808. out_2.op.node_def)
  809. out_3 = _apply_op(
  810. g, "TwoFloatInputsIntOutput", [ref_t, nonref_t], [dtypes.int32],
  811. name="op3")
  812. self.assertProtoEquals(
  813. "op:'TwoFloatInputsIntOutput' name:'op3' input:'op1' input:'op1:1'",
  814. out_3.op.node_def)
  815. class NameStackTest(test_util.TensorFlowTestCase):
  816. def testBasics(self):
  817. g = ops.Graph()
  818. self.assertEqual("foo", g.unique_name("foo", mark_as_used=False))
  819. self.assertEqual("foo", g.unique_name("foo", mark_as_used=False))
  820. self.assertEqual("foo", g.unique_name("foo"))
  821. self.assertEqual("foo_1", g.unique_name("foo", mark_as_used=False))
  822. self.assertEqual("foo_1", g.unique_name("foo"))
  823. self.assertEqual("foo_2", g.unique_name("foo", mark_as_used=False))
  824. self.assertEqual("foo_2", g.unique_name("foo"))
  825. self.assertEqual("foo_1_1", g.unique_name("foo_1", mark_as_used=False))
  826. self.assertEqual("foo_1_1", g.unique_name("foo_1"))
  827. self.assertEqual("foo_1_2", g.unique_name("foo_1", mark_as_used=False))
  828. self.assertEqual("foo_1_2", g.unique_name("foo_1"))
  829. self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2", mark_as_used=False))
  830. self.assertEqual("foo_1_2_1", g.unique_name("foo_1_2"))
  831. with g.name_scope("bar"):
  832. self.assertEqual("bar/foo", g.unique_name("foo", mark_as_used=False))
  833. self.assertEqual("bar/foo", g.unique_name("foo"))
  834. self.assertEqual("bar/foo_1", g.unique_name("foo", mark_as_used=False))
  835. self.assertEqual("bar/foo_1", g.unique_name("foo"))
  836. with g.name_scope(None):
  837. self.assertEqual("foo_3", g.unique_name("foo", mark_as_used=False))
  838. self.assertEqual("foo_3", g.unique_name("foo"))
  839. with g.name_scope("baz"):
  840. self.assertEqual(
  841. "bar/baz/foo", g.unique_name(
  842. "foo", mark_as_used=False))
  843. self.assertEqual("bar/baz/foo", g.unique_name("foo"))
  844. self.assertEqual(
  845. "bar/baz/foo_1", g.unique_name(
  846. "foo", mark_as_used=False))
  847. self.assertEqual("bar/baz/foo_1", g.unique_name("foo"))
  848. with g.name_scope("baz"):
  849. self.assertEqual(
  850. "bar/baz_1/foo", g.unique_name(
  851. "foo", mark_as_used=False))
  852. self.assertEqual("bar/baz_1/foo", g.unique_name("foo"))
  853. self.assertEqual(
  854. "bar/baz_1/foo_1", g.unique_name(
  855. "foo", mark_as_used=False))
  856. self.assertEqual("bar/baz_1/foo_1", g.unique_name("foo"))
  857. with g.name_scope("quux"):
  858. self.assertEqual("quux/foo", g.unique_name("foo", mark_as_used=False))
  859. self.assertEqual("quux/foo", g.unique_name("foo"))
  860. with g.name_scope("bar"):
  861. with g.name_scope("baz"):
  862. self.assertEqual(
  863. "bar_1/baz/foo", g.unique_name(
  864. "foo", mark_as_used=False))
  865. self.assertEqual("bar_1/baz/foo", g.unique_name("foo"))
  866. self.assertEqual("foo_4", g.unique_name("foo", mark_as_used=False))
  867. self.assertEqual("foo_4", g.unique_name("foo"))
  868. self.assertEqual("bar_2", g.unique_name("bar", mark_as_used=False))
  869. self.assertEqual("bar_2", g.unique_name("bar"))
  870. @test_util.run_deprecated_v1
  871. def testNameAndVariableScope(self):
  872. with self.cached_session() as sess:
  873. with sess.graph.name_scope("l0"):
  874. with variable_scope.variable_scope("l1"):
  875. with sess.graph.name_scope("l1") as scope:
  876. self.assertEqual("l0/l1/l1/", scope)
  877. self.assertEqual(
  878. "l0/l1/l1/foo",
  879. sess.graph.unique_name(
  880. "foo", mark_as_used=False))
  881. self.assertEqual("l0/l1/l1/foo", sess.graph.unique_name("foo"))
  882. with sess.graph.name_scope("l2") as scope:
  883. self.assertEqual("l0/l1/l2/", scope)
  884. self.assertEqual(
  885. "l0/l1/l2/foo",
  886. sess.graph.unique_name(
  887. "foo", mark_as_used=False))
  888. self.assertEqual("l0/l1/l2/foo", sess.graph.unique_name("foo"))
  889. def testOutOfOrderUniqueName(self):
  890. g = ops.Graph()
  891. self.assertEqual("foo_2", g.unique_name("foo_2"))
  892. self.assertEqual("foo", g.unique_name("foo"))
  893. self.assertEqual("foo_1", g.unique_name("foo"))
  894. self.assertEqual("foo_3", g.unique_name("foo"))
  895. def testUniqueNameCaseInsensitivity(self):
  896. g = ops.Graph()
  897. self.assertEqual("foo", g.unique_name("foo"))
  898. self.assertEqual("Foo_1", g.unique_name("Foo"))
  899. with g.name_scope("bar"):
  900. self.assertEqual("bar/foo", g.unique_name("foo"))
  901. with g.name_scope("Bar"):
  902. self.assertEqual("Bar_1/foo", g.unique_name("foo"))
  903. def testInvalidNameRaisesError(self):
  904. g = ops.Graph()
  905. with g.name_scope(""): # Should not raise
  906. pass
  907. with g.name_scope("foo/"): # Should not raise
  908. with g.name_scope("_bar"): # Should not raise
  909. pass
  910. with self.assertRaises(ValueError):
  911. with g.name_scope("foo:0"):
  912. pass
  913. with self.assertRaises(ValueError):
  914. with g.name_scope("_bar"):
  915. pass
  916. class NameTest(test_util.TensorFlowTestCase):
  917. def testGenerateName(self):
  918. g = ops.Graph()
  919. op0 = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32])
  920. self.assertEqual("TwoFloatOutputs", op0.name)
  921. self.assertEqual("TwoFloatOutputs:0", op0.outputs[0].name)
  922. self.assertEqual("TwoFloatOutputs:1", op0.outputs[1].name)
  923. op1 = g.create_op("FloatOutput", [], [dtypes.float32])
  924. self.assertEqual("FloatOutput", op1.name)
  925. self.assertEqual("FloatOutput:0", op1.outputs[0].name)
  926. op2 = g.create_op("FloatOutput", [], [dtypes.float32])
  927. self.assertEqual("FloatOutput_1", op2.name)
  928. self.assertEqual("FloatOutput_1:0", op2.outputs[0].name)
  929. op3 = g.create_op("FloatOutput", [], [dtypes.float32], name="my_op")
  930. self.assertEqual("my_op", op3.name)
  931. self.assertEqual("my_op:0", op3.outputs[0].name)
  932. def testNameScope(self):
  933. g = ops.Graph()
  934. with g.name_scope("foo") as foo:
  935. self.assertEqual("foo/", foo)
  936. with g.name_scope("foo2") as foo2:
  937. self.assertEqual("foo/foo2/", foo2)
  938. with g.name_scope(None) as empty1:
  939. self.assertEqual("", empty1)
  940. with g.name_scope("foo3") as foo3:
  941. self.assertEqual("foo3/", foo3)
  942. with g.name_scope("") as empty2:
  943. self.assertEqual("", empty2)
  944. self.assertEqual("FloatOutput",
  945. g.create_op("FloatOutput", [], [dtypes.float32]).name)
  946. with g.name_scope("bar") as scope:
  947. self.assertEqual("bar/FloatOutput",
  948. g.create_op("FloatOutput", [], [dtypes.float32]).name)
  949. self.assertEqual("bar/FloatOutput_1",
  950. g.create_op("FloatOutput", [], [dtypes.float32]).name)
  951. # If you use the value from "with .. as", that values is used as-is.
  952. self.assertEqual(
  953. "bar", g.create_op(
  954. "FloatOutput", [], [dtypes.float32], name=scope).name)
  955. with g.name_scope("baz") as scope:
  956. with g.name_scope("quux"):
  957. self.assertEqual("baz/quux/FloatOutput",
  958. g.create_op("FloatOutput", [], [dtypes.float32]).name)
  959. # If you use the value from the enclosing "with .. as", nothing is pushed.
  960. with g.name_scope(scope):
  961. self.assertEqual("baz/FloatOutput",
  962. g.create_op("FloatOutput", [], [dtypes.float32]).name)
  963. self.assertEqual(
  964. "baz", g.create_op(
  965. "FloatOutput", [], [dtypes.float32], name=scope).name)
  966. self.assertEqual(
  967. "trailing",
  968. g.create_op(
  969. "FloatOutput", [], [dtypes.float32], name="trailing/").name)
  970. with g.name_scope("bar"):
  971. self.assertEqual("bar_1/FloatOutput",
  972. g.create_op("FloatOutput", [], [dtypes.float32]).name)
  973. with g.name_scope("bar/"):
  974. self.assertEqual("bar/FloatOutput_2",
  975. g.create_op("FloatOutput", [], [dtypes.float32]).name)
  976. class DeviceTest(test_util.TensorFlowTestCase):
  977. def testNoDevice(self):
  978. g = ops.Graph()
  979. op = g.create_op("FloatOutput", [], [dtypes.float32])
  980. self.assertDeviceEqual(None, op.device)
  981. gd = g.as_graph_def()
  982. self.assertProtoEqualsVersion("""
  983. node { name: "FloatOutput" op: "FloatOutput" }
  984. """, gd)
  985. def testEagerBackingDevice(self):
  986. with context.eager_mode():
  987. with ops.device("/device:CPU:0"):
  988. t = constant_op.constant(1.0)
  989. self.assertRegexpMatches(t.device, "/device:CPU:0")
  990. self.assertRegexpMatches(t.backing_device, "/device:CPU:0")
  991. def testDevicePartialString(self):
  992. g = ops.Graph()
  993. with g.device("/job:worker/replica:2"):
  994. g.create_op("FloatOutput", [], [dtypes.float32])
  995. gd = g.as_graph_def()
  996. self.assertProtoEqualsVersion("""
  997. node { name: "FloatOutput" op: "FloatOutput"
  998. device: "/job:worker/replica:2" }
  999. """, gd)
  1000. def testDeviceFull(self):
  1001. g = ops.Graph()
  1002. with g.device(
  1003. pydev.DeviceSpec(
  1004. job="worker", replica=2, task=0, device_type="CPU",
  1005. device_index=3)):
  1006. g.create_op("FloatOutput", [], [dtypes.float32])
  1007. gd = g.as_graph_def()
  1008. self.assertProtoEqualsVersion("""
  1009. node { name: "FloatOutput" op: "FloatOutput"
  1010. device: "/job:worker/replica:2/task:0/device:CPU:3" }
  1011. """, gd)
  1012. def testNesting(self):
  1013. g = ops.Graph()
  1014. with g.device("/job:worker/replica:2"):
  1015. g.create_op("FloatOutput", [], [dtypes.float32])
  1016. with g.device("/job:worker/replica:3/task:0"):
  1017. g.create_op("FloatOutput", [], [dtypes.float32])
  1018. g.create_op("FloatOutput", [], [dtypes.float32])
  1019. gd = g.as_graph_def()
  1020. self.assertProtoEqualsVersion("""
  1021. node { name: "FloatOutput" op: "FloatOutput"
  1022. device: "/job:worker/replica:2" }
  1023. node { name: "FloatOutput_1" op: "FloatOutput"
  1024. device: "/job:worker/replica:3/task:0" }
  1025. node { name: "FloatOutput_2" op: "FloatOutput"
  1026. device: "/job:worker/replica:2" }
  1027. """, gd)
  1028. def testNestingString(self):
  1029. g = ops.Graph()
  1030. with g.device("/job:worker/replica:2"):
  1031. g.create_op("FloatOutput", [], [dtypes.float32])
  1032. with g.device("/job:worker/replica:3/task:0"):
  1033. g.create_op("FloatOutput", [], [dtypes.float32])
  1034. g.create_op("FloatOutput", [], [dtypes.float32])
  1035. gd = g.as_graph_def()
  1036. self.assertProtoEqualsVersion("""
  1037. node { name: "FloatOutput" op: "FloatOutput"
  1038. device: "/job:worker/replica:2" }
  1039. node { name: "FloatOutput_1" op: "FloatOutput"
  1040. device: "/job:worker/replica:3/task:0" }
  1041. node { name: "FloatOutput_2" op: "FloatOutput"
  1042. device: "/job:worker/replica:2" }
  1043. """, gd)
  1044. def testNestingOverrideGpuCpu(self):
  1045. g = ops.Graph()
  1046. with g.device("/job:worker/replica:2/device:CPU:1"):
  1047. g.create_op("FloatOutput", [], [dtypes.float32])
  1048. with g.device("/job:worker/replica:2/device:GPU:2"):
  1049. g.create_op("FloatOutput", [], [dtypes.float32])
  1050. g.create_op("FloatOutput", [], [dtypes.float32])
  1051. gd = g.as_graph_def()
  1052. self.assertProtoEqualsVersion("""
  1053. node { name: "FloatOutput" op: "FloatOutput"
  1054. device: "/job:worker/replica:2/device:CPU:1" }
  1055. node { name: "FloatOutput_1" op: "FloatOutput"
  1056. device: "/job:worker/replica:2/device:GPU:2" }
  1057. node { name: "FloatOutput_2" op: "FloatOutput"
  1058. device: "/job:worker/replica:2/device:CPU:1" }
  1059. """, gd)
  1060. def testNestingWithMergeDeviceFunction(self):
  1061. g = ops.Graph()
  1062. with g.device(pydev.merge_device("/device:GPU:0")):
  1063. g.create_op("FloatOutput", [], [dtypes.float32])
  1064. with g.device(pydev.merge_device("/job:worker")):
  1065. g.create_op("FloatOutput", [], [dtypes.float32])
  1066. with g.device(pydev.merge_device("/device:CPU:0")):
  1067. g.create_op("FloatOutput", [], [dtypes.float32])
  1068. with g.device(pydev.merge_device("/job:ps")):
  1069. g.create_op("FloatOutput", [], [dtypes.float32])
  1070. with g.device(pydev.merge_device(None)):
  1071. g.create_op("FloatOutput", [], [dtypes.float32])
  1072. gd = g.as_graph_def()
  1073. self.assertProtoEqualsVersion("""
  1074. node { name: "FloatOutput" op: "FloatOutput"
  1075. device: "/device:GPU:0" }
  1076. node { name: "FloatOutput_1" op: "FloatOutput"
  1077. device: "/job:worker/device:GPU:0" }
  1078. node { name: "FloatOutput_2" op: "FloatOutput"
  1079. device: "/job:worker/device:CPU:0" }
  1080. node { name: "FloatOutput_3" op: "FloatOutput"
  1081. device: "/job:ps/device:CPU:0" }
  1082. node { name: "FloatOutput_4" op: "FloatOutput"
  1083. device: "/job:ps/device:CPU:0" }
  1084. """, gd)
  1085. def testNestingWithDeviceStrings(self):
  1086. g = ops.Graph()
  1087. with g.device("/device:GPU:0"):
  1088. g.create_op("FloatOutput", [], [dtypes.float32])
  1089. with g.device("/job:worker"):
  1090. g.create_op("FloatOutput", [], [dtypes.float32])
  1091. with g.device("/device:CPU:0"):
  1092. g.create_op("FloatOutput", [], [dtypes.float32])
  1093. with g.device("/job:ps"):
  1094. g.create_op("FloatOutput", [], [dtypes.float32])
  1095. with g.device(""):
  1096. g.create_op("FloatOutput", [], [dtypes.float32])
  1097. gd = g.as_graph_def()
  1098. self.assertProtoEqualsVersion("""
  1099. node { name: "FloatOutput" op: "FloatOutput"
  1100. device: "/device:GPU:0" }
  1101. node { name: "FloatOutput_1" op: "FloatOutput"
  1102. device: "/job:worker/device:GPU:0" }
  1103. node { name: "FloatOutput_2" op: "FloatOutput"
  1104. device: "/job:worker/device:CPU:0" }
  1105. node { name: "FloatOutput_3" op: "FloatOutput"
  1106. device: "/job:ps/device:CPU:0" }
  1107. node { name: "FloatOutput_4" op: "FloatOutput"
  1108. device: "/job:ps/device:CPU:0" }
  1109. """, gd)
  1110. def testNestingWithDeviceStringWildcard(self):
  1111. g = ops.Graph()
  1112. with g.device("/device:GPU:7"):
  1113. g.create_op("FloatOutput", [], [dtypes.float32])
  1114. with g.device("/device:GPU:*"):
  1115. g.create_op("FloatOutput", [], [dtypes.float32])
  1116. with g.device("/device:CPU:*"):
  1117. g.create_op("FloatOutput", [], [dtypes.float32])
  1118. with g.device("/device:CPU:5"):
  1119. g.create_op("FloatOutput", [], [dtypes.float32])
  1120. gd = g.as_graph_def()
  1121. self.assertProtoEqualsVersion("""
  1122. node { name: "FloatOutput" op: "FloatOutput"
  1123. device: "/device:GPU:7" }
  1124. node { name: "FloatOutput_1" op: "FloatOutput"
  1125. device: "/device:GPU:7" }
  1126. node { name: "FloatOutput_2" op: "FloatOutput"
  1127. device: "/device:CPU:*" }
  1128. node { name: "FloatOutput_3" op: "FloatOutput"
  1129. device: "/device:CPU:5" }
  1130. """, gd)
  1131. def testNoneClearsDefault(self):
  1132. g = ops.Graph()
  1133. with g.device("/job:worker/replica:2/device:CPU:1"):
  1134. g.create_op("FloatOutput", [], [dtypes.float32])
  1135. with g.device(None):
  1136. g.create_op("FloatOutput", [], [dtypes.float32])
  1137. g.create_op("FloatOutput", [], [dtypes.float32])
  1138. gd = g.as_graph_def()
  1139. self.assertProtoEqualsVersion("""
  1140. node { name: "FloatOutput" op: "FloatOutput"
  1141. device: "/job:worker/replica:2/device:CPU:1" }
  1142. node { name: "FloatOutput_1" op: "FloatOutput" }
  1143. node { name: "FloatOutput_2" op: "FloatOutput"
  1144. device: "/job:worker/replica:2/device:CPU:1" }
  1145. """, gd)
  1146. def testNoneIgnoresOuterDeviceFunction(self):
  1147. g = ops.Graph()
  1148. with g.device(lambda op: "/job:worker/replica:2/device:CPU:1"):
  1149. g.create_op("FloatOutput", [], [dtypes.float32])
  1150. with g.device(None):
  1151. g.create_op("FloatOutput", [], [dtypes.float32])
  1152. g.create_op("FloatOutput", [], [dtypes.float32])
  1153. gd = g.as_graph_def()
  1154. self.assertProtoEqualsVersion("""
  1155. node { name: "FloatOutput" op: "FloatOutput"
  1156. device: "/job:worker/replica:2/device:CPU:1" }
  1157. node { name: "FloatOutput_1" op: "FloatOutput" }
  1158. node { name: "FloatOutput_2" op: "FloatOutput"
  1159. device: "/job:worker/replica:2/device:CPU:1" }
  1160. """, gd)
  1161. def _overwritingDeviceFunction(self, unused_op):
  1162. # This device function unconditionally overwrites the device of ops.
  1163. #
  1164. # NOTE(mrry): Writing device functions like this is not
  1165. # recommended. Instead, in most cases you should use
  1166. # `pydev.merge_device("/job:ps")` or simply `"/job:ps"` as the
  1167. # argument to `tf.device()` and the device component will be merged in.
  1168. return "/job:overwrite"
  1169. def testOverwritingBehavior(self):
  1170. g = ops.Graph()
  1171. with g.device(self._overwritingDeviceFunction):
  1172. g.create_op("FloatOutput", [], [dtypes.float32])
  1173. with g.device("/job:ps"): # Will be overwritten.
  1174. g.create_op("FloatOutput", [], [dtypes.float32])
  1175. with g.device(pydev.merge_device("/job:ps")): # Will be overwritten.
  1176. g.create_op("FloatOutput", [], [dtypes.float32])
  1177. with g.device(None): # Disables overwriting device function
  1178. with g.device("/job:ps"):
  1179. g.create_op("FloatOutput", [], [dtypes.float32])
  1180. with g.device(None): # Disables overwriting device function
  1181. with g.device(pydev.merge_device("/job:ps")):
  1182. g.create_op("FloatOutput", [], [dtypes.float32])
  1183. gd = g.as_graph_def()
  1184. self.assertProtoEqualsVersion("""
  1185. node { name: "FloatOutput" op: "FloatOutput"
  1186. device: "/job:overwrite" }
  1187. node { name: "FloatOutput_1" op: "FloatOutput"
  1188. device: "/job:overwrite" }
  1189. node { name: "FloatOutput_2" op: "FloatOutput"
  1190. device: "/job:overwrite" }
  1191. node { name: "FloatOutput_3" op: "FloatOutput"
  1192. device: "/job:ps" }
  1193. node { name: "FloatOutput_4" op: "FloatOutput"
  1194. device: "/job:ps" }
  1195. """, gd)
  1196. class MultithreadedGraphStateTest(test_util.TensorFlowTestCase):
  1197. class TestThread(threading.Thread):
  1198. def __init__(self, graph, replica_id):
  1199. super(MultithreadedGraphStateTest.TestThread, self).__init__()
  1200. self._graph = graph
  1201. self._replica_id = replica_id
  1202. # This thread sets this event when it mutated the graph. The caller can
  1203. # wait for that.
  1204. self.has_mutated_graph = threading.Event()
  1205. # This thread waits for when it should continue. The caller can set this
  1206. # event.
  1207. self.should_continue = threading.Event()
  1208. def run(self):
  1209. # Mutate a graph's stack, then set `has_mutated_graph`, then wait for
  1210. # `should_continue`, then add an op to the graph affected by the graph's
  1211. # stack.
  1212. raise NotImplementedError("must be implemented in descendants")
  1213. def testDeviceFunctionStack(self):
  1214. class DeviceSettingThread(self.TestThread):
  1215. def run(self):
  1216. with g.device("/job:worker/replica:{}".format(self._replica_id)):
  1217. self.has_mutated_graph.set()
  1218. self.should_continue.wait()
  1219. self.should_continue.clear()
  1220. g.create_op(
  1221. "FloatOutput", [], [dtypes.float32],
  1222. name="FloatOutput_{}".format(self._replica_id))
  1223. g = ops.Graph()
  1224. # If `switch_to_thread` isn't called, then device placement of the ops
  1225. # below is not deterministic.
  1226. g.switch_to_thread_local()
  1227. threads = [DeviceSettingThread(g, i) for i in range(3)]
  1228. for t in threads:
  1229. t.start()
  1230. t.has_mutated_graph.wait()
  1231. t.has_mutated_graph.clear()
  1232. for t in threads:
  1233. t.should_continue.set()
  1234. t.join()
  1235. gd = g.as_graph_def()
  1236. self.assertProtoEqualsVersion("""
  1237. node { name: "FloatOutput_0" op: "FloatOutput"
  1238. device: "/job:worker/replica:0" }
  1239. node { name: "FloatOutput_1" op: "FloatOutput"
  1240. device: "/job:worker/replica:1" }
  1241. node { name: "FloatOutput_2" op: "FloatOutput"
  1242. device: "/job:worker/replica:2" }
  1243. """, gd)
  1244. def testColocateWith(self):
  1245. class ColocatingThread(self.TestThread):
  1246. def __init__(self, graph, replica_id, op_to_colocate_with):
  1247. super(ColocatingThread, self).__init__(graph, replica_id)
  1248. self._op_to_colocate_with = op_to_colocate_with
  1249. def run(self):
  1250. with g.colocate_with(self._op_to_colocate_with):
  1251. self.has_mutated_graph.set()
  1252. self.should_continue.wait()
  1253. self.should_continue.clear()
  1254. g.create_op(
  1255. "FloatOutput", [], [dtypes.float32],
  1256. name="FloatOutput_{}".format(self._replica_id))
  1257. g = ops.Graph()
  1258. ops_to_colocate_with = []
  1259. for i in range(3):
  1260. with g.device("/job:worker/replica:{}".format(i)):
  1261. ops_to_colocate_with.append(
  1262. g.create_op(
  1263. "FloatOutput", [], [dtypes.float32],
  1264. name="ColocateWithMe_{}".format(i)))
  1265. # If `switch_to_thread` isn't called, then `device` and `attr` values for
  1266. # the ops below are not deterministic.
  1267. g.switch_to_thread_local()
  1268. threads = [
  1269. ColocatingThread(g, i, ops_to_colocate_with[i]) for i in range(3)
  1270. ]
  1271. for t in threads:
  1272. t.start()
  1273. t.has_mutated_graph.wait()
  1274. t.has_mutated_graph.clear()
  1275. for t in threads:
  1276. t.should_continue.set()
  1277. t.join()
  1278. gd = g.as_graph_def()
  1279. self.assertProtoEqualsVersion("""
  1280. node { name: "ColocateWithMe_0" op: "FloatOutput"
  1281. device: "/job:worker/replica:0" }
  1282. node { name: "ColocateWithMe_1" op: "FloatOutput"
  1283. device: "/job:worker/replica:1" }
  1284. node { name: "ColocateWithMe_2" op: "FloatOutput"
  1285. device: "/job:worker/replica:2" }
  1286. node { name: "FloatOutput_0" op: "FloatOutput"
  1287. device: "/job:worker/replica:0"
  1288. attr { key: "_class"
  1289. value { list {
  1290. s: "loc:@ColocateWithMe_0"}}}}
  1291. node { name: "FloatOutput_1" op: "FloatOutput"
  1292. device: "/job:worker/replica:1"
  1293. attr { key: "_class"
  1294. value { list {
  1295. s: "loc:@ColocateWithMe_1"}}}}
  1296. node { name: "FloatOutput_2" op: "FloatOutput"
  1297. device: "/job:worker/replica:2"
  1298. attr { key: "_class"
  1299. value { list {
  1300. s: "loc:@ColocateWithMe_2"}}}}
  1301. """, gd)
  1302. def testControlDependencies(self):
  1303. class DependingThread(self.TestThread):
  1304. def __init__(self, graph, replica_id, dependency_op):
  1305. super(DependingThread, self).__init__(graph, replica_id)
  1306. self._dependency_op = dependency_op
  1307. def run(self):
  1308. with g.control_dependencies([self._dependency_op]):
  1309. self.has_mutated_graph.set()
  1310. self.should_continue.wait()
  1311. self.should_continue.clear()
  1312. g.create_op(
  1313. "FloatOutput", [], [dtypes.float32],
  1314. name="FloatOutput_{}".format(self._replica_id))
  1315. g = ops.Graph()
  1316. dependency_ops = []
  1317. for i in range(3):
  1318. dependency_ops.append(
  1319. g.create_op(
  1320. "FloatOutput", [], [dtypes.float32],
  1321. name="ColocateWithMe_{}".format(i)))
  1322. # If `switch_to_thread` isn't called, then `input` values for the ops below
  1323. # are not deterministic.
  1324. g.switch_to_thread_local()
  1325. threads = [DependingThread(g, i, dependency_ops[i]) for i in range(3)]
  1326. for t in threads:
  1327. t.start()
  1328. t.has_mutated_graph.wait()
  1329. t.has_mutated_graph.clear()
  1330. for t in threads:
  1331. t.should_continue.set()
  1332. t.join()
  1333. gd = g.as_graph_def()
  1334. self.assertProtoEqualsVersion("""
  1335. node { name: "ColocateWithMe_0" op: "FloatOutput" }
  1336. node { name: "ColocateWithMe_1" op: "FloatOutput" }
  1337. node { name: "ColocateWithMe_2" op: "FloatOutput" }
  1338. node { name: "FloatOutput_0" op: "FloatOutput"
  1339. input: "^ColocateWithMe_0" }
  1340. node { name: "FloatOutput_1" op: "FloatOutput"
  1341. input: "^ColocateWithMe_1" }
  1342. node { name: "FloatOutput_2" op: "FloatOutput"
  1343. input: "^ColocateWithMe_2" }
  1344. """, gd)
  1345. def testNameStack(self):
  1346. class NameSettingThread(self.TestThread):
  1347. def run(self):
  1348. with g.name_scope("foo"):
  1349. op1 = g.create_op("FloatOutput", [], [dtypes.float32])
  1350. self.has_mutated_graph.set()
  1351. self.should_continue.wait()
  1352. self.should_continue.clear()
  1353. op2 = g.create_op("FloatOutput", [], [dtypes.float32])
  1354. self.result = (op1, op2)
  1355. g = ops.Graph()
  1356. threads = [NameSettingThread(g, i) for i in range(3)]
  1357. for t in threads:
  1358. t.start()
  1359. t.has_mutated_graph.wait()
  1360. t.has_mutated_graph.clear()
  1361. for t in threads:
  1362. t.should_continue.set()
  1363. t.join()
  1364. suffixes = ["", "_1", "_2"]
  1365. for t, s in zip(threads, suffixes):
  1366. self.assertEquals("foo" + s + "/FloatOutput", t.result[0].name)
  1367. self.assertEquals("foo" + s + "/FloatOutput_1", t.result[1].name)
  1368. class ObjectWithName(object):
  1369. def __init__(self, name):
  1370. self._name = name
  1371. @property
  1372. def name(self):
  1373. return self._name
  1374. class CollectionTest(test_util.TensorFlowTestCase):
  1375. def test_get_collections(self):
  1376. g = ops.Graph()
  1377. self.assertSequenceEqual(g.collections, [])
  1378. g.add_to_collection("key", 12)
  1379. g.add_to_collection("key", 15)
  1380. self.assertSequenceEqual(g.collections, ["key"])
  1381. g.add_to_collection("other", "foo")
  1382. self.assertSequenceEqual(sorted(g.collections), ["key", "other"])
  1383. def test_add_to_collection(self):
  1384. g = ops.Graph()
  1385. g.add_to_collection("key", 12)
  1386. g.add_to_collection("other", "foo")
  1387. g.add_to_collection("key", 34)
  1388. # Note that only blank1 is returned.
  1389. g.add_to_collection("blah", 27)
  1390. blank1 = ObjectWithName("prefix/foo")
  1391. g.add_to_collection("blah", blank1)
  1392. blank2 = ObjectWithName("junk/foo")
  1393. g.add_to_collection("blah", blank2)
  1394. self.assertEqual([12, 34], g.get_collection("key"))
  1395. self.assertEqual([], g.get_collection("nothing"))
  1396. self.assertEqual([27, blank1, blank2], g.get_collection("blah"))
  1397. self.assertEqual([blank1], g.get_collection("blah", "prefix"))
  1398. self.assertEqual([blank1], g.get_collection("blah", ".*x"))
  1399. # Make sure that get_collection() returns a first-level
  1400. # copy of the collection, while get_collection_ref() returns
  1401. # the original list.
  1402. other_collection_snapshot = g.get_collection("other")
  1403. other_collection_ref = g.get_collection_ref("other")
  1404. self.assertEqual(["foo"], other_collection_snapshot)
  1405. self.assertEqual(["foo"], other_collection_ref)
  1406. g.add_to_collection("other", "bar")
  1407. self.assertEqual(["foo"], other_collection_snapshot)
  1408. self.assertEqual(["foo", "bar"], other_collection_ref)
  1409. self.assertEqual(["foo", "bar"], g.get_collection("other"))
  1410. self.assertTrue(other_collection_ref is g.get_collection_ref("other"))
  1411. # Verify that getting an empty collection ref returns a modifiable list.
  1412. empty_coll_ref = g.get_collection_ref("empty")
  1413. self.assertEqual([], empty_coll_ref)
  1414. empty_coll = g.get_collection("empty")
  1415. self.assertEqual([], empty_coll)
  1416. self.assertFalse(empty_coll is empty_coll_ref)
  1417. empty_coll_ref2 = g.get_collection_ref("empty")
  1418. self.assertTrue(empty_coll_ref2 is empty_coll_ref)
  1419. # Add to the collection.
  1420. empty_coll_ref.append("something")
  1421. self.assertEqual(["something"], empty_coll_ref)
  1422. self.assertEqual(["something"], empty_coll_ref2)
  1423. self.assertEqual([], empty_coll)
  1424. self.assertEqual(["something"], g.get_collection("empty"))
  1425. empty_coll_ref3 = g.get_collection_ref("empty")
  1426. self.assertTrue(empty_coll_ref3 is empty_coll_ref)
  1427. def test_add_to_collections_uniquify(self):
  1428. g = ops.Graph()
  1429. g.add_to_collections([1, 2, 1], "key")
  1430. # Make sure "key" is not added twice
  1431. self.assertEqual(["key"], g.get_collection(1))
  1432. def test_add_to_collections_from_list(self):
  1433. g = ops.Graph()
  1434. g.add_to_collections(["abc", "123"], "key")
  1435. self.assertEqual(["key"], g.get_collection("abc"))
  1436. self.assertEqual(["key"], g.get_collection("123"))
  1437. def test_add_to_collections_from_tuple(self):
  1438. g = ops.Graph()
  1439. g.add_to_collections(("abc", "123"), "key")
  1440. self.assertEqual(["key"], g.get_collection("abc"))
  1441. self.assertEqual(["key"], g.get_collection("123"))
  1442. def test_add_to_collections_from_generator(self):
  1443. g = ops.Graph()
  1444. def generator():
  1445. yield "abc"
  1446. yield "123"
  1447. g.add_to_collections(generator(), "key")
  1448. self.assertEqual(["key"], g.get_collection("abc"))
  1449. self.assertEqual(["key"], g.get_collection("123"))
  1450. def test_add_to_collections_from_set(self):
  1451. g = ops.Graph()
  1452. g.add_to_collections(set(["abc", "123"]), "key")
  1453. self.assertEqual(["key"], g.get_collection("abc"))
  1454. self.assertEqual(["key"], g.get_collection("123"))
  1455. def test_add_to_collections_from_string(self):
  1456. g = ops.Graph()
  1457. g.add_to_collections("abc", "key")
  1458. self.assertEqual(["key"], g.get_collection("abc"))
  1459. def test_default_graph(self):
  1460. with ops.Graph().as_default():
  1461. ops.add_to_collection("key", 90)
  1462. ops.add_to_collection("key", 100)
  1463. # Collections are ordered.
  1464. self.assertEqual([90, 100], ops.get_collection("key"))
  1465. def test_defun(self):
  1466. with context.eager_mode():
  1467. @eager_function.defun
  1468. def defun():
  1469. ops.add_to_collection("int", 1)
  1470. ops.add_to_collection("tensor", constant_op.constant(2))
  1471. @eager_function.defun
  1472. def inner_defun():
  1473. self.assertEqual(ops.get_collection("int"), [1])
  1474. three = ops.get_collection("tensor")[0] + ops.get_collection("int")[0]
  1475. ops.add_to_collection("int", 2)
  1476. self.assertEqual(ops.get_collection("int"), [1, 2])
  1477. ops.add_to_collection("foo", "bar")
  1478. self.assertEqual(ops.get_collection("foo"), ["bar"])
  1479. return three
  1480. self.assertEqual(ops.get_collection("int"), [1])
  1481. three = inner_defun()
  1482. self.assertEqual(ops.get_collection("int"), [1])
  1483. self.assertEqual(ops.get_collection("foo"), [])
  1484. return three
  1485. three = defun()
  1486. self.assertEqual(three.numpy(), 3)
  1487. ops.NotDifferentiable("FloatOutput")
  1488. @ops.RegisterGradient("CopyOp")
  1489. def _CopyGrad(op, x_grad): # pylint: disable=invalid-name
  1490. _ = op
  1491. return x_grad
  1492. @ops.RegisterGradient("copy_override")
  1493. def _CopyOverrideGrad(op, x_grad): # pylint: disable=invalid-name
  1494. _ = op
  1495. return x_grad
  1496. class RegistrationTest(test_util.TensorFlowTestCase):
  1497. @test_util.run_deprecated_v1
  1498. def testRegisterGradients(self):
  1499. x = test_ops.float_output()
  1500. y = test_ops.copy_op(x)
  1501. fn = ops.get_gradient_function(y.op)
  1502. self.assertEqual(_CopyGrad, fn)
  1503. def testOverrideGradients(self):
  1504. g = ops.Graph()
  1505. with g.as_default():
  1506. x = test_ops.float_output()
  1507. with g.gradient_override_map({"CopyOp": "copy_override"}):
  1508. y = test_ops.copy_op(x)
  1509. fn = ops.get_gradient_function(y.op)
  1510. self.assertEqual(_CopyOverrideGrad, fn)
  1511. def testNonExistentOverride(self):
  1512. g = ops.Graph()
  1513. with g.as_default():
  1514. x = test_ops.float_output()
  1515. with g.gradient_override_map({"CopyOp": "unknown_override"}):
  1516. y = test_ops.copy_op(x)
  1517. with self.assertRaisesRegexp(LookupError, "unknown_override"):
  1518. ops.get_gradient_function(y.op)
  1519. class ComparisonTest(test_util.TensorFlowTestCase):
  1520. def testMembershipAllowed(self):
  1521. g = ops.Graph()
  1522. t1 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop1")
  1523. t2 = _apply_op(g, "FloatOutput", [], [dtypes.float32], name="myop2")
  1524. self.assertTrue(isinstance(t1, ops.Tensor))
  1525. self.assertTrue(isinstance(t2, ops.Tensor))
  1526. self.assertTrue(t1 in [t1])
  1527. self.assertTrue(t1 not in [t2])
  1528. class ControlDependenciesTest(test_util.TensorFlowTestCase):
  1529. @test_util.run_deprecated_v1
  1530. def testBasic(self):
  1531. g = ops.Graph()
  1532. with g.as_default():
  1533. # Creating unregistered ops with _apply_op() doesn't work with the C API
  1534. # TODO(skyewm): address this more consistently. Possible solutions are
  1535. # to use registered ops in all tests, create a way to register ops in
  1536. # Python tests, or conditionally disable the op registration check in
  1537. # the C API.
  1538. a = constant_op.constant(1.0)
  1539. b = constant_op.constant(1.0)
  1540. with g.control_dependencies([a]):
  1541. c = constant_op.constant(1.0)
  1542. d = array_ops.identity(b)
  1543. e = array_ops.identity(c)
  1544. self.assertEqual(c.op.control_inputs, [a.op])
  1545. self.assertEqual(d.op.control_inputs, [a.op])
  1546. # e should be dominated by c.
  1547. self.assertEqual(e.op.control_inputs, [])
  1548. @test_util.run_in_graph_and_eager_modes
  1549. def testEager(self):
  1550. def future():
  1551. future.calls += 1
  1552. return constant_op.constant(2.0)
  1553. future.calls = 0
  1554. if context.executing_eagerly():
  1555. a = constant_op.constant(1.0)
  1556. b = future
  1557. with ops.control_dependencies([a, b]):
  1558. c = constant_op.constant(3.0)
  1559. self.assertEqual(future.calls, 1)
  1560. else:
  1561. g = ops.Graph()
  1562. with g.as_default():
  1563. a = constant_op.constant(1.0)
  1564. b = future()
  1565. with g.control_dependencies([a, b]):
  1566. c = constant_op.constant(3.0)
  1567. self.assertEqual(c.op.control_inputs, [a.op, b.op])
  1568. self.assertEqual(future.calls, 1)
  1569. def testBasicWithConversion(self):
  1570. g = ops.Graph()
  1571. a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1572. class ConvertibleObj(object):
  1573. def _as_graph_element(self):
  1574. return a
  1575. with g.control_dependencies([ConvertibleObj()]):
  1576. c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1577. self.assertEqual(c.op.control_inputs, [a.op])
  1578. def testNested(self):
  1579. g = ops.Graph()
  1580. a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1581. a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1582. a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1583. a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1584. with g.control_dependencies([a_1, a_2, a_3, a_4]):
  1585. b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1586. with g.control_dependencies([a_1]):
  1587. with g.control_dependencies([a_2]):
  1588. with g.control_dependencies([a_3]):
  1589. with g.control_dependencies([a_4]):
  1590. b_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1591. self.assertItemsEqual([a_1.op, a_2.op, a_3.op, a_4.op],
  1592. b_1.op.control_inputs)
  1593. self.assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs)
  1594. def testClear(self):
  1595. g = ops.Graph()
  1596. a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1597. a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1598. a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1599. a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1600. with g.control_dependencies([a_1]):
  1601. with g.control_dependencies([a_2]):
  1602. with g.control_dependencies(None):
  1603. with g.control_dependencies([a_3]):
  1604. with g.control_dependencies([a_4]):
  1605. # deps [a_3, a_4]
  1606. b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1607. # deps = [a_3]
  1608. b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1609. # deps back to None
  1610. b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1611. # deps back to [a_1, a_2]
  1612. b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1613. # deps back to [a_1]
  1614. b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1615. with g.control_dependencies(None):
  1616. # deps are None again
  1617. b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1618. self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs)
  1619. self.assertItemsEqual([a_3.op], b_3.op.control_inputs)
  1620. self.assertItemsEqual([], b_none.op.control_inputs)
  1621. self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs)
  1622. self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
  1623. self.assertItemsEqual([], b_none2.op.control_inputs)
  1624. def testComplex(self):
  1625. g = ops.Graph()
  1626. # Usage pattern:
  1627. # * Nodes a_i are constants defined at the outermost scope, and are used
  1628. # as control inputs for the ith nested scope.
  1629. # * Nodes b_i are defined as Mul(a_3, a_4) at each scope.
  1630. # * Nodes c_i are defined as Mul(a_1, b_1) at each scope.
  1631. # * Nodes d_i are defined as Mul(b_i, c_i) at each scope.
  1632. # * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1.
  1633. a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1634. a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1635. a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1636. a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1637. with g.control_dependencies([a_1]):
  1638. b_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
  1639. [dtypes.float32])
  1640. c_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
  1641. [dtypes.float32])
  1642. d_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_1, c_1],
  1643. [dtypes.float32])
  1644. e_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1645. with g.control_dependencies([a_2]):
  1646. b_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
  1647. [dtypes.float32])
  1648. c_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
  1649. [dtypes.float32])
  1650. d_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_2, c_2],
  1651. [dtypes.float32])
  1652. e_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_1, e_1],
  1653. [dtypes.float32])
  1654. with g.control_dependencies([a_3]):
  1655. b_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
  1656. [dtypes.float32])
  1657. c_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
  1658. [dtypes.float32])
  1659. d_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_3, c_3],
  1660. [dtypes.float32])
  1661. e_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_2, e_2],
  1662. [dtypes.float32])
  1663. with g.control_dependencies([a_4]):
  1664. b_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
  1665. [dtypes.float32])
  1666. c_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
  1667. [dtypes.float32])
  1668. d_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_4, c_4],
  1669. [dtypes.float32])
  1670. e_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_3, e_3],
  1671. [dtypes.float32])
  1672. self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
  1673. self.assertItemsEqual([a_1.op, a_2.op], b_2.op.control_inputs)
  1674. self.assertItemsEqual([a_1.op, a_2.op], b_3.op.control_inputs)
  1675. self.assertItemsEqual([a_1.op, a_2.op], b_4.op.control_inputs)
  1676. self.assertItemsEqual([], c_1.op.control_inputs)
  1677. self.assertItemsEqual([a_2.op], c_2.op.control_inputs)
  1678. self.assertItemsEqual([a_2.op, a_3.op], c_3.op.control_inputs)
  1679. self.assertItemsEqual([a_2.op, a_3.op, a_4.op], c_4.op.control_inputs)
  1680. self.assertItemsEqual([], d_1.op.control_inputs)
  1681. self.assertItemsEqual([], d_2.op.control_inputs)
  1682. self.assertItemsEqual([], d_3.op.control_inputs)
  1683. self.assertItemsEqual([], d_4.op.control_inputs)
  1684. self.assertItemsEqual([a_1.op], e_1.op.control_inputs)
  1685. self.assertItemsEqual([a_2.op], e_2.op.control_inputs)
  1686. self.assertItemsEqual([a_3.op], e_3.op.control_inputs)
  1687. self.assertItemsEqual([a_4.op], e_4.op.control_inputs)
  1688. def testRepeatedDependency(self):
  1689. g = ops.Graph()
  1690. a = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32])
  1691. a_0, a_1 = a.outputs
  1692. with g.control_dependencies([a_0]):
  1693. b = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1694. with g.control_dependencies([a_1]):
  1695. c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1696. self.assertEqual(b.op.control_inputs, [a])
  1697. self.assertEqual(c.op.control_inputs, [a])
  1698. def testNoControlDependencyWithDataDependency(self):
  1699. g = ops.Graph()
  1700. a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1701. with g.control_dependencies([a]):
  1702. b = _apply_op(g, "Identity", [a], [dtypes.float32])
  1703. self.assertEqual(b.op.control_inputs, [])
  1704. class OpScopeTest(test_util.TensorFlowTestCase):
  1705. @test_util.run_in_graph_and_eager_modes
  1706. def testNames(self):
  1707. with ops.name_scope("foo") as foo:
  1708. self.assertEqual("foo/", foo)
  1709. with ops.name_scope("foo2") as foo2:
  1710. self.assertEqual("foo/foo2/", foo2)
  1711. with ops.name_scope(None) as empty1:
  1712. self.assertEqual("", empty1)
  1713. with ops.name_scope("foo3") as foo3:
  1714. self.assertEqual("foo3/", foo3)
  1715. with ops.name_scope("") as empty2:
  1716. self.assertEqual("", empty2)
  1717. with ops.name_scope("foo/") as outer_foo:
  1718. self.assertEqual("foo/", outer_foo)
  1719. with ops.name_scope("") as empty3:
  1720. self.assertEqual("", empty3)
  1721. with ops.name_scope("foo4") as foo4:
  1722. self.assertEqual("foo/foo4/", foo4)
  1723. with ops.name_scope("foo5//") as foo5:
  1724. self.assertEqual("foo5//", foo5)
  1725. with ops.name_scope("foo6") as foo6:
  1726. self.assertEqual("foo5//foo6/", foo6)
  1727. with ops.name_scope("/") as foo7:
  1728. self.assertEqual("/", foo7)
  1729. with ops.name_scope("//") as foo8:
  1730. self.assertEqual("//", foo8)
  1731. with ops.name_scope("a//b/c") as foo9:
  1732. self.assertEqual("foo/a//b/c/", foo9)
  1733. with ops.name_scope("a//b/c") as foo10:
  1734. self.assertEqual("a//b/c/", foo10)
  1735. @test_util.run_in_graph_and_eager_modes
  1736. def testEagerDefaultScopeName(self):
  1737. with ops.name_scope(None, "default") as scope:
  1738. self.assertEqual(scope, "default/")
  1739. with ops.name_scope(None, "default2") as scope2:
  1740. self.assertEqual(scope2, "default/default2/")
  1741. @test_util.run_deprecated_v1
  1742. def testNoScopeName(self):
  1743. g0 = ops.Graph()
  1744. values = [
  1745. g0.create_op("A", [], [dtypes.float32]),
  1746. g0.create_op("B", [], [dtypes.float32])
  1747. ]
  1748. with self.assertRaises(ValueError):
  1749. with ops.name_scope(None, values=values):
  1750. pass
  1751. with self.assertRaises(ValueError):
  1752. with ops.name_scope(None, None, values):
  1753. pass
  1754. @test_util.run_deprecated_v1
  1755. def testEmptyScopeName(self):
  1756. g0 = ops.Graph()
  1757. a = g0.create_op("A", [], [dtypes.float32])
  1758. b = g0.create_op("B", [], [dtypes.float32])
  1759. with ops.name_scope("", values=[a, b]) as scope:
  1760. self.assertEqual("", scope)
  1761. self.assertEqual(g0, ops.get_default_graph())
  1762. with ops.name_scope("", "my_default_scope", [a, b]) as scope:
  1763. self.assertEqual("", scope)
  1764. self.assertEqual(g0, ops.get_default_graph())
  1765. @test_util.run_deprecated_v1
  1766. def testDefaultScopeName(self):
  1767. g0 = ops.Graph()
  1768. a = g0.create_op("A", [], [dtypes.float32])
  1769. b = g0.create_op("B", [], [dtypes.float32])
  1770. scope_name = "my_scope"
  1771. default_scope_name = "my_default_scope"
  1772. with ops.name_scope(scope_name, default_scope_name, [a, b]) as scope:
  1773. self.assertEqual("%s/" % scope_name, scope)
  1774. self.assertEqual(g0, ops.get_default_graph())
  1775. with ops.name_scope(None, default_scope_name, [a, b]) as scope:
  1776. self.assertEqual("%s/" % default_scope_name, scope)
  1777. self.assertEqual(g0, ops.get_default_graph())
  1778. def _testGraphElements(self, graph_elements):
  1779. scope_name = "my_scope"
  1780. with ops.name_scope(scope_name, values=graph_elements) as scope:
  1781. self.assertEqual("%s/" % scope_name, scope)
  1782. self.assertEqual(graph_elements[0].graph, ops.get_default_graph())
  1783. g1 = ops.Graph()
  1784. a = g1.create_op("A", [], [dtypes.float32])
  1785. with self.assertRaises(ValueError):
  1786. with ops.name_scope(scope_name, values=graph_elements + [a]):
  1787. pass
  1788. @test_util.run_deprecated_v1
  1789. def testTensor(self):
  1790. g0 = ops.Graph()
  1791. a = g0.create_op("A", [], [dtypes.float32])
  1792. b = g0.create_op("B", [], [dtypes.float32])
  1793. self._testGraphElements([a, b])
  1794. @test_util.run_deprecated_v1
  1795. def testSparseTensor(self):
  1796. g0 = ops.Graph()
  1797. a = g0.create_op("A", [], [dtypes.float32])
  1798. b = g0.create_op("B", [], [dtypes.float32])
  1799. sparse = sparse_tensor.SparseTensor(
  1800. _apply_op(g0, "Int64Output", [], [dtypes.int64]),
  1801. _apply_op(g0, "FloatOutput", [], [dtypes.float32]),
  1802. _apply_op(g0, "Int64Output", [], [dtypes.int64]))
  1803. self._testGraphElements([a, sparse, b])
  1804. @test_util.run_deprecated_v1
  1805. def testVariable(self):
  1806. g0 = ops.Graph()
  1807. with g0.as_default():
  1808. variable = variables.Variable([1.0])
  1809. a = g0.create_op("A", [], [dtypes.float32])
  1810. b = g0.create_op("B", [], [dtypes.float32])
  1811. self._testGraphElements([a, variable, b])
  1812. class InitScopeTest(test_util.TensorFlowTestCase):
  1813. def testClearsControlDependencies(self):
  1814. g = ops.Graph()
  1815. a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1816. a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1817. a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1818. a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1819. with g.as_default():
  1820. with g.control_dependencies([a_1]):
  1821. with g.control_dependencies([a_2]):
  1822. with ops.init_scope():
  1823. with g.control_dependencies([a_3]):
  1824. with g.control_dependencies([a_4]):
  1825. # deps [a_3, a_4]
  1826. b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1827. # deps = [a_3]
  1828. b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1829. # deps back to None
  1830. b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1831. # deps back to [a_1, a_2]
  1832. b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1833. # deps back to [a_1]
  1834. b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1835. with ops.init_scope():
  1836. # deps are None again
  1837. b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  1838. self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs)
  1839. self.assertItemsEqual([a_3.op], b_3.op.control_inputs)
  1840. self.assertItemsEqual([], b_none.op.control_inputs)
  1841. self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs)
  1842. self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
  1843. self.assertItemsEqual([], b_none2.op.control_inputs)
  1844. def testLiftsOpsFromFunctions(self):
  1845. g0 = ops.Graph()
  1846. g1 = ops.Graph()
  1847. g1._building_function = True # pylint: disable=protected-access
  1848. g2 = ops.Graph()
  1849. g2._building_function = True # pylint: disable=protected-access
  1850. with g0.as_default():
  1851. with g1.as_default():
  1852. with g2.as_default():
  1853. with ops.init_scope():
  1854. _ = constant_op.constant(1.0)
  1855. self.assertEqual(len(g2.get_operations()), 0)
  1856. self.assertEqual(len(g1.get_operations()), 0)
  1857. self.assertEqual(len(g0.get_operations()), 1)
  1858. def testPreservesDevices(self):
  1859. g0 = ops.Graph()
  1860. with g0.as_default(), ops.device("CPU:0"):
  1861. g1 = ops.Graph()
  1862. g1._building_function = True # pylint: disable=protected-access
  1863. with g1.as_default(), ops.device("GPU:0"):
  1864. with ops.init_scope():
  1865. # init_scope should preserve device set under `g1`.
  1866. on_gpu = constant_op.constant(1.0)
  1867. self.assertEqual(on_gpu.device, "/device:GPU:0")
  1868. still_on_gpu = constant_op.constant(1.0)
  1869. self.assertEqual(still_on_gpu.device, "/device:GPU:0")
  1870. on_cpu = constant_op.constant(1.0)
  1871. self.assertEqual(on_cpu.device, "/device:CPU:0")
  1872. def testComposes(self):
  1873. g0 = ops.Graph()
  1874. g1 = ops.Graph()
  1875. g1._building_function = True # pylint: disable=protected-access
  1876. g2 = ops.Graph()
  1877. g2._building_function = True # pylint: disable=protected-access
  1878. g3 = ops.Graph()
  1879. g3._building_function = False # pylint: disable=protected-access
  1880. with g0.as_default():
  1881. with g1.as_default():
  1882. with ops.init_scope():
  1883. # This op should be lifted into g0.
  1884. _ = constant_op.constant(1.0)
  1885. self.assertIs(g0, ops.get_default_graph())
  1886. self.assertEqual(len(g2.get_operations()), 0)
  1887. self.assertEqual(len(g1.get_operations()), 0)
  1888. self.assertEqual(len(g0.get_operations()), 1)
  1889. with g2.as_default():
  1890. with ops.init_scope():
  1891. # This op should be lifted into g0.
  1892. _ = constant_op.constant(1.0)
  1893. self.assertIs(g0, ops.get_default_graph())
  1894. with g3.as_default():
  1895. with ops.init_scope():
  1896. # This op should be lifted into g3, because g3 is not building a
  1897. # function.
  1898. _ = constant_op.constant(1.0)
  1899. self.assertIs(g3, ops.get_default_graph())
  1900. self.assertEqual(len(g3.get_operations()), 1)
  1901. self.assertEqual(len(g2.get_operations()), 0)
  1902. self.assertEqual(len(g1.get_operations()), 0)
  1903. self.assertEqual(len(g0.get_operations()), 2)
  1904. def testEscapesToEagerContext(self):
  1905. g = ops.Graph()
  1906. g._building_function = True # pylint: disable=protected-access
  1907. with context.eager_mode():
  1908. with context.graph_mode():
  1909. with g.as_default():
  1910. with ops.init_scope():
  1911. # Because g is building a function, init_scope should
  1912. # escape out to the eager context.
  1913. self.assertTrue(context.executing_eagerly())
  1914. # g should be reinstated as the default graph, and the
  1915. # graph context should be re-entered.
  1916. self.assertIs(g, ops.get_default_graph())
  1917. self.assertFalse(context.executing_eagerly())
  1918. def testStaysInEagerWhenOnlyEagerContextActive(self):
  1919. with context.eager_mode():
  1920. with ops.init_scope():
  1921. self.assertTrue(context.eager_mode())
  1922. self.assertTrue(context.eager_mode())
  1923. def testEscapesDefunWhenInEagerMode(self):
  1924. def function_with_variables():
  1925. with ops.init_scope():
  1926. self.v = resource_variable_ops.ResourceVariable(3)
  1927. return self.v.assign_add(1)
  1928. with context.eager_mode():
  1929. # Each invocation of function_with_variables recreates a variable.
  1930. self.assertEqual(4, int(function_with_variables()))
  1931. self.assertEqual(4, int(function_with_variables()))
  1932. compiled = eager_function.defun(function_with_variables)
  1933. # The init_scope in function_with_variables lifts the variable out
  1934. # of the graph function constructed by defun; hence,
  1935. # compiled now appears to be stateful.
  1936. self.assertEqual(4, int(compiled()))
  1937. self.assertEqual(5, int(compiled()))
  1938. def testEscapesDefunWhenInGraphMode(self):
  1939. def function_with_variables(name):
  1940. with ops.init_scope():
  1941. _ = variable_scope.get_variable(name, shape=(1,))
  1942. g = ops.Graph()
  1943. with g.as_default():
  1944. with self.cached_session():
  1945. # First ensure that graphs that are not building functions are
  1946. # not escaped.
  1947. function_with_variables("foo")
  1948. with self.assertRaisesRegexp(ValueError,
  1949. r"Variable foo already exists.*"):
  1950. # This will fail because reuse is not set to True.
  1951. function_with_variables("foo")
  1952. compiled = eager_function.defun(function_with_variables)
  1953. compiled("bar")
  1954. self.assertEqual(
  1955. len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2)
  1956. # The second call to `compiled` should not create variables: the
  1957. # init_scope has lifted the variable creation code out of the defun.
  1958. compiled("bar")
  1959. self.assertEqual(
  1960. len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2)
  1961. def testEscapesNestedDefun(self):
  1962. def inner_function():
  1963. with ops.init_scope():
  1964. self.v = resource_variable_ops.ResourceVariable(1)
  1965. return self.v.assign_add(2)
  1966. def outer_function(inner=None):
  1967. with ops.init_scope():
  1968. self.v0 = resource_variable_ops.ResourceVariable(0)
  1969. return self.v0.assign_add(1) + inner()
  1970. with context.eager_mode():
  1971. # Each invocation of outer_function recreates variables.
  1972. self.assertEqual(4, int(outer_function(inner=inner_function)))
  1973. self.assertEqual(4, int(outer_function(inner=inner_function)))
  1974. compiled_inner = eager_function.defun(inner_function)
  1975. compiled_outer = eager_function.defun(outer_function)
  1976. # The init_scope lifts variables out of the graph functions
  1977. # constructed by defun; hence, compiled_outer should now appear to be
  1978. # stateful.
  1979. self.assertEqual(4, int(compiled_outer(inner=compiled_inner)))
  1980. self.assertEqual(7, int(compiled_outer(inner=compiled_inner)))
  1981. @test_util.run_v1_only("b/120545219")
  1982. def testFallsBackToGlobalGraphWhenAllGraphsAreBuildingFunctions(self):
  1983. with context.graph_mode():
  1984. ops.reset_default_graph()
  1985. # This doesn't push anything onto the graph stack, but it does
  1986. # set the stack's global graph.
  1987. global_graph = ops.get_default_graph()
  1988. fn_graph = ops.Graph()
  1989. # pylint: disable=protected-access
  1990. fn_graph._building_function = True
  1991. self.assertEqual(len(ops._default_graph_stack.stack), 0)
  1992. with fn_graph.as_default():
  1993. self.assertEqual(len(ops._default_graph_stack.stack), 1)
  1994. with ops.init_scope():
  1995. self.assertGreater(len(ops._default_graph_stack.stack), 1)
  1996. dummy = constant_op.constant(1.0)
  1997. self.assertEqual(len(ops._default_graph_stack.stack), 1)
  1998. # Note that the global graph is _not_ on the graph stack.
  1999. self.assertEqual(len(ops._default_graph_stack.stack), 0)
  2000. # Ensure that `dummy` was added to the global graph.
  2001. self.assertEqual(global_graph, dummy.graph)
  2002. # pylint: enable=protected-access
  2003. def testInstallsDefaultGraphWhenGraphStackIsEmptyInGraphMode(self):
  2004. with context.graph_mode():
  2005. # pylint: disable=protected-access
  2006. self.assertEqual(len(ops._default_graph_stack.stack), 0)
  2007. with ops.init_scope():
  2008. self.assertGreater(len(ops._default_graph_stack.stack), 0)
  2009. self.assertEqual(len(ops._default_graph_stack.stack), 0)
  2010. # pylint: enable=protected-access
  2011. def testPreservesNameScopeInGraphConstruction(self):
  2012. with ops.Graph().as_default():
  2013. function_graph = ops.Graph()
  2014. with function_graph.as_default():
  2015. with ops.name_scope("inner"), ops.init_scope():
  2016. self.assertEqual(ops.get_name_scope(), "inner")
  2017. self.assertEqual(ops.get_name_scope(), "")
  2018. def testEnteringGraphFromEagerIsSticky(self):
  2019. with context.eager_mode():
  2020. g = ops.Graph()
  2021. with g.as_default():
  2022. with ops.init_scope():
  2023. self.assertFalse(context.executing_eagerly())
  2024. self.assertEqual(g, ops.get_default_graph())
  2025. def testMixGraphEager(self):
  2026. with context.eager_mode():
  2027. c = constant_op.constant(1.0)
  2028. with ops.Graph().as_default():
  2029. with self.assertRaisesRegexp(
  2030. RuntimeError, "Attempting to capture an EagerTensor"):
  2031. math_ops.add(c, c)
  2032. c2 = constant_op.constant(2.0)
  2033. with self.assertRaisesRegexp(
  2034. TypeError, "contains objects other than 'EagerTensor'"):
  2035. math_ops.add(c2, c2)
  2036. def testPreservesNameScopeInEagerExecution(self):
  2037. with context.eager_mode():
  2038. def foo():
  2039. with ops.name_scope("inner"), ops.init_scope():
  2040. if context.executing_eagerly():
  2041. # A trailing slash is always appended when eager execution is
  2042. # enabled.
  2043. self.assertEqual(context.context().scope_name, "inner/")
  2044. else:
  2045. self.assertEqual(ops.get_name_scope(), "inner")
  2046. foo()
  2047. self.assertEqual(ops.get_name_scope(), "")
  2048. foo_compiled = eager_function.defun(foo)
  2049. foo_compiled()
  2050. self.assertEqual(ops.get_name_scope(), "")
  2051. def testExecutingEagerlyOutsideFunctions(self):
  2052. @eager_function.defun
  2053. def f():
  2054. return ops.executing_eagerly_outside_functions()
  2055. with context.eager_mode():
  2056. self.assertTrue(ops.executing_eagerly_outside_functions())
  2057. self.assertTrue(f())
  2058. g = ops.Graph()
  2059. with g.as_default():
  2060. self.assertFalse(ops.executing_eagerly_outside_functions())
  2061. class GraphTest(test_util.TensorFlowTestCase):
  2062. def setUp(self):
  2063. ops.reset_default_graph()
  2064. def _AssertDefault(self, expected):
  2065. self.assertIs(expected, ops.get_default_graph())
  2066. def testResetDefaultGraphNesting(self):
  2067. g0 = ops.Graph()
  2068. with self.assertRaises(AssertionError):
  2069. with g0.as_default():
  2070. ops.reset_default_graph()
  2071. def testGraphContextManagerCancelsEager(self):
  2072. with context.eager_mode():
  2073. with ops.Graph().as_default():
  2074. self.assertFalse(context.executing_eagerly())
  2075. def testGraphContextManager(self):
  2076. g0 = ops.Graph()
  2077. with g0.as_default() as g1:
  2078. self.assertIs(g0, g1)
  2079. def testDefaultGraph(self):
  2080. orig = ops.get_default_graph()
  2081. self._AssertDefault(orig)
  2082. g0 = ops.Graph()
  2083. self._AssertDefault(orig)
  2084. context_manager_0 = g0.as_default()
  2085. self._AssertDefault(orig)
  2086. with context_manager_0 as g0:
  2087. self._AssertDefault(g0)
  2088. with ops.Graph().as_default() as g1:
  2089. self._AssertDefault(g1)
  2090. self._AssertDefault(g0)
  2091. self._AssertDefault(orig)
  2092. def testPreventFeeding(self):
  2093. g = ops.Graph()
  2094. a = constant_op.constant(2.0)
  2095. self.assertTrue(g.is_feedable(a))
  2096. g.prevent_feeding(a)
  2097. self.assertFalse(g.is_feedable(a))
  2098. @test_util.run_deprecated_v1
  2099. def testPreventFetching(self):
  2100. g = ops.Graph()
  2101. a = constant_op.constant(2.0)
  2102. self.assertTrue(g.is_fetchable(a))
  2103. g.prevent_fetching(a.op)
  2104. self.assertFalse(g.is_fetchable(a))
  2105. def testAsGraphElementConversions(self):
  2106. class ConvertibleObj(object):
  2107. def _as_graph_element(self):
  2108. return "FloatOutput:0"
  2109. class NonConvertibleObj(object):
  2110. pass
  2111. g = ops.Graph()
  2112. a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  2113. self.assertEqual(a, g.as_graph_element(ConvertibleObj()))
  2114. with self.assertRaises(TypeError):
  2115. g.as_graph_element(NonConvertibleObj())
  2116. # Regression test against creating custom __del__ functions in classes
  2117. # involved in cyclic references, e.g. Graph and Operation. (Python won't gc
  2118. # cycles that require calling a __del__ method, because the __del__ method can
  2119. # theoretically increase the object's refcount to "save" it from gc, and any
  2120. # already-deleted objects in the cycle would have be to restored.)
  2121. def testGarbageCollected(self):
  2122. # Create a graph we can delete and a weak reference to monitor if it's gc'd
  2123. g = ops.Graph()
  2124. g_ref = weakref.ref(g)
  2125. # Create some ops
  2126. with g.as_default():
  2127. a = constant_op.constant(2.0)
  2128. b = constant_op.constant(3.0)
  2129. c = math_ops.add(a, b)
  2130. # Create a session we can delete
  2131. with session.Session(graph=g) as sess:
  2132. self.evaluate(c)
  2133. # Delete all references and trigger gc
  2134. del g
  2135. del a
  2136. del b
  2137. del c
  2138. del sess
  2139. gc.collect()
  2140. self.assertIsNone(g_ref())
  2141. def testRunnableAfterInvalidShape(self):
  2142. with ops.Graph().as_default():
  2143. with self.assertRaises(ValueError):
  2144. math_ops.add([1, 2], [1, 2, 3])
  2145. a = constant_op.constant(1)
  2146. with session.Session() as sess:
  2147. self.evaluate(a)
  2148. def testRunnableAfterInvalidShapeWithKernelLabelMap(self):
  2149. g = ops.Graph()
  2150. with g.as_default():
  2151. with g._kernel_label_map({"KernelLabelRequired": "overload_1"}):
  2152. with self.assertRaises(ValueError):
  2153. test_ops.kernel_label_required(1)
  2154. a = constant_op.constant(1)
  2155. with session.Session() as sess:
  2156. self.evaluate(a)
  2157. class AttrScopeTest(test_util.TensorFlowTestCase):
  2158. def _get_test_attrs(self):
  2159. x = control_flow_ops.no_op()
  2160. try:
  2161. a = compat.as_text(x.get_attr("_A"))
  2162. except ValueError:
  2163. a = None
  2164. try:
  2165. b = compat.as_text(x.get_attr("_B"))
  2166. except ValueError:
  2167. b = None
  2168. return (a, b)
  2169. @test_util.run_deprecated_v1
  2170. def testNoLabel(self):
  2171. with self.cached_session():
  2172. self.assertAllEqual((None, None), self._get_test_attrs())
  2173. @test_util.run_deprecated_v1
  2174. def testLabelMap(self):
  2175. with self.cached_session() as sess:
  2176. a1 = self._get_test_attrs()
  2177. with sess.graph._attr_scope({
  2178. "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("foo"))
  2179. }):
  2180. a2 = self._get_test_attrs()
  2181. with sess.graph._attr_scope({
  2182. "_A": None,
  2183. "_B": attr_value_pb2.AttrValue(s=compat.as_bytes("bar"))
  2184. }):
  2185. a3 = self._get_test_attrs()
  2186. with sess.graph._attr_scope({
  2187. "_A": attr_value_pb2.AttrValue(s=compat.as_bytes("baz"))
  2188. }):
  2189. a4 = self._get_test_attrs()
  2190. a5 = self._get_test_attrs()
  2191. a6 = self._get_test_attrs()
  2192. a7 = self._get_test_attrs()
  2193. self.assertAllEqual((None, None), a1)
  2194. self.assertAllEqual(("foo", None), a2)
  2195. self.assertAllEqual((None, "bar"), a3)
  2196. self.assertAllEqual(("baz", "bar"), a4)
  2197. self.assertAllEqual((None, "bar"), a5)
  2198. self.assertAllEqual(("foo", None), a6)
  2199. self.assertAllEqual((None, None), a7)
  2200. ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape)
  2201. class KernelLabelTest(test_util.TensorFlowTestCase):
  2202. @test_util.run_deprecated_v1
  2203. def testNoLabel(self):
  2204. with self.cached_session():
  2205. self.assertAllEqual(b"My label is: default",
  2206. test_ops.kernel_label().eval())
  2207. @test_util.run_deprecated_v1
  2208. def testLabelMap(self):
  2209. with self.cached_session() as sess:
  2210. default_1 = test_ops.kernel_label()
  2211. # pylint: disable=protected-access
  2212. with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}):
  2213. overload_1_1 = test_ops.kernel_label()
  2214. with sess.graph._kernel_label_map({"KernelLabel": "overload_2"}):
  2215. overload_2 = test_ops.kernel_label()
  2216. with sess.graph._kernel_label_map({"KernelLabel": ""}):
  2217. default_2 = test_ops.kernel_label()
  2218. overload_1_2 = test_ops.kernel_label()
  2219. # pylint: enable=protected-access
  2220. default_3 = test_ops.kernel_label()
  2221. self.assertAllEqual(b"My label is: default", self.evaluate(default_1))
  2222. self.assertAllEqual(b"My label is: default", self.evaluate(default_2))
  2223. self.assertAllEqual(b"My label is: default", self.evaluate(default_3))
  2224. self.assertAllEqual(b"My label is: overload_1",
  2225. self.evaluate(overload_1_1))
  2226. self.assertAllEqual(b"My label is: overload_1",
  2227. self.evaluate(overload_1_2))
  2228. self.assertAllEqual(b"My label is: overload_2", self.evaluate(overload_2))
  2229. class AsGraphDefTest(test_util.TensorFlowTestCase):
  2230. def testGraphDefVersion(self):
  2231. """Test that the graphdef version is plumbed through to kernels."""
  2232. with ops.Graph().as_default() as g:
  2233. version = g.graph_def_versions.producer
  2234. with self.session(graph=g):
  2235. v = test_ops.graph_def_version().eval()
  2236. self.assertEqual(version, v)
  2237. def testAddShapes(self):
  2238. with ops.Graph().as_default() as g:
  2239. t1, t2, t3, t4, t5 = _apply_op(g, "FiveFloatOutputs", [],
  2240. [dtypes.float32] * 5)
  2241. t1.set_shape(None)
  2242. t2.set_shape([])
  2243. t3.set_shape([None])
  2244. t4.set_shape([43, 37])
  2245. t5.set_shape([43, None])
  2246. b = constant_op.constant(1.0) # pylint: disable=unused-variable
  2247. gd = g.as_graph_def(add_shapes=True)
  2248. self.assertProtoEqualsVersion("""
  2249. node { name: "FiveFloatOutputs" op: "FiveFloatOutputs"
  2250. attr {
  2251. key: "_output_shapes"
  2252. value {
  2253. list {
  2254. shape { unknown_rank: true }
  2255. shape { }
  2256. shape { dim { size: -1 } }
  2257. shape { dim { size: 43 } dim { size: 37 } }
  2258. shape { dim { size: 43 } dim { size: -1 } }
  2259. }
  2260. }
  2261. }
  2262. }
  2263. node { name: "Const" op: "Const"
  2264. attr {
  2265. key: "_output_shapes"
  2266. value {
  2267. list {
  2268. shape { }
  2269. }
  2270. }
  2271. }
  2272. attr {
  2273. key: "dtype"
  2274. value { type: DT_FLOAT }
  2275. }
  2276. attr {
  2277. key: "value"
  2278. value {
  2279. tensor {
  2280. dtype: DT_FLOAT
  2281. tensor_shape { }
  2282. float_val: 1.0 } } } }
  2283. """, gd)
  2284. @ops.RegisterStatistics("a", "flops")
  2285. def _calc_a_forward_flops(unused_graph, unused_node):
  2286. return ops.OpStats("flops", 20)
  2287. class StatisticsTest(test_util.TensorFlowTestCase):
  2288. def testRegisteredNode(self):
  2289. graph = ops.Graph()
  2290. node = ops._NodeDef("a", "an_a")
  2291. flops = ops.get_stats_for_node_def(graph, node, "flops")
  2292. self.assertEqual(20, flops.value)
  2293. missing_stat = ops.get_stats_for_node_def(graph, node, "missing_stat")
  2294. self.assertEqual(None, missing_stat.value)
  2295. def testUnregisteredNode(self):
  2296. graph = ops.Graph()
  2297. node = ops._NodeDef("b", "a_b")
  2298. weight_params = ops.get_stats_for_node_def(graph, node, "weight_params")
  2299. self.assertEqual(None, weight_params.value)
  2300. def testAccumulateStatistics(self):
  2301. flops_total = ops.OpStats("flops")
  2302. self.assertEqual(None, flops_total.value)
  2303. second_flops = ops.OpStats("flops", 3)
  2304. flops_total += second_flops
  2305. self.assertEqual(3, flops_total.value)
  2306. class DeviceStackTest(test_util.TensorFlowTestCase):
  2307. @test_util.run_deprecated_v1
  2308. def testBasicDeviceAssignmentMetadata(self):
  2309. def device_func(unused_op):
  2310. return "/cpu:*"
  2311. const_zero = constant_op.constant([0.0], name="zero")
  2312. with ops.device("/cpu"):
  2313. const_one = constant_op.constant([1.0], name="one")
  2314. with ops.device("/cpu:0"):
  2315. const_two = constant_op.constant([2.0], name="two")
  2316. with ops.device(device_func):
  2317. const_three = constant_op.constant(3.0, name="three")
  2318. self.assertEqual(0, len(const_zero.op._device_assignments))
  2319. one_list = const_one.op._device_assignments
  2320. self.assertEqual(1, len(one_list))
  2321. self.assertEqual("/cpu", one_list[0].obj)
  2322. self.assertEqual("ops_test.py", os.path.basename(one_list[0].filename))
  2323. two_list = const_two.op._device_assignments
  2324. self.assertEqual(2, len(two_list))
  2325. devices = [t.obj for t in two_list]
  2326. self.assertEqual(set(["/cpu", "/cpu:0"]), set(devices))
  2327. three_list = const_three.op._device_assignments
  2328. self.assertEqual(1, len(three_list))
  2329. func_description = three_list[0].obj
  2330. expected_regex = r"device_func<.*ops_test.py, [0-9]+"
  2331. self.assertRegexpMatches(func_description, expected_regex)
  2332. @test_util.run_deprecated_v1
  2333. def testDeviceAssignmentMetadataForGraphDeviceAndTfDeviceFunctions(self):
  2334. with ops.device("/cpu"):
  2335. const_one = constant_op.constant([1.0], name="one")
  2336. with ops.get_default_graph().device("/cpu"):
  2337. const_two = constant_op.constant([2.0], name="two")
  2338. one_metadata = const_one.op._device_assignments[0]
  2339. two_metadata = const_two.op._device_assignments[0]
  2340. # Verify both types of device assignment return the right stack info.
  2341. self.assertRegexpMatches("ops_test.py",
  2342. os.path.basename(one_metadata.filename))
  2343. self.assertEqual(one_metadata.filename, two_metadata.filename)
  2344. self.assertEqual(one_metadata.lineno + 2, two_metadata.lineno)
  2345. class ColocationGroupTest(test_util.TensorFlowTestCase):
  2346. @test_util.run_deprecated_v1
  2347. def testBasic(self):
  2348. a = constant_op.constant([2.0], name="a")
  2349. with ops.colocate_with(a.op):
  2350. b = constant_op.constant(3.0)
  2351. c = constant_op.constant(4.0)
  2352. self.assertEqual([b"loc:@a"], a.op.colocation_groups())
  2353. self.assertEqual([b"loc:@a"], b.op.colocation_groups())
  2354. with self.assertRaises(ValueError):
  2355. c.op.get_attr("_class")
  2356. @test_util.run_deprecated_v1
  2357. def testBasicColocationMetadata(self):
  2358. const_two = constant_op.constant([2.0], name="two")
  2359. with ops.colocate_with(const_two.op):
  2360. const_three = constant_op.constant(3.0, name="three")
  2361. locations_dict = const_three.op._colocation_dict
  2362. self.assertIn("two", locations_dict)
  2363. metadata = locations_dict["two"]
  2364. self.assertIsNone(metadata.obj)
  2365. # Check that this test's filename is recorded as the file containing the
  2366. # colocation statement.
  2367. self.assertEqual("ops_test.py", os.path.basename(metadata.filename))
  2368. @test_util.run_deprecated_v1
  2369. def testColocationDeviceInteraction(self):
  2370. with ops.device("/cpu:0"):
  2371. with ops.device("/device:GPU:0"):
  2372. a = constant_op.constant([2.0], name="a")
  2373. with ops.colocate_with(a.op):
  2374. # 'b' is created in the scope of /cpu:0, but it is
  2375. # colocated with 'a', which is on '/device:GPU:0'. colocate_with
  2376. # overrides devices because it is a stronger constraint.
  2377. b = constant_op.constant(3.0)
  2378. self.assertEqual([b"loc:@a"], b.op.colocation_groups())
  2379. self.assertEqual(a.op.device, b.op.device)
  2380. @test_util.run_deprecated_v1
  2381. def testColocationCanonicalization(self):
  2382. with ops.device("/device:GPU:0"):
  2383. _ = constant_op.constant(2.0)
  2384. with ops.device(lambda op: "/device:GPU:0"):
  2385. b = constant_op.constant(3.0)
  2386. with ops.get_default_graph().colocate_with(b):
  2387. with ops.device("/device:GPU:0"):
  2388. c = constant_op.constant(4.0)
  2389. # A's device will be /device:GPU:0
  2390. # B's device will be /device:GPU:0
  2391. # C's device will be /device:GPU:0 because it
  2392. # inherits B's device name, after canonicalizing the names.
  2393. self.assertEqual(b.op.device, c.op.device)
  2394. @test_util.run_deprecated_v1
  2395. def testLocationOverrides(self):
  2396. with ops.device("/cpu:0"):
  2397. with ops.device("/device:GPU:0"):
  2398. a = constant_op.constant([2.0], name="a")
  2399. # Note that this colocation is "redundant", since we are
  2400. # within the scope of "/device:GPU:0". However, we would like to
  2401. # preserve in the GraphDef that these two ops should be
  2402. # colocated in a portable way.
  2403. with ops.colocate_with(a.op):
  2404. b = constant_op.constant(3.0)
  2405. c = constant_op.constant(4.0)
  2406. d = constant_op.constant(5.0)
  2407. self.assertEqual([b"loc:@a"], b.op.colocation_groups())
  2408. self.assertEqual("/device:GPU:0", a.op.device)
  2409. self.assertEqual(a.op.device, b.op.device)
  2410. # Test that device function stack is restored.
  2411. self.assertEqual("/device:GPU:0", c.op.device)
  2412. self.assertEqual("/device:CPU:0", d.op.device)
  2413. @test_util.run_deprecated_v1
  2414. def testNestedColocateWith(self):
  2415. a = constant_op.constant([2.0], name="a")
  2416. with ops.colocate_with(a.op):
  2417. b = constant_op.constant(3.0)
  2418. with ops.colocate_with(b.op):
  2419. c = constant_op.constant(4.0)
  2420. self.assertEqual([b"loc:@a"], b.op.colocation_groups())
  2421. self.assertEqual([b"loc:@a"], c.op.colocation_groups())
  2422. @test_util.run_deprecated_v1
  2423. def testMultiColocationGroups(self):
  2424. a = constant_op.constant([2.0], name="a")
  2425. b = constant_op.constant(3.0, name="b")
  2426. with ops.colocate_with(a.op):
  2427. with ops.colocate_with(b.op):
  2428. c = constant_op.constant(4.0)
  2429. self.assertEqual(set([b"loc:@a", b"loc:@b"]), set(c.op.colocation_groups()))
  2430. @test_util.run_deprecated_v1
  2431. def testColocationIgnoreStack(self):
  2432. a = constant_op.constant([2.0], name="a")
  2433. b = constant_op.constant(3.0, name="b")
  2434. with ops.colocate_with(a.op):
  2435. with ops.colocate_with(b.op, ignore_existing=True):
  2436. c = constant_op.constant(4.0)
  2437. self.assertEqual(set([b"loc:@b"]), set(c.op.colocation_groups()))
  2438. @test_util.run_deprecated_v1
  2439. def testColocateWithReset(self):
  2440. a = constant_op.constant([2.0], name="a")
  2441. with ops.colocate_with(a.op):
  2442. b = constant_op.constant(3.0, name="b")
  2443. with ops.colocate_with(None, ignore_existing=True):
  2444. c = constant_op.constant(4.0, name="c")
  2445. self.assertEqual([b"loc:@a"], b.op.colocation_groups())
  2446. self.assertEqual([b"loc:@c"], c.op.colocation_groups())
  2447. @test_util.run_deprecated_v1
  2448. def testColocateWithInitialNoneThenNested(self):
  2449. a = constant_op.constant([2.0], name="a")
  2450. with ops.colocate_with(a.op):
  2451. with ops.colocate_with(None, ignore_existing=True):
  2452. b = constant_op.constant(3.0, name="b")
  2453. with ops.colocate_with(b.op):
  2454. c = constant_op.constant(4.0, name="c")
  2455. self.assertEqual([b"loc:@b"], b.op.colocation_groups())
  2456. self.assertEqual([b"loc:@b"], c.op.colocation_groups())
  2457. @test_util.run_deprecated_v1
  2458. def testColocateVariables(self):
  2459. a = variables.Variable([2.0], name="a")
  2460. with ops.colocate_with(a.op):
  2461. b = variables.Variable([3.0], name="b")
  2462. self.assertEqual([b"loc:@a"], b.op.colocation_groups())
  2463. class DeprecatedTest(test_util.TensorFlowTestCase):
  2464. def testSuccess(self):
  2465. with ops.Graph().as_default() as g:
  2466. test_util.set_producer_version(g, 7)
  2467. old = test_ops.old()
  2468. with self.session(graph=g):
  2469. old.run()
  2470. def _error(self):
  2471. return ((r"Op Old is not available in GraphDef version %d\. "
  2472. r"It has been removed in version 8\. For reasons\.") %
  2473. versions.GRAPH_DEF_VERSION)
  2474. def testGraphConstructionFail(self):
  2475. with ops.Graph().as_default():
  2476. with self.assertRaisesRegexp(NotImplementedError, self._error()):
  2477. test_ops.old()
  2478. class DenseTensorLikeTypeTest(test_util.TensorFlowTestCase):
  2479. def testSuccess(self):
  2480. op = ops.Operation(
  2481. ops._NodeDef("FloatOutput", "myop"), ops.Graph(), [], [dtypes.float32])
  2482. t = op.outputs[0]
  2483. self.assertTrue(ops.is_dense_tensor_like(t))
  2484. v = variables.Variable([17])
  2485. self.assertTrue(ops.is_dense_tensor_like(v))
  2486. class BadClassNoName(object):
  2487. pass
  2488. class BadClassBadName(object):
  2489. def name(self):
  2490. pass
  2491. class BadClassNoDtype(object):
  2492. @property
  2493. def name(self):
  2494. pass
  2495. class BadClassBadDtype(object):
  2496. @property
  2497. def name(self):
  2498. pass
  2499. def dtype(self):
  2500. pass
  2501. def testBadClass(self):
  2502. with self.assertRaisesRegexp(TypeError, "`name`"):
  2503. ops.register_dense_tensor_like_type(
  2504. DenseTensorLikeTypeTest.BadClassNoName)
  2505. with self.assertRaisesRegexp(TypeError, "`name`"):
  2506. ops.register_dense_tensor_like_type(
  2507. DenseTensorLikeTypeTest.BadClassBadName)
  2508. with self.assertRaisesRegexp(TypeError, "`dtype`"):
  2509. ops.register_dense_tensor_like_type(
  2510. DenseTensorLikeTypeTest.BadClassNoDtype)
  2511. with self.assertRaisesRegexp(TypeError, "`dtype`"):
  2512. ops.register_dense_tensor_like_type(
  2513. DenseTensorLikeTypeTest.BadClassBadDtype)
  2514. class NameScopeTest(test_util.TensorFlowTestCase):
  2515. def testStripAndPrependScope(self):
  2516. strs = [
  2517. "hidden1/hidden1/weights", # Same prefix. Should strip.
  2518. "hidden1///hidden1/weights", # Extra "/". Should strip.
  2519. "^hidden1/hidden1/weights", # Same prefix. Should strip.
  2520. "loc:@hidden1/hidden1/weights", # Same prefix. Should strip.
  2521. "hhidden1/hidden1/weights", # Different prefix. Should keep.
  2522. "hidden1"
  2523. ] # Not a prefix. Should keep.
  2524. expected_striped = [
  2525. "hidden1/weights", "hidden1/weights", "^hidden1/weights",
  2526. "loc:@hidden1/weights", "hhidden1/hidden1/weights", "hidden1"
  2527. ]
  2528. expected_prepended = [
  2529. "hidden2/hidden1/weights", "hidden2/hidden1/weights",
  2530. "^hidden2/hidden1/weights", "loc:@hidden2/hidden1/weights",
  2531. "hidden2/hhidden1/hidden1/weights", "hidden2/hidden1"
  2532. ]
  2533. name_scope_to_strip = "hidden1"
  2534. name_scope_to_add = "hidden2"
  2535. for es, ep, s in zip(expected_striped, expected_prepended, strs):
  2536. striped = ops.strip_name_scope(s, name_scope_to_strip)
  2537. self.assertEqual(es, striped)
  2538. self.assertEqual(ep, ops.prepend_name_scope(striped, name_scope_to_add))
  2539. def testGetNameScope(self):
  2540. with ops.Graph().as_default() as g:
  2541. with ops.name_scope("scope1"):
  2542. with ops.name_scope("scope2"):
  2543. with ops.name_scope("scope3"):
  2544. self.assertEqual("scope1/scope2/scope3", g.get_name_scope())
  2545. self.assertEqual("scope1/scope2", g.get_name_scope())
  2546. self.assertEqual("scope1", g.get_name_scope())
  2547. self.assertEqual("", g.get_name_scope())
  2548. def testTwoGraphs(self):
  2549. def f():
  2550. g1 = ops.Graph()
  2551. g2 = ops.Graph()
  2552. with g1.as_default():
  2553. with g2.as_default():
  2554. with ops.name_scope("_"):
  2555. pass
  2556. self.assertRaisesRegexp(ValueError, "'_' is not a valid scope name", f)
  2557. class TracebackTest(test_util.TensorFlowTestCase):
  2558. @test_util.run_deprecated_v1
  2559. def testTracebackWithStartLines(self):
  2560. with self.cached_session() as sess:
  2561. a = constant_op.constant(2.0)
  2562. sess.run(
  2563. a,
  2564. options=config_pb2.RunOptions(
  2565. trace_level=config_pb2.RunOptions.FULL_TRACE))
  2566. self.assertTrue(sess.graph.get_operations())
  2567. # Tests that traceback_with_start_lines is the same as traceback
  2568. # but includes one more element at the end.
  2569. for op in sess.graph.get_operations():
  2570. self.assertEquals(len(op.traceback), len(op.traceback_with_start_lines))
  2571. for frame, frame_with_start_line in zip(
  2572. op.traceback, op.traceback_with_start_lines):
  2573. self.assertEquals(5, len(frame_with_start_line))
  2574. self.assertEquals(frame, frame_with_start_line[:-1])
  2575. class EnableEagerExecutionTest(test_util.TensorFlowTestCase):
  2576. @test_util.run_v1_only("b/120545219")
  2577. def testBadArgumentsToEnableEagerExecution(self):
  2578. with self.assertRaisesRegexp(TypeError, "config must be a tf.ConfigProto"):
  2579. ops.enable_eager_execution(context.DEVICE_PLACEMENT_SILENT)
  2580. with self.assertRaisesRegexp(ValueError, "device_policy must be one of"):
  2581. c = config_pb2.ConfigProto()
  2582. ops.enable_eager_execution(c, c)
  2583. with self.assertRaisesRegexp(ValueError, "execution_mode must be one of"):
  2584. c = config_pb2.ConfigProto()
  2585. ops.enable_eager_execution(c, execution_mode=c)
  2586. if __name__ == "__main__":
  2587. googletest.main()

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。