Browse Source

upgrade to tensorflow v1.13-rc1

tags/v0.8.0
haiping008 6 years ago
parent
commit
7c420df94f
9 changed files with 57 additions and 26 deletions
  1. +4
    -1
      src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs
  2. +4
    -3
      src/TensorFlowNET.Core/Gradients/math_grad.py.cs
  3. +11
    -6
      src/TensorFlowNET.Core/Graphs/Graph.cs
  4. +5
    -0
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  5. +5
    -12
      src/TensorFlowNET.Core/Variables/RefVariable.cs
  6. +25
    -0
      src/TensorFlowNET.Core/ops.py.cs
  7. BIN
      src/TensorFlowNET.Core/runtimes/win-x64/native/tensorflow.dll
  8. +2
    -2
      test/TensorFlowNET.Examples/LinearRegression.cs
  9. +1
    -2
      test/TensorFlowNET.UnitTest/GradientTest.cs

+ 4
- 1
src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs View File

@@ -64,6 +64,8 @@ namespace Tensorflow
// Get a uid for this call to gradients that can be used to help
// cluster ops for compilation.
var gradient_uid = ops.get_default_graph().unique_name("uid");
ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name: "y");
xs = ops.internal_convert_n_to_tensor_or_indexed_slices(xs, name: "x", as_ref: true);
grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, gradient_uid);

/**
@@ -148,7 +150,7 @@ namespace Tensorflow
}
else
{
in_grads = _NonEagerInputs(op, xs).Select(x => new Tensor(IntPtr.Zero)).ToArray();
in_grads = new Tensor[_NonEagerInputs(op, xs).Count()];
}

var inputs = _NonEagerInputs(op, xs).ToList();
@@ -226,6 +228,7 @@ namespace Tensorflow

private static Tensor[] _MaybeCompile(string scope, Operation op, Tensor out_grads, Action func, Func<Operation, Tensor, Tensor[]> grad_fn)
{
scope = scope.EndsWith("/") ? scope.Substring(0, scope.Length - 1) : scope;
return grad_fn(op, out_grads);
}



+ 4
- 3
src/TensorFlowNET.Core/Gradients/math_grad.py.cs View File

@@ -72,9 +72,10 @@ namespace Tensorflow

public static bool _ShapesFullySpecifiedAndEqual(Tensor x, Tensor y, Tensor grad)
{
return false;
/*return string.Join(",", x.shape).Equals(string.Join(",", y.shape)) &&
string.Join(",", x.shape).Equals(string.Join(",", grad.shape));*/
if (x.NDims == 0 && y.NDims == 0 && grad.NDims == 0) return true;

return string.Join(",", x.shape).Equals(string.Join(",", y.shape)) &&
string.Join(",", x.shape).Equals(string.Join(",", grad.shape));
}

public static (Tensor, Tensor) _SumGrad(Operation op, Tensor grad)


+ 11
- 6
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -41,17 +41,20 @@ namespace Tensorflow
_graph_key = $"grap-key-{ops.uid()}/";
}

public T as_graph_element<T>(T obj, bool allow_tensor = true, bool allow_operation = true)
public object as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true)
{
return _as_graph_element_locked(obj, allow_tensor, allow_operation);
}

private Func<object> _as_graph_element(object obj)
private Tensor _as_graph_element(object obj)
{
if (obj is RefVariable var)
return var._as_graph_element();

return null;
}

private T _as_graph_element_locked<T>(T obj, bool allow_tensor = true, bool allow_operation = true)
private object _as_graph_element_locked(object obj, bool allow_tensor = true, bool allow_operation = true)
{
string types_str = "";

@@ -69,12 +72,14 @@ namespace Tensorflow
}

var temp_obj = _as_graph_element(obj);
if (temp_obj != null)
obj = temp_obj;

