|
|
@@ -221,7 +221,7 @@ namespace TensorFlowNET.UnitTest |
|
|
|
} |
|
|
|
|
|
|
|
///Returns a TensorFlow Session for use in executing tests. |
|
|
|
public IEnumerable<Session> cached_session( |
|
|
|
public Session cached_session( |
|
|
|
Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false) |
|
|
|
{ |
|
|
|
// This method behaves differently than self.session(): for performance reasons |
|
|
@@ -267,9 +267,8 @@ namespace TensorFlowNET.UnitTest |
|
|
|
{ |
|
|
|
var sess = self._get_cached_session( |
|
|
|
graph, config, force_gpu, crash_if_inconsistent_args: true); |
|
|
|
var cached = self._constrain_devices_and_set_default(sess, use_gpu, force_gpu); |
|
|
|
return cached; |
|
|
|
|
|
|
|
using var cached = self._constrain_devices_and_set_default(sess, use_gpu, force_gpu); |
|
|
|
return cached; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
@@ -318,13 +317,12 @@ namespace TensorFlowNET.UnitTest |
|
|
|
return s.as_default(); |
|
|
|
} |
|
|
|
|
|
|
|
private IEnumerable<Session> _constrain_devices_and_set_default(Session sess, bool use_gpu, bool force_gpu) |
|
|
|
private Session _constrain_devices_and_set_default(Session sess, bool use_gpu, bool force_gpu) |
|
|
|
{ |
|
|
|
// Set the session and its graph to global default and constrain devices.""" |
|
|
|
// if context.executing_eagerly(): |
|
|
|
// yield None |
|
|
|
// else: |
|
|
|
{ |
|
|
|
if (tf.executing_eagerly()) |
|
|
|
return null; |
|
|
|
else { |
|
|
|
sess.graph.as_default(); |
|
|
|
sess.as_default(); |
|
|
|
{ |
|
|
@@ -340,13 +338,13 @@ namespace TensorFlowNET.UnitTest |
|
|
|
using (sess.graph.device(gpu_name)) { |
|
|
|
yield return sess; |
|
|
|
}*/ |
|
|
|
yield return sess; |
|
|
|
return sess; |
|
|
|
} |
|
|
|
else if (use_gpu) |
|
|
|
yield return sess; |
|
|
|
return sess; |
|
|
|
else |
|
|
|
using (sess.graph.device("/device:CPU:0")) |
|
|
|
yield return sess; |
|
|
|
return sess; |
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|