Browse Source

TFE_TapeVariableAccessed

tags/v0.20
Oceania2018 5 years ago
parent
commit
ca6f8b2a46
15 changed files with 252 additions and 64 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/APIs/tf.math.cs
  2. +3
    -0
      src/TensorFlowNET.Core/Eager/EagerTensor.cs
  3. +3
    -0
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  4. +22
    -3
      src/TensorFlowNET.Core/Gradients/GradientActor.cs
  5. +5
    -0
      src/TensorFlowNET.Core/Gradients/Tape.cs
  6. +59
    -12
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  7. +102
    -40
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  8. +2
    -2
      src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs
  9. +14
    -1
      src/TensorFlowNET.Core/Operations/math_ops.cs
  10. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
  11. +1
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.Value.cs
  12. +15
    -0
      src/TensorFlowNET.Core/Tensors/constant_op.cs
  13. +18
    -2
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  14. +5
    -0
      src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs
  15. +1
    -1
      src/TensorFlowNET.Core/ops.cs

+ 1
- 1
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -390,7 +390,7 @@ namespace Tensorflow
=> x / ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y");

public Tensor pow<T1, T2>(T1 x, T2 y, string name = "pow")
=> gen_math_ops.pow(x, y, name: name);
=> math_ops.pow(x, y, name: name);

/// <summary>
/// Divides `x / y` elementwise, rounding toward the most negative integer.


+ 3
- 0
src/TensorFlowNET.Core/Eager/EagerTensor.cs View File

@@ -53,6 +53,9 @@ namespace Tensorflow.Eager

public static string GetFormattedString(TF_DataType dtype, NDArray nd)
{
if (nd.size == 0)
return "[]";

switch (dtype)
{
case TF_DataType.TF_STRING:


+ 3
- 0
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -375,6 +375,9 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern void TFE_TapeWatch(IntPtr tape, IntPtr tensor);

[DllImport(TensorFlowLibName)]
public static extern void TFE_TapeVariableAccessed(IntPtr variable);
[DllImport(TensorFlowLibName)]
public static extern IntPtr TFE_TapeGradient(IntPtr tape,
IntPtr[] target, int target_size,


+ 22
- 3
src/TensorFlowNET.Core/Gradients/GradientActor.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Eager;
using static Tensorflow.Binding;
@@ -65,7 +66,7 @@ namespace Tensorflow.Gradients
_tape.watch(x as EagerTensor);
}

public Tensor gradient(Tensor target, Tensor sources)
public Tensor gradient(Tensor target, Tensor source)
{
if(_recording)
{
@@ -76,15 +77,33 @@ namespace Tensorflow.Gradients
using var status = new Status();
var et = c_api.TFE_TapeGradient(_tape,
new [] { (target as EagerTensor).EagerTensorHandle }, 1,
new [] { (sources as EagerTensor).EagerTensorHandle }, 1,
new [] { (source as EagerTensor).EagerTensorHandle }, 1,
status);
status.Check(true);
return new EagerTensor(et);
}

public Tensor gradient(Tensor target, ResourceVariable[] sources)
{
if (_recording)
{
if (!_persistent)
_pop_tape();
}

using var status = new Status();
EagerTensorHandle et = c_api.TFE_TapeGradient(_tape,
new[] { (target as EagerTensor).EagerTensorHandle }, 1,
sources.Select(x => (x.handle as EagerTensor).EagerTensorHandle).ToArray(), sources.Length,
status);
status.Check(true);
return et;
}

public void Dispose()
{
if (_recording)
_pop_tape();
}
}
}

+ 5
- 0
src/TensorFlowNET.Core/Gradients/Tape.cs View File

@@ -25,6 +25,11 @@ namespace Tensorflow.Gradients
c_api.TFE_TapeSetRemove(tape);
}

public static void variable_accessed(ResourceVariable variable)
{
c_api.TFE_TapeVariableAccessed(variable.handle as EagerTensor);
}

public static bool IsDtypeTrainable(DataType dtype)
{
switch (dtype)


+ 59
- 12
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -220,6 +220,18 @@ namespace Tensorflow
/// <param name="name"></param>
public static Tensor identity(Tensor input, string name = null)
{
if (tf.context.executing_eagerly())
{
using var status = new Status();
EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Identity", name, new IntPtr[]
{
input as EagerTensor
}, 1, null, status);
status.Check(true);
return tensor;
}

var _op = _op_def_lib._apply_op_helper("Identity", name, new { input });

return _op.output;
@@ -258,14 +270,14 @@ namespace Tensorflow
if (tf.context.executing_eagerly())
{
using var status = new Status();
var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Fill", name, new IntPtr[]
{
dims as EagerTensor,
value as EagerTensor
}, 2, null, status);
status.Check(true);
return new EagerTensor(tensor);
return tensor;
}

