Browse Source

merge with

tags/v0.9
Oceania2018 6 years ago
parent
commit
44316321cf
11 changed files with 2963 additions and 169 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/Keras/Layers/Conv.cs
  2. +59
    -57
      src/TensorFlowNET.Core/Operations/Operation.cs
  3. +0
    -22
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Operations/nn_ops.cs
  5. +6
    -0
      src/TensorFlowNET.Core/Sessions/Session.cs
  6. +33
    -18
      test/TensorFlowNET.UnitTest/PythonTest.cs
  7. +1
    -68
      test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs
  8. +507
    -0
      test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs
  9. +1104
    -0
      test/TensorFlowNET.UnitTest/gradients_test/gradients_test.py
  10. +1243
    -0
      test/TensorFlowNET.UnitTest/nn_test/nn_test.py
  11. +7
    -1
      test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs

+ 2
- 2
src/TensorFlowNET.Core/Keras/Layers/Conv.cs View File

@@ -101,8 +101,8 @@ namespace Tensorflow.Keras.Layers
throw new NotImplementedException("call channels_first");
}
else
{
outputs = nn_ops.bias_add(outputs, bias, data_format: "NHWC");
{
outputs = nn_ops.bias_add(outputs, bias._AsTensor(), data_format: "NHWC");
}
}



+ 59
- 57
src/TensorFlowNET.Core/Operations/Operation.cs View File

@@ -10,22 +10,22 @@ using System.Text;
namespace Tensorflow
{
/// <summary>
/// Represents a graph node that performs computation on tensors.
///
/// An `Operation` is a node in a TensorFlow `Graph` that takes zero or
/// more `Tensor` objects as input, and produces zero or more `Tensor`
/// objects as output. Objects of type `Operation` are created by
/// calling an op constructor(such as `tf.matmul`)
/// or `tf.Graph.create_op`.
///
/// For example `c = tf.matmul(a, b)` creates an `Operation` of type
/// "MatMul" that takes tensors `a` and `b` as input, and produces `c`
/// as output.
///
/// After the graph has been launched in a session, an `Operation` can
/// be executed by passing it to
/// `tf.Session.run`.
/// <summary>
/// Represents a graph node that performs computation on tensors.
///
/// An `Operation` is a node in a TensorFlow `Graph` that takes zero or
/// more `Tensor` objects as input, and produces zero or more `Tensor`
/// objects as output. Objects of type `Operation` are created by
/// calling an op constructor(such as `tf.matmul`)
/// or `tf.Graph.create_op`.
///
/// For example `c = tf.matmul(a, b)` creates an `Operation` of type
/// "MatMul" that takes tensors `a` and `b` as input, and produces `c`
/// as output.
///
/// After the graph has been launched in a session, an `Operation` can
/// be executed by passing it to
/// `tf.Session.run`.
/// `op.run()` is a shortcut for calling `tf.get_default_session().run(op)`.
/// </summary>
public partial class Operation : ITensorOrOperation
@@ -271,47 +271,49 @@ namespace Tensorflow
return base.Equals(obj);
}
/// <summary>
/// Update the input to this operation at the given index.
///
/// NOTE: This is for TF internal use only.Please don't use it.
/// </summary>
/// <param name="index">the index of the input to update.</param>
/// <param name="tensor"> the Tensor to be used as the input at the given index.</param>
public void _update_input(int index, Tensor tensor)
{
_assert_same_graph(tensor);
var input = _tf_input(index);
/// <summary>
/// Update the input to this operation at the given index.
///
/// NOTE: This is for TF internal use only.Please don't use it.
/// </summary>
/// <param name="index">the index of the input to update.</param>
/// <param name="tensor"> the Tensor to be used as the input at the given index.</param>
public void _update_input(int index, Tensor tensor)
{
_assert_same_graph(tensor);
var input = _tf_input(index);
var output = tensor._as_tf_output();
// Reset cached inputs.
_inputs = null;// new InputList(new Tensor[] { tensor }); // is this right? original code: self._inputs_val=None
// TODO: implement below code dependencies
c_api.TF_UpdateEdge(graph, output, input, status);
}

private void _assert_same_graph(Tensor tensor)
{
//TODO: implement
}

/// <summary>
/// Create and return a new TF_Output for output_idx'th output of this op.
/// </summary>
public TF_Output _tf_output(int output_idx)
{
var tf_output = new TF_Output(op, output_idx);
return tf_output;
}

/// <summary>
/// Create and return a new TF_Input for input_idx'th input of this op.
/// </summary>
public TF_Input _tf_input(int input_idx)
{
var tf_input = new TF_Input(op, input_idx);
return tf_input;
}
}
}
_inputs = null;
// after the c_api call next time _inputs is accessed
// the updated inputs are reloaded from the c_api
c_api.TF_UpdateEdge(_graph, output, input, status);
//var updated_inputs = inputs;
}
private void _assert_same_graph(Tensor tensor)
{
//TODO: implement
}
/// <summary>
/// Create and return a new TF_Output for output_idx'th output of this op.
/// </summary>
public TF_Output _tf_output(int output_idx)
{
var tf_output = new TF_Output(op, output_idx);
return tf_output;
}
/// <summary>
/// Create and return a new TF_Input for input_idx'th input of this op.
/// </summary>
public TF_Input _tf_input(int input_idx)
{
var tf_input = new TF_Input(op, input_idx);
return tf_input;
}
}
}

