Browse Source

Merge pull request #1172 from novikov-alexander/alnovi/cached_session

cached_session for graph tests
tags/v0.110.4-Transformer-Model
Haiping GitHub 2 years ago
parent
commit
9ddff69963
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 156 additions and 16 deletions
  1. +2
    -1
      test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/WhileContextTestCase.cs
  2. +11
    -12
      test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs
  3. +143
    -3
      test/TensorFlowNET.Graph.UnitTest/PythonTest.cs

+ 2
- 1
test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/WhileContextTestCase.cs View File

@@ -1,5 +1,6 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Linq;
using Tensorflow;
using static Tensorflow.Binding;

@@ -23,7 +24,7 @@ namespace TensorFlowNET.UnitTest.ControlFlowTest
private void _testWhileContextHelper(int maximum_iterations)
{
// TODO: implement missing code dependencies
var sess = this.cached_session();
using var sess = this.cached_session();
var i = constant_op.constant(0, name: "i");
var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, ops.convert_to_tensor(10), name: "c"));
var b = new Func<Tensor, Tensor>(x => math_ops.add(x, 1, name: "c"));


+ 11
- 12
test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs View File

@@ -388,22 +388,21 @@ namespace TensorFlowNET.UnitTest.Gradient

}

[Ignore("TODO")]
[TestMethod]
public void testBoundaryContinue()
{
//@test_util.run_v1_only("b/120545219")
//def testBoundaryContinue(self):
// # Test that we differentiate both 'x' and 'y' correctly when x is a
// # predecessor of y.
// with self.cached_session():
// x = constant(1.0)
// y = x * 2.0
// z = y * 3.0
// grads = gradients.gradients(z, [x, y])
// self.assertTrue(all(x is not None for x in grads))
// self.assertEqual(6.0, grads[0].eval())
// Test that we differentiate both 'x' and 'y' correctly when x is a
// predecessor of y.

using (self.cached_session())
{
var x = tf.constant(1.0);
var y = x * 2.0;
var z = y * 3.0;
var grads = tf.gradients(z, new[] { x, y });
self.assertTrue(all(grads.Select(x => x != null)));
self.assertEqual(6.0, grads[0].eval());
}
}

[Ignore("TODO")]


+ 143
- 3
test/TensorFlowNET.Graph.UnitTest/PythonTest.cs View File

@@ -6,6 +6,8 @@ using System.Collections;
using System.Linq;
using Tensorflow;
using static Tensorflow.Binding;
using OneOf.Types;
using System.Collections.Generic;

namespace TensorFlowNET.UnitTest
{
@@ -139,6 +141,21 @@ namespace TensorFlowNET.UnitTest

#region tensor evaluation and test session

private Session _cached_session = null;
private Graph _cached_graph = null;
private object _cached_config = null;
private bool _cached_force_gpu = false;

private void _ClearCachedSession()
{
if (self._cached_session != null)
{
self._cached_session.Dispose();
self._cached_session = null;
}
}


//protected object _eval_helper(Tensor[] tensors)
//{
// if (tensors == null)
@@ -203,10 +220,56 @@ namespace TensorFlowNET.UnitTest
}
}


public Session cached_session()
///Returns a TensorFlow Session for use in executing tests.
public Session cached_session(
Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false)
{
throw new NotImplementedException();
// This method behaves differently than self.session(): for performance reasons
// `cached_session` will by default reuse the same session within the same
// test.The session returned by this function will only be closed at the end
// of the test(in the TearDown function).

// Use the `use_gpu` and `force_gpu` options to control where ops are run.If
// `force_gpu` is True, all ops are pinned to `/ device:GPU:0`. Otherwise, if
// `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as
// possible.If both `force_gpu and `use_gpu` are False, all ops are pinned to
// the CPU.

// Example:
// python
// class MyOperatorTest(test_util.TensorFlowTestCase) :
// def testMyOperator(self):
// with self.cached_session() as sess:
// valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
// result = MyOperator(valid_input).eval()
// self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
// invalid_input = [-1.0, 2.0, 7.0]
// with self.assertRaisesOpError("negative input not supported"):
// MyOperator(invalid_input).eval()


// Args:
// graph: Optional graph to use during the returned session.
// config: An optional config_pb2.ConfigProto to use to configure the
// session.
// use_gpu: If True, attempt to run as many ops as possible on GPU.
// force_gpu: If True, pin all ops to `/device:GPU:0`.

// Yields:
// A Session object that should be used as a context manager to surround
// the graph building and execution code in a test case.


// TODO:
// if context.executing_eagerly():
// return self._eval_helper(tensors)
// else:
{
var sess = self._get_cached_session(
graph, config, force_gpu, crash_if_inconsistent_args: true);
using var cached = self._constrain_devices_and_set_default(sess, use_gpu, force_gpu);
return cached;
}
}

//Returns a TensorFlow Session for use in executing tests.
@@ -254,6 +317,39 @@ namespace TensorFlowNET.UnitTest
return s.as_default();
}

private Session _constrain_devices_and_set_default(Session sess, bool use_gpu, bool force_gpu)
{
// Set the session and its graph to global default and constrain devices."""
if (tf.executing_eagerly())
return null;
else {
sess.graph.as_default();
sess.as_default();
{
if (force_gpu)
{
// TODO:

// Use the name of an actual device if one is detected, or
// '/device:GPU:0' otherwise
/* var gpu_name = gpu_device_name();
if (!gpu_name)
gpu_name = "/device:GPU:0"
using (sess.graph.device(gpu_name)) {
yield return sess;
}*/
return sess;
}
else if (use_gpu)
return sess;
else
using (sess.graph.device("/device:CPU:0"))
return sess;
}
}
}

// See session() for details.
private Session _create_session(Graph graph, object cfg, bool forceGpu)
{
@@ -298,6 +394,50 @@ namespace TensorFlowNET.UnitTest
return new Session(graph);//, config = prepare_config(config))
}

private Session _get_cached_session(
Graph graph = null,
object config = null,
bool force_gpu = false,
bool crash_if_inconsistent_args = true)
{
// See cached_session() for documentation.
if (self._cached_session == null)
{
var sess = self._create_session(graph, config, force_gpu);
self._cached_session = sess;
self._cached_graph = graph;
self._cached_config = config;
self._cached_force_gpu = force_gpu;
return sess;
} else {

if (crash_if_inconsistent_args && !self._cached_graph.Equals(graph))
throw new ValueError(@"The graph used to get the cached session is
different than the one that was used to create the
session. Maybe create a new session with
self.session()");
if (crash_if_inconsistent_args && !self._cached_config.Equals(config)) {
throw new ValueError(@"The config used to get the cached session is
different than the one that was used to create the
session. Maybe create a new session with
self.session()");
}
if (crash_if_inconsistent_args && !self._cached_force_gpu.Equals(force_gpu)) {
throw new ValueError(@"The force_gpu value used to get the cached session is
different than the one that was used to create the
session. Maybe create a new session with
self.session()");
}
return _cached_session;
}
}

[TestCleanup]
public void Cleanup()
{
_ClearCachedSession();
}

#endregion

public void AssetSequenceEqual<T>(T[] a, T[] b)


Loading…
Cancel
Save