var _op = _op_def_lib._apply_op_helper("Fill", name, new { dims, value });
@@ -281,6 +293,18 @@ namespace Tensorflow
/// <returns>A tuple of `Tensor` objects (r0, r1).</returns>
public static (Tensor, Tensor) broadcast_gradient_args(Tensor s0, Tensor s1, string name = "")
{
if (tf.context.executing_eagerly())
{
using var status = new Status();
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"BroadcastGradientArgs", name, new IntPtr[]
{
s0 as EagerTensor,
s1 as EagerTensor
}, 2, null, status);
status.Check(true);
}

var _op = _op_def_lib._apply_op_helper("BroadcastGradientArgs", name, new { s0, s1 });

return (_op.outputs[0], _op.outputs[1]);
@@ -371,10 +395,19 @@ namespace Tensorflow
{
if (tf.context.executing_eagerly())
{
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Shape", name, null,
input, "out_type", out_type);
return _result;
using var status = new Status();
EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Shape", name, new IntPtr[]
{
input as EagerTensor,
}, 1,
op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[]
{
"out_type", out_type
}, status),
status);
status.Check(true);
return tensor;
}

var _op = _op_def_lib._apply_op_helper("Shape", name, new { input, out_type });
@@ -455,12 +488,26 @@ namespace Tensorflow
{
if (tf.context.executing_eagerly())
{
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name,
"StridedSlice", name, null,
input, begin, end, strides, "begin_mask", begin_mask,
"end_mask", end_mask, "ellipsis_mask", ellipsis_mask,
"new_axis_mask", new_axis_mask, "shrink_axis_mask", shrink_axis_mask);
return _result;
using var status = new Status();
EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"StridedSlice", name, new IntPtr[]
{
input as EagerTensor,
begin as EagerTensor,
end as EagerTensor,
strides as EagerTensor,
}, 4,
op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[]
{
"begin_mask", begin_mask,
"end_mask", end_mask,
"ellipsis_mask", ellipsis_mask,
"new_axis_mask", new_axis_mask,
"shrink_axis_mask", shrink_axis_mask
}, status),
status);
status.Check(true);
return tensor;
}

var _op = _op_def_lib._apply_op_helper("StridedSlice", name, new


+ 102
- 40
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -173,10 +173,20 @@ namespace Tensorflow
{
try
{
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Prod", name, null,
input, axis, "keep_dims", keep_dims);
return _result;
using var status = new Status();
EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Prod", name, new IntPtr[]
{
input as EagerTensor,
axis as EagerTensor
}, 2,
op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[]
{
"keep_dims", keep_dims
}, status),
status);
status.Check(true);
return tensor;
}
catch (Exception)
{
@@ -236,10 +246,15 @@ namespace Tensorflow
{
if (tf.context.executing_eagerly())
{
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Add", name, null,
x, y);
return _result;
using var status = new Status();
EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Add", name, new IntPtr[]
{
x as EagerTensor,
y as EagerTensor
}, 2, null, status);
status.Check(true);
return tensor;
}

var _op = _op_def_lib._apply_op_helper("Add", name, args: new { x, y });
@@ -647,10 +662,14 @@ namespace Tensorflow
{
if (tf.context.executing_eagerly())
{
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Sqrt", name, null,
x);
return _result;
using var status = new Status();
EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Sqrt", name, new IntPtr[]
{
x as EagerTensor,
}, 1, null, status);
status.Check(true);
return tensor;
}

var _op = _op_def_lib._apply_op_helper("Sqrt", name, args: new { x });
@@ -682,10 +701,15 @@ namespace Tensorflow
{
if (tf.context.executing_eagerly())
{
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Sub", name, null,
x, y);
return _result;
using var status = new Status();
EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Sub", name, new IntPtr[]
{
x as EagerTensor,
y as EagerTensor
}, 2, null, status);
status.Check(true);
return tensor;
}

var _op = _op_def_lib._apply_op_helper("Sub", name, args: new { x, y });
@@ -704,10 +728,15 @@ namespace Tensorflow
{
if (tf.context.executing_eagerly())
{
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Equal", name, null,
x, y);
return _result;
using var status = new Status();
EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Equal", name, new IntPtr[]
{
x as EagerTensor,
y as EagerTensor
}, 2, null, status);
status.Check(true);
return tensor;
}