+ 0
- 22
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -290,33 +290,11 @@ namespace Tensorflow
{
// TODO: here a chunk of original code is missing
/*
if fn1 is not None:
if true_fn is not None:
raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.")
true_fn = fn1
elif true_fn is None:
raise TypeError("cond(): true_fn argument required")
if fn2 is not None:
if false_fn is not None:
raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.")
false_fn = fn2
elif false_fn is None:
raise TypeError("cond(): false_fn argument required")

if not callable(true_fn):
raise TypeError("true_fn must be callable.")
if not callable(false_fn):
raise TypeError("false_fn must be callable.")

with ops.name_scope(name, "cond", [pred]):
if context.executing_eagerly():
if pred:
return _UnpackIfSingleton(true_fn())
return _UnpackIfSingleton(false_fn())

# Add the Switch to the graph.
if isinstance(pred, bool):
raise TypeError("pred must not be a Python bool")
*/

// Add the Switch to the graph.


+ 1
- 1
src/TensorFlowNET.Core/Operations/nn_ops.cs View File

@@ -30,7 +30,7 @@ namespace Tensorflow
/// <param name="name"></param>
/// <returns></returns>
public static Tensor bias_add(Tensor value,
RefVariable bias,
Tensor bias,
string data_format = null,
string name = null)
{


+ 6
- 0
src/TensorFlowNET.Core/Sessions/Session.cs View File

@@ -38,6 +38,12 @@ namespace Tensorflow
Status.Check(true);
}

public Session as_default()
{
tf.defaultSession = this;
return this;
}

public static Session LoadFromSavedModel(string path)
{
var graph = c_api.TF_NewGraph();


+ 33
- 18
test/TensorFlowNET.UnitTest/PythonTest.cs View File

@@ -132,24 +132,44 @@ namespace TensorFlowNET.UnitTest
}
/// <summary>
/// Evaluates tensors and returns numpy values.
/// <param name="tensors">A Tensor or a nested list/tuple of Tensors.</param>
/// This function is used in many original tensorflow unit tests to evaluate tensors
/// in a test session with special settings (for instance constant folding off)
///
/// </summary>
/// <returns> tensors numpy values.</returns>
[Obsolete("Why do we need this function? we already have Tensor.eval().")]
public object evaluate(params Tensor[] tensors)
public T evaluate<T>(Tensor tensor)
{
var results = new Dictionary<string, NDArray>();
// if context.executing_eagerly():
// return self._eval_helper(tensors)
// else:
{
var sess = ops.get_default_session();
if (sess == None)
with(self.session(), s => sess = s);
return sess.run(tensors);
if (sess == null)
sess = self.session();
T t_result = (T)(object)null;
with<Session>(sess, s =>
{
var ndarray=tensor.eval();
if (typeof(T) == typeof(double))
{
double d = ndarray;
t_result = (T)(object)d;
}
else if (typeof(T) == typeof(int))
{
int d = ndarray;
t_result = (T) (object) d;
}
else
{
t_result = (T)(object)ndarray;
}
});
return t_result;
}
}
//Returns a TensorFlow Session for use in executing tests.
public Session session(Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false)
{
@@ -189,16 +209,11 @@ namespace TensorFlowNET.UnitTest
//if (context.executing_eagerly())
// yield None
//else
{
with<Session>(self._create_session(graph, config, force_gpu), sess =>
{
with(self._constrain_devices_and_set_default(sess, use_gpu, force_gpu), (x) =>
{
s = sess;
});
});
}
return s;
//{
s = self._create_session(graph, config, force_gpu);
self._constrain_devices_and_set_default(s, use_gpu, force_gpu);
//}
return s.as_default();
}
private IPython _constrain_devices_and_set_default(Session sess, bool useGpu, bool forceGpu)


+ 1
- 68
test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs View File

@@ -91,74 +91,7 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test
});
}
[Ignore("Todo")]
[TestMethod]
public void testCondTrueLegacy()
{
// def testCondTrueLegacy(self):
// x = constant_op.constant(2)
// y = constant_op.constant(5)
// z = control_flow_ops.cond(
// math_ops.less(x, y),
// fn1=lambda: math_ops.multiply(x, 17),
// fn2=lambda: math_ops.add(y, 23))
// self.assertEquals(self.evaluate(z), 34)
}
[Ignore("Todo")]
[TestMethod]
public void testCondFalseLegacy()
{
// def testCondFalseLegacy(self):
// x = constant_op.constant(2)
// y = constant_op.constant(1)
// z = control_flow_ops.cond(
// math_ops.less(x, y),
// fn1=lambda: math_ops.multiply(x, 17),
// fn2=lambda: math_ops.add(y, 23))
// self.assertEquals(self.evaluate(z), 24)
}
[Ignore("Todo")]
[TestMethod]
public void testCondMissingArg1()
{
// def testCondMissingArg1(self):
// x = constant_op.constant(1)
// with self.assertRaises(TypeError):
// control_flow_ops.cond(True, false_fn=lambda: x)
}
[Ignore("Todo")]
[TestMethod]
public void testCondMissingArg2()
{
// def testCondMissingArg2(self):
// x = constant_op.constant(1)
// with self.assertRaises(TypeError):
// control_flow_ops.cond(True, lambda: x)
}
[Ignore("Todo")]
[TestMethod]
public void testCondDuplicateArg1()
{
// def testCondDuplicateArg1(self):
// x = constant_op.constant(1)
// with self.assertRaises(TypeError):
// control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x)
}
[Ignore("Todo")]
[TestMethod]
public void testCondDuplicateArg2()
{
// def testCondDuplicateArg2(self):
// x = constant_op.constant(1)
// with self.assertRaises(TypeError):
// control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x)
}
// NOTE: all other test python test cases of this class are either not needed due to strong typing or dest a deprecated api
}
}