if (obj is Tensor tensor && allow_tensor)
{
if (tensor.Graph.Equals(this))
{
return obj;
return tensor;
}
else
{
@@ -85,7 +90,7 @@ namespace Tensorflow
{
if (op.Graph.Equals(this))
{
return obj;
return op;
}
else
{
@@ -93,7 +98,7 @@ namespace Tensorflow
}
}

throw new Exception($"Can not convert a {typeof(T).Name} into a {types_str}.");
throw new Exception($"Can not convert a {obj.GetType().Name} into a {types_str}.");
}

public void add_to_collection<T>(string name, T value)


+ 5
- 0
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -35,6 +35,11 @@ namespace Tensorflow
c_api.TF_DeleteSessionOptions(opts);
}

public virtual NDArray run(RefVariable fetches, FeedItem[] feed_dict = null)
{
return _run(fetches, feed_dict);
}

public virtual NDArray run(Tensor fetches, FeedItem[] feed_dict = null)
{
return _run(fetches, feed_dict);


+ 5
- 12
src/TensorFlowNET.Core/Variables/RefVariable.cs View File

@@ -109,20 +109,13 @@ namespace Tensorflow
});
}

public Tensor _ref()
{
return _variable;
}
public Tensor _ref() => _variable;

public Tensor value()
{
return _snapshot;
}
public Tensor value() => _snapshot;

public Tensor _AsTensor()
{
return _snapshot;
}
public Tensor _AsTensor() => _snapshot;

public Tensor _as_graph_element() => _variable;

public Tensor _TensorConversionFunction(bool as_ref = false)
{


+ 25
- 0
src/TensorFlowNET.Core/ops.py.cs View File

@@ -329,6 +329,11 @@ namespace Tensorflow
};
}

public static Tensor[] convert_n_to_tensor_or_indexed_slices(Tensor[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = "")
{
return internal_convert_n_to_tensor_or_indexed_slices(values, dtype: dtype, name: name);
}

public static Tensor convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = "")
{
return internal_convert_to_tensor_or_indexed_slices(value: value, dtype: dtype, name: name, as_ref: false);
@@ -339,6 +344,26 @@ namespace Tensorflow
return value;
}

public static Tensor[] internal_convert_n_to_tensor_or_indexed_slices(Tensor[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = "", bool as_ref = false)
{
var ret = new List<Tensor>();

foreach(var (i, value) in Python.enumerate(values))
{
if (value == null)
{
ret.Add(value);
}
else
{
var n = string.IsNullOrEmpty(name) ? "" : $"{name}_{i}";
ret.Add(internal_convert_to_tensor_or_indexed_slices(value, dtype: dtype, name: n, as_ref: as_ref));
}
}

return ret.ToArray();
}

public static Tensor[] internal_convert_n_to_tensor<T>(T[] values, DataType dtype = DataType.DtInvalid,
string name = "", DataType preferred_dtype = DataType.DtInvalid,
bool as_ref = false)


BIN
src/TensorFlowNET.Core/runtimes/win-x64/native/tensorflow.dll View File


+ 2
- 2
test/TensorFlowNET.Examples/LinearRegression.cs View File

@@ -80,9 +80,9 @@ namespace TensorFlowNET.Examples
new FeedItem(X, train_X),
new FeedItem(Y, train_Y)
});
var rW = sess.run(W);
Console.WriteLine($"Epoch: {epoch + 1} cost={c} " +
$"W={sess.run(W)} b={sess.run(b)}");
$"W={rW} b={sess.run(b)}");
}
}



+ 1
- 2
test/TensorFlowNET.UnitTest/GradientTest.cs View File

@@ -23,8 +23,7 @@ namespace TensorFlowNET.UnitTest
Assert.AreEqual(ys.op.inputs[0].name, "Const:0");
Assert.AreEqual(ys.op.inputs[1].name, "mul:0");

var xs = new Tensor[] { a, b };
var g = tf.gradients(ys, xs, stop_gradients: new Tensor[] { a, b });
var g = tf.gradients(ys, new Tensor[] { a, b }, stop_gradients: new Tensor[] { a, b });
Assert.AreEqual(g[0].name, "gradients/Fill:0");
Assert.AreEqual(g[1].name, "gradients/Fill:0");
}


Loading…
Cancel
Save