var _op = _op_def_lib._apply_op_helper("Equal", name, args: new { x, y });
@@ -727,10 +756,15 @@ namespace Tensorflow
{
if (tf.context.executing_eagerly())
{
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name,
"NotEqual", name, null,
x, y);
return _result;
using var status = new Status();
EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"NotEqual", name, new IntPtr[]
{
x as EagerTensor,
y as EagerTensor
}, 2, null, status);
status.Check(true);
return tensor;
}

var _op = _op_def_lib._apply_op_helper("NotEqual", name, args: new { x, y });
@@ -742,10 +776,15 @@ namespace Tensorflow
{
if (tf.context.executing_eagerly())
{
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Atan2", name, null,
y, x);
return _result;
using var status = new Status();
EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Atan2", name, new IntPtr[]
{
y as EagerTensor,
x as EagerTensor
}, 2, null, status);
status.Check(true);
return tensor;
}

var _op = _op_def_lib._apply_op_helper("Atan2", name, args: new { y, x });
@@ -757,14 +796,14 @@ namespace Tensorflow
if (tf.context.executing_eagerly())
{
using var status = new Status();
var _result = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Mul", name, new IntPtr[]
{
x as EagerTensor,
y as EagerTensor
}, 2, null, status);
status.Check(true);
return new EagerTensor(_result);
return tensor;
}

var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y });
@@ -776,10 +815,15 @@ namespace Tensorflow
{
if (tf.context.executing_eagerly())
{
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Mul", name, null,
x, y);
return _result;
using var status = new Status();
EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Mul", name, new IntPtr[]
{
x as EagerTensor,
y as EagerTensor,
}, 1, null, status);
status.Check(true);
return tensor;
}

var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y });
@@ -832,8 +876,15 @@ namespace Tensorflow
{
if (tf.context.executing_eagerly())
{
var _result = wrap_tfe_src.TFE_FastPathExecute(tf.context, "", "FloorDiv", name, null, x, y);
return _result;
using var status = new Status();
EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"FloorDiv", name, new IntPtr[]
{
x as EagerTensor,
y as EagerTensor
}, 2, null, status);
status.Check(true);
return tensor;
}

var _op = _op_def_lib._apply_op_helper("FloorDiv", name, args: new { x, y });
@@ -864,10 +915,8 @@ namespace Tensorflow
}, 2,
op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[]
{
"transpose_a",
transpose_a,
"transpose_b",
transpose_b
"transpose_a", transpose_a,
"transpose_b", transpose_b
}, status),
status);
status.Check(true);
@@ -965,6 +1014,19 @@ namespace Tensorflow

public static Tensor pow<Tx, Ty>(Tx x, Ty y, string name = null)
{
if (tf.context.executing_eagerly())
{
using var status = new Status();
EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"Pow", name, new IntPtr[]
{
x as EagerTensor,
y as EagerTensor
}, 2, null, status);
status.Check(true);
return tensor;
}

var _op = _op_def_lib._apply_op_helper("Pow", name, args: new { x, y });

return _op.outputs[0];


+ 2
- 2
src/TensorFlowNET.Core/Operations/gen_resource_variable_ops.cs View File

@@ -115,13 +115,13 @@ namespace Tensorflow
if (tf.context.executing_eagerly())
{
using var status = new Status();
var tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
EagerTensorHandle tensor = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"ReadVariableOp", name,
new IntPtr[] { resource as EagerTensor }, 1,
op => wrap_tfe_src.SetOpAttrs(tf.context, op, new object[] { "dtype", dtype }, status),
status);
status.Check(true);
return new EagerTensor(tensor);
return tensor;
}

var _op = _op_def_lib._apply_op_helper("ReadVariableOp", name, new


+ 14
- 1
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -17,6 +17,7 @@
using NumSharp;
using System;
using System.Collections.Generic;
using Tensorflow.Eager;
using Tensorflow.Framework;
using static Tensorflow.Binding;

@@ -540,6 +541,11 @@ namespace Tensorflow
}
else
{
if(x is EagerTensor)
{
return constant_op.constant(np.arange(x.shape.Rank));
}

var rank = array_ops.rank(x);
return range(0, rank, 1);
}
@@ -588,7 +594,14 @@ namespace Tensorflow
=> gen_math_ops.rsqrt(x, name: name);

public static Tensor pow<Tx, Ty>(Tx x, Ty y, string name = null)
=> gen_math_ops.pow(x, y, name: name);
=> tf_with(ops.name_scope(name, "Pow", new { x, y }), scope =>
{
name = scope;
var x_tensor = ops.convert_to_tensor(x, name: "x");
var y_tensor = ops.convert_to_tensor(y, name: "y", dtype: x_tensor.dtype.as_base_dtype());

return gen_math_ops.pow(x_tensor, y_tensor, name: name);
});