+ 507
- 0
test/TensorFlowNET.UnitTest/gradients_test/GradientsTest.cs View File

@@ -0,0 +1,507 @@
using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
namespace TensorFlowNET.UnitTest.gradients_test
{
[TestClass]
public class GradientsTest : PythonTest
{
//[Ignore("TODO")]
[TestMethod]
public void testGradients()
{
with(tf.Graph().as_default(), g =>
{
var inp = tf.constant(1.0, shape: new[]{32, 100}, name:"in");
var w = tf.constant(1.0, shape: new[] { 100, 10}, name:"w");
var b = tf.constant(1.0, shape: new[] { 10}, name:"b");
var xw = math_ops.matmul(inp, w, name: "xw");
var h = nn_ops.bias_add(xw, b, name: "h");
var w_grad = gradients_impl.gradients(new []{h}, new[] { w})[0];
self.assertEquals("MatMul", w_grad.op.type);
// TODO: Operation._original_op
//self.assertEquals(w_grad.op._original_op, xw.op);
self.assertTrue((bool)w_grad.op.get_attr("transpose_a"));
self.assertFalse((bool)w_grad.op.get_attr("transpose_b"));
});
}
[Ignore("TODO")]
[TestMethod]
public void testUnusedOutput()
{
//def testUnusedOutput(self):
// with ops.Graph().as_default():
// w = constant(1.0, shape=[2, 2])
// x = constant(1.0, shape=[2, 2])
// wx = math_ops.matmul(w, x)
// split_wx = array_ops.split(value=wx, num_or_size_splits=2, axis=0)
// c = math_ops.reduce_sum(split_wx[1])
// gw = gradients.gradients(c, [w])[0]
// self.assertEquals("MatMul", gw.op.type)
}
[Ignore("TODO")]
[TestMethod]
public void testColocateGradients()
{
//def testColocateGradients(self):
// with ops.Graph().as_default() as g:
// w = constant(1.0, shape=[1, 1])
// x = constant(1.0, shape=[1, 2])
// with g.device("/device:GPU:0"):
// wx = math_ops.matmul(w, x)
// gw = gradients.gradients(wx, [w], colocate_gradients_with_ops=True)[0]
// self.assertEqual(gw.op.colocation_groups(), wx.op.colocation_groups())
}
[Ignore("TODO")]
[TestMethod]
public void testColocateGradientsWithAggregation()
{
//def testColocateGradientsWithAggregation(self):
// with ops.Graph().as_default() as g:
// with g.device("/device:GPU:1"):
// w = constant(1.0, shape=[1, 1])
// x = constant(1.0, shape=[1, 2])
// y = constant(1.0, shape=[1, 2])
// wx = math_ops.matmul(w, x)
// wy = math_ops.matmul(w, y)
// with g.device("/device:GPU:0"):
// z = wx + wy
// gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0]
// self.assertEqual(gw1.op.colocation_groups(), wx.op.colocation_groups())
// gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0]
// self.assertTrue(wx.op.colocation_groups() != gw2.op.colocation_groups())
}
[Ignore("TODO")]
[TestMethod]
public void testColocateGradientsWithAggregationInMultipleDevices()
{
//def testColocateGradientsWithAggregationInMultipleDevices(self):
// with ops.Graph().as_default() as g:
// with g.device("/device:GPU:1"):
// w = constant(1.0, shape=[1, 1])
// x = constant(1.0, shape=[1, 2])
// y = constant(1.0, shape=[1, 2])
// with g.device("/task:1"):
// wx = math_ops.matmul(w, x)
// with g.device("/task:2"):
// wy = math_ops.matmul(w, y)
// with g.device("/device:GPU:0"):
// z = wx + wy
// gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0]
// self.assertEqual(gw1.op.colocation_groups(), w.op.colocation_groups())
// gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0]
// self.assertTrue(w.op.colocation_groups() != gw2.op.colocation_groups())
}
[Ignore("TODO")]
[TestMethod]
public void testColocateGradientsWithGateGradients()
{
//def testColocateGradientsWithGateGradients(self):
// if not test_util.is_gpu_available():
// self.skipTest("No GPU available")
// with ops.Graph().as_default() as g:
// with g.device("/device:CPU:0"):
// x = constant(1.0, shape=[1, 1])
// y = constant(1.0, shape=[1, 1])
// s = x + y
// with g.device("/device:GPU:0"):
// z = math_ops.reduce_sum(s)
// gz_x = gradients.gradients(z, [x], colocate_gradients_with_ops=True,
// gate_gradients=True)[0]
// with session.Session():
// # Make sure the placer doesn't complain.
// self.evaluate(gz_x)
}
[Ignore("TODO")]
[TestMethod]
public void testBoundaryStop()
{
//def testBoundaryStop(self):
// # Test that we don't differentiate 'x'. The gradient function for 'x' is
// # set explicitly to None so we will get an exception if the gradient code
// # tries to differentiate 'x'.
// with ops.Graph().as_default():
// c = constant(1.0)
// x = array_ops.identity(c)
// y = x + 1.0
// z = y + 1
// grads = gradients.gradients(z, [x])
// self.assertTrue(all(x is not None for x in grads))
}
[Ignore("TODO")]
[TestMethod]
public void testBoundaryContinue()
{
//@test_util.run_v1_only("b/120545219")
//def testBoundaryContinue(self):
// # Test that we differentiate both 'x' and 'y' correctly when x is a
// # predecessor of y.
// with self.cached_session():
// x = constant(1.0)
// y = x * 2.0
// z = y * 3.0
// grads = gradients.gradients(z, [x, y])
// self.assertTrue(all(x is not None for x in grads))
// self.assertEqual(6.0, grads[0].eval())
}
[Ignore("TODO")]
[TestMethod]
public void testAggregationMethodAccumulateN()
{
//@test_util.run_v1_only("b/120545219")
//def testAggregationMethodAccumulateN(self):
// with self.cached_session():
// x = constant(1.0)
// y = x * 2.0
// z = y + y + y + y + y + y + y + y + y + y
// grads = gradients.gradients(
// z, [x, y],
// aggregation_method=gradients.AggregationMethod.
// EXPERIMENTAL_ACCUMULATE_N)
// self.assertTrue(all(x is not None for x in grads))
// self.assertEqual(20.0, grads[0].eval())
// self.assertEqual(10.0, grads[1].eval())
}
[Ignore("TODO")]
[TestMethod]
public void testAggregationMethodAddN()
{
//@test_util.run_v1_only("b/120545219")
//def testAggregationMethodAddN(self):
// with self.cached_session():
// x = constant(1.0)
// y = x * 2.0
// z = y + y + y + y + y + y + y + y + y + y
// grads = gradients.gradients(
// z, [x, y], aggregation_method=gradients.AggregationMethod.ADD_N)
// self.assertTrue(all(x is not None for x in grads))
// self.assertEqual(20.0, grads[0].eval())
// self.assertEqual(10.0, grads[1].eval())
}
[Ignore("TODO")]
[TestMethod]
public void testAggregationMethodTree()
{
//@test_util.run_v1_only("b/120545219")
//def testAggregationMethodTree(self):
// with self.cached_session():
// x = constant(1.0)
// y = x * 2.0
// z = y + y + y + y + y + y + y + y + y + y
// grads = gradients.gradients(
// z, [x, y],
// aggregation_method=gradients.AggregationMethod.EXPERIMENTAL_TREE)
// self.assertTrue(all(x is not None for x in grads))
// self.assertEqual(20.0, grads[0].eval())
// self.assertEqual(10.0, grads[1].eval())
}
[Ignore("TODO")]
[TestMethod]
public void testNoGradientForStringOutputs()
{
//def testNoGradientForStringOutputs(self):
// with ops.Graph().as_default():
// def _TestOpGrad(_, float_grad, string_grad):
// """Gradient function for TestStringOutput."""
// self.assertEquals(float_grad.dtype, dtypes.float32)
// self.assertFalse(string_grad)
// return float_grad
// ops.RegisterGradient("TestStringOutput")(_TestOpGrad)
// c = constant(1.0)
// x, _ = test_ops.test_string_output(c)
// z = x * 2.0
// w = z * 3.0
// grads = gradients.gradients(z, [c])
// self.assertTrue(isinstance(grads[0], ops.Tensor))
// grads = gradients.gradients(w, [c])
// self.assertTrue(isinstance(grads[0], ops.Tensor))
}
[Ignore("TODO")]
[TestMethod]
public void testSingletonIndexedSlices()
{
//def testSingletonIndexedSlices(self):
// with ops.Graph().as_default():
// x = array_ops.placeholder(dtypes.float32)
// y = array_ops.identity(x)
// dy = ops.IndexedSlices(
// array_ops.placeholder(dtypes.float32),
// array_ops.placeholder(dtypes.int32))
// dx, = gradients.gradients(y, x, grad_ys=dy)
// # The IndexedSlices gradient of tf.identity is the identity map.
// with self.cached_session() as sess:
// vdx, vdy = sess.run(
// [dx, dy], feed_dict={x: [1.0], dy.indices: [0], dy.values: [2.0]})
// self.assertEqual(vdx, vdy)
}
[Ignore("TODO")]
[TestMethod]
public void testNonDifferentiableSwitchInWhileLoop()
{
//@test_util.run_v1_only("b/120545219")
//def testNonDifferentiableSwitchInWhileLoop(self):
// with ops.Graph().as_default():
// v = array_ops.placeholder(dtypes.float32, [])
// def _Step(i, a, ta):
// a += math_ops.cast(v, dtypes.int32)
// return (i + 1, a, ta.write(i, a))
// n = 4
// i, _, ta = control_flow_ops.while_loop(
// lambda i, *_: i < n,
// _Step, [0, 0, tensor_array_ops.TensorArray(
// dtypes.int32, size=n)])
// target = ta.read(i - 1)
// grad, = gradients.gradients(target, v)
// self.assertIsNone(grad)
}
[Ignore("TODO")]
[TestMethod]
public void testVariableReadValueGradient()
{
//def testVariableReadValueGradient(self):
// with ops.Graph().as_default():
// init = constant_op.constant(100.0)
// var = variables.Variable(init)
// gradient = gradients.gradients(var.read_value(), var)
// self.assertIsNotNone(gradient)
}
[Ignore("TODO")]
[TestMethod]
public void testVariableAsGraphElementGradient()
{
//def testVariableAsGraphElementGradient(self):
// with ops.Graph().as_default() as graph:
// init = constant_op.constant(100.0)
// var = variables.Variable(init)
// gradient = gradients.gradients(graph.as_graph_element(var), var)
// self.assertIsNotNone(gradient)
}
[Ignore("TODO")]
[TestMethod]
public void testVariableRefGradient()
{
//@test_util.run_v1_only("b/120545219")
//def testVariableRefGradient(self):
// with ops.Graph().as_default():
// init = constant_op.constant(100.0)
// var = variables.VariableV1(init)
// gradient = gradients.gradients(var._ref(), var)
// self.assertIsNotNone(gradient)
}
[Ignore("TODO")]
[TestMethod]
public void testDependentYs()
{
//@test_util.run_v1_only("b/120545219")
//def testDependentYs(self):
// with self.cached_session():
// x = constant_op.constant(3.0)
// y = math_ops.square(x)
// y1 = math_ops.square(y)
// y2 = math_ops.square(y1)
// g = gradients.gradients([y, y2], x)
// self.assertAllClose(17502.0, g[0].eval())
// g = gradients.gradients(y + y2, x)
// self.assertAllClose(17502.0, g[0].eval())
// z = array_ops.identity(y)
// z2 = array_ops.identity(y2)
// g = gradients.gradients([z, z2], x)
// self.assertAllClose(17502.0, g[0].eval())
}
[Ignore("TODO")]
[TestMethod]
public void testPartialDerivatives()
{
//@test_util.run_v1_only("b/120545219")
//def testPartialDerivatives(self):
// with self.cached_session():
// x = constant_op.constant(1.)
// y = 2 * x
// z = x + y
// totalg = gradients.gradients(z, [x, y])
// self.assertEqual([3.0, 1.0], [g.eval() for g in totalg])
// partialg = gradients.gradients(z, [x, y], stop_gradients=[x, y])
// self.assertEqual([1.0, 1.0], [g.eval() for g in partialg])
}
[Ignore("TODO")]
[TestMethod]
public void testStopGradients()
{
//@test_util.run_v1_only("b/120545219")
//def testStopGradients(self):
// def _MakeGraph(rng, stop_gradients=()):
// def _FunctionOf(xs, k=3):
// return ops.convert_to_tensor(
// sum(math_ops.matmul(rng.rand(k, k), x) for x in xs)
// + rng.rand(k, k))
// a = _FunctionOf([])
// if "a" in stop_gradients: a = array_ops.stop_gradient(a)
// b = _FunctionOf([a])
// if "b" in stop_gradients: b = array_ops.stop_gradient(b)
// c = _FunctionOf([a, b])
// if "c" in stop_gradients: c = array_ops.stop_gradient(c)
// d = _FunctionOf([b, c])
// if "d" in stop_gradients: d = array_ops.stop_gradient(d)
// return dict(a=a, b=b, c=c, d=d)
// def _Gradients(ys, xs, **kwargs):
// dydxs = gradients.gradients(ys, xs, **kwargs)
// dydxs = [0. * x if dydx is None else dydx
// for x, dydx in zip(xs, dydxs)]
// return dydxs
// seed = np.random.randint(1000)
// cases = []
// subsets = [""] + "a b c d ab ac ad bc bd cd abc abd acd bcd abcd".split()
// graph = _MakeGraph(np.random.RandomState(seed))
// for constants in subsets:
// graph_with_stops = _MakeGraph(np.random.RandomState(seed), constants)
// for variables_ in subsets:
// # compute the gradient when stopped using tf.stop_gradients
// grad1 = _Gradients([graph_with_stops["d"]],
// [graph_with_stops[v] for v in variables_])
// # compute the gradient when stopped using the stop_gradients kwarg
// grad2 = _Gradients([graph["d"]],
// [graph[v] for v in variables_],
// stop_gradients=[graph[v] for v in constants])
// cases.append(dict(grad1=grad1, grad2=grad2,
// constants=constants, variables=variables_))
// # evaluate all tensors in one call to session.run for speed
// with self.cached_session() as sess:
// results = sess.run([(case["grad1"], case["grad2"]) for case in cases])
// for (npgrad1, npgrad2), case in zip(results, cases):
// for a, b in zip(npgrad1, npgrad2):
// np.testing.assert_allclose(a, b)
}
[Ignore("TODO")]
[TestMethod]
public void testUnconnectedGradientsNoneUnconnectedGradients()
{
//def testUnconnectedGradientsNoneUnconnectedGradients(self):
// with ops.Graph().as_default():
// x = constant(1.0, shape=[2, 2])
// y = constant(3.0, shape=[3, 1])
// grad = gradients.gradients(
// [y], [x], unconnected_gradients="none")
// self.assertIsNone(grad[0])
}
[Ignore("TODO")]
[TestMethod]
public void testUnconnectedGradientsZerosUnconnectedGradients()
{
//def testUnconnectedGradientsZerosUnconnectedGradients(self):
// with ops.Graph().as_default():
// x = constant(1.0, shape=[2, 2])
// y = constant(3.0, shape=[3, 1])
// grads = gradients.gradients(
// [y], [x], unconnected_gradients="zero")
// with self.cached_session() as sess:
// self.assertAllEqual([[0.0, 0.0], [0.0, 0.0]], self.evaluate(grads)[0])
}
[Ignore("TODO")]
[TestMethod]
public void testUnconnectedGradientsZeroConnectedGradients()
{
//def testUnconnectedGradientsZeroConnectedGradients(self):
// with ops.Graph().as_default():
// x = constant(1.0)
// y = x * 3.0
// grad = gradients.gradients(
// [y], [x], unconnected_gradients="zero")
// with self.cached_session() as sess:
// self.assertEquals(3.0, self.evaluate(grad)[0])
}
[Ignore("TODO")]
[TestMethod]
public void testUnknownUnconnectedGradientsValueGiven()
{
//def testUnknownUnconnectedGradientsValueGiven(self):
// with ops.Graph().as_default():
// x = constant(1.0)
// y = constant(1.0)
// with self.assertRaisesRegexp(
// ValueError, "Unknown value for unconnected_gradients: 'nonsense'"):
// gradients.gradients([y], [x], unconnected_gradients="nonsense")
}
/*
*/
}
}

+ 1104
- 0
test/TensorFlowNET.UnitTest/gradients_test/gradients_test.py
File diff suppressed because it is too large
View File


+ 1243
- 0
test/TensorFlowNET.UnitTest/nn_test/nn_test.py
File diff suppressed because it is too large
View File


+ 7
- 1
test/TensorFlowNET.UnitTest/ops_test/ControlDependenciesTest.cs View File

@@ -139,7 +139,7 @@ namespace TensorFlowNET.UnitTest.ops_test
var a_2 = constant_op.constant(3.0);
var a_3 = constant_op.constant(4.0);
var a_4 = constant_op.constant(5.0);
Operation b_1 = null, b_2 = null;
Tensor b_1 = null, b_2 = null;
with(g.control_dependencies(new[] { a_1, a_2, a_3, a_4 }), ctrl =>
{
b_1 = constant_op.constant(6.0);
@@ -157,6 +157,12 @@ namespace TensorFlowNET.UnitTest.ops_test
});
});
});
var z=tf.add(a_1, tf.multiply(b_2, b_1));
with(g.control_dependencies(new[] {z}), ctrl =>
{
var z1 = tf.add(a_3, tf.multiply(a_4, a_2));
});
//tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);
assertItemsEqual(b_1.op.control_inputs, new[] { a_1.op, a_2.op, a_3.op, a_4.op });
assertItemsEqual(b_2.op.control_inputs, b_1.op.control_inputs);
}


Loading…
Cancel
Save