public static Tensor range(object start, object limit = null, object delta = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = "range")
{


+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs View File

@@ -54,7 +54,7 @@ namespace Tensorflow
#else
#region Compute

public static Tensor operator +(Tensor lhs, ResourceVariable rhs) => BinaryOpWrapper("add", lhs, rhs);
public static Tensor operator +(Tensor lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs);
public static Tensor operator +(Tensor lhs, NDArray rhs) => BinaryOpWrapper("add", lhs, rhs);
public static Tensor operator +(NDArray lhs, Tensor rhs) => BinaryOpWrapper("add", lhs, rhs);


+ 1
- 1
src/TensorFlowNET.Core/Tensors/Tensor.Value.cs View File

@@ -43,7 +43,7 @@ namespace Tensorflow
{
//T can only be unmanaged, I believe it is safe to say that MemoryCopy is valid for all cases this method can be called.
var src = (T*)buffer;
len *= ((long)itemsize);
len *= (long)itemsize;
System.Buffer.MemoryCopy(src, dst, len, len);
}
}


+ 15
- 0
src/TensorFlowNET.Core/Tensors/constant_op.cs View File

@@ -113,6 +113,21 @@ namespace Tensorflow

private static EagerTensor convert_to_eager_tensor(object value, Context ctx, TF_DataType dtype = TF_DataType.DtInvalid)
{
// convert data type
if (dtype != TF_DataType.DtInvalid &&
value.GetType().Name != "NDArray" &&
dtypes.as_base_dtype(dtype) != dtypes.as_dtype(value.GetType()))
{
switch (dtype)
{
case TF_DataType.TF_FLOAT:
value = Convert.ToSingle(value);
break;
default:
break;
}
}

switch (value)
{
case NDArray val:


+ 18
- 2
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -2,6 +2,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Gradients;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -65,6 +66,7 @@ namespace Tensorflow

protected Tensor _read_variable_op()
{
variable_accessed(this);
var result = gen_resource_variable_ops.read_variable_op(_handle, _dtype);
// _maybe_set_handle_data(_dtype, _handle, result);
return result;
@@ -82,12 +84,26 @@ namespace Tensorflow
void variable_accessed(BaseResourceVariable variable)
{
if (variable.trainable)
; // tape.variable_accessed(variable)
Tape.variable_accessed(variable as ResourceVariable);
}

/// <summary>
/// Constructs an op which reads the value of this variable.
///
/// Should be used when there are multiple reads, or when it is desirable to
/// read the value only after some condition is true.
/// </summary>
/// <returns></returns>
Tensor read_value()
=> tf_with(ops.name_scope("Read"), delegate
{
var value = _read_variable_op();
return array_ops.identity(value);
});

public override string ToString()
=> $"tf.Variable '{name}' shape={shape} dtype={dtype.as_numpy_name()}, numpy={numpy()}";

public NDArray numpy() => _read_variable_op().numpy();
public NDArray numpy() => read_value().numpy();
}
}

+ 5
- 0
src/TensorFlowNET.Core/Variables/ResourceVariable.Operators.cs View File

@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/

using NumSharp;
using System;
using static Tensorflow.Binding;

@@ -31,6 +32,7 @@ namespace Tensorflow
public static Tensor operator -(ResourceVariable x, Tensor y) => op_helper("sub", x, y);

public static Tensor operator *(ResourceVariable x, ResourceVariable y) => gen_math_ops.mul(x, y);
public static Tensor operator *(ResourceVariable x, NDArray y) => op_helper("mul", x, y);

public static Tensor operator <(ResourceVariable x, Tensor y) => gen_math_ops.less(x.value(), y);

@@ -53,6 +55,9 @@ namespace Tensorflow
case "sub":
result = gen_math_ops.sub(xVal, yTensor, name);
break;
case "mul":
result = gen_math_ops.mul(xVal, yTensor, name: name);
break;
default:
throw new NotImplementedException("");
}


+ 1
- 1
src/TensorFlowNET.Core/ops.cs View File

@@ -464,7 +464,7 @@ namespace Tensorflow
case RefVariable varVal:
return varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref);
case ResourceVariable varVal:
return null;
return varVal.value();
case TensorShape ts:
return constant_op.constant(ts.dims, dtype: dtype, name: name);
case int[] dims:


Loading…
Cancel
Save