@@ -116,6 +116,9 @@ namespace Tensorflow | |||
public Tensor fill<T>(Tensor dims, T value, string name = null) | |||
=> gen_array_ops.fill(dims, value, name: name); | |||
public Tensor fill<T>(Shape dims, T value, string name = null) | |||
=> array_ops.fill(dims, value, name: name); | |||
/// <summary> | |||
/// Return a tensor with the same shape and contents as input. | |||
/// </summary> | |||
@@ -68,12 +68,12 @@ namespace Tensorflow.Eager | |||
Tensor[] op_inputs, | |||
object[] attrs, | |||
Tensor[] op_outputs) | |||
=> (output_grads, unneeded_gradients) => | |||
=> (out_grads, unneeded_gradients) => | |||
{ | |||
if (ops.gradientFunctions[op_name] == null) | |||
return new Tensor[op_inputs.Length]; | |||
var op = new EagerOperation | |||
var oper = new EagerOperation | |||
{ | |||
Name = op_name, | |||
NumInputs = op_inputs.Length, | |||
@@ -84,7 +84,40 @@ namespace Tensorflow.Eager | |||
Attrs = attrs | |||
}; | |||
return ops.gradientFunctions[op_name](op, output_grads); | |||
/*return op_name switch | |||
{ | |||
"Add" => math_grad._AddGrad(oper, out_grads), | |||
"AddV2" => math_grad._AddV2Grad(oper, out_grads), | |||
"BiasAdd" => nn_grad._BiasAddGrad(oper, out_grads), | |||
"Cast" => math_grad._CastGrad(oper, out_grads), | |||
"ConcatV2" => array_grad._ConcatV2Grad(oper, out_grads), | |||
"Conv2D" => nn_grad._Conv2DGrad(oper, out_grads), | |||
"ExpandDims" => array_grad._ExpandDimsGrad(oper, out_grads), | |||
"Exp" => math_grad._ExpGrad(oper, out_grads), | |||
"FusedBatchNormV3" => nn_grad._FusedBatchNormV3Grad(oper, out_grads), | |||
"Id" => math_grad._IdGrad(oper, out_grads), | |||
"LeakyRelu" => nn_grad._LeakyReluGrad(oper, out_grads), | |||
"Log1p" => math_grad._Log1pGrad(oper, out_grads), | |||
"Maximum" => math_grad._MaximumGrad(oper, out_grads), | |||
"Mean" => math_grad._MeanGrad(oper, out_grads), | |||
"Minimum" => math_grad._MinimumGrad(oper, out_grads), | |||
"Mul" => math_grad._MulGrad(oper, out_grads), | |||
"Neg" => math_grad._NegGrad(oper, out_grads), | |||
"Pad" => array_grad._PadGrad(oper, out_grads), | |||
"Pow" => math_grad._PowGrad(oper, out_grads), | |||
"RealDiv" => math_grad._RealDivGrad(oper, out_grads), | |||
"Read" => resource_variable_grad._ReadGrad(oper, out_grads), | |||
"Reshape" => array_grad._ReshapeGrad(oper, out_grads), | |||
"ResizeNearestNeighbor" => image_grad._ResizeNearestNeighborGrad(oper, out_grads), | |||
"Select" => math_grad._SelectGrad(oper, out_grads), | |||
"Sigmoid" => math_grad._SigmoidGrad(oper, out_grads), | |||
"Sum" => math_grad._SumGrad(oper, out_grads), | |||
"Sub" => math_grad._SubGrad(oper, out_grads), | |||
"StridedSlice" => array_grad._StridedSliceGrad(oper, out_grads), | |||
_ => ops.gradientFunctions[op_name](oper, out_grads) | |||
};*/ | |||
return ops.gradientFunctions[op_name](oper, out_grads); | |||
}; | |||
bool CouldForwardprop() | |||
@@ -15,7 +15,7 @@ namespace Tensorflow.Eager | |||
/// </summary> | |||
public partial class EagerRunner | |||
{ | |||
UnorderedMap<string, SafeOpHandle> thread_local_eager_operation_map = new UnorderedMap<string, SafeOpHandle>(); | |||
UnorderedMap<string, SafeEagerOpHandle> thread_local_eager_operation_map = new UnorderedMap<string, SafeEagerOpHandle>(); | |||
public void ClearEagerOperationMap() | |||
=> thread_local_eager_operation_map.Clear(); | |||
@@ -157,7 +157,7 @@ namespace Tensorflow.Eager | |||
return flat_result; | |||
} | |||
SafeOpHandle GetOp(Context ctx, string op_or_function_name, Status status) | |||
SafeEagerOpHandle GetOp(Context ctx, string op_or_function_name, Status status) | |||
{ | |||
if (thread_local_eager_operation_map.find(op_or_function_name, out var op)) | |||
c_api.TFE_OpReset(op, op_or_function_name, ctx.DeviceName, status.Handle); | |||
@@ -205,7 +205,7 @@ namespace Tensorflow.Eager | |||
ArgDef input_arg, | |||
List<object> flattened_attrs, | |||
List<Tensor> flattened_inputs, | |||
SafeOpHandle op, | |||
SafeEagerOpHandle op, | |||
Status status) | |||
{ | |||
var tensor = tf.convert_to_tensor(inputs); | |||
@@ -225,7 +225,7 @@ namespace Tensorflow.Eager | |||
return true; | |||
} | |||
public void SetOpAttrs(SafeOpHandle op, params object[] attrs) | |||
public void SetOpAttrs(SafeEagerOpHandle op, params object[] attrs) | |||
{ | |||
var status = tf.Status; | |||
var len = attrs.Length; | |||
@@ -258,7 +258,7 @@ namespace Tensorflow.Eager | |||
/// <param name="attr_value"></param> | |||
/// <param name="attr_list_sizes"></param> | |||
/// <param name="status"></param> | |||
void SetOpAttrWithDefaults(Context ctx, SafeOpHandle op, AttrDef attr, | |||
void SetOpAttrWithDefaults(Context ctx, SafeEagerOpHandle op, AttrDef attr, | |||
string attr_name, object attr_value, | |||
Dictionary<string, long> attr_list_sizes, | |||
Status status) | |||
@@ -280,7 +280,7 @@ namespace Tensorflow.Eager | |||
} | |||
} | |||
bool SetOpAttrList(Context ctx, SafeOpHandle op, | |||
bool SetOpAttrList(Context ctx, SafeEagerOpHandle op, | |||
string key, object values, TF_AttrType type, | |||
Dictionary<string, long> attr_list_sizes, | |||
Status status) | |||
@@ -326,7 +326,7 @@ namespace Tensorflow.Eager | |||
return true; | |||
} | |||
bool SetOpAttrScalar(Context ctx, SafeOpHandle op, | |||
bool SetOpAttrScalar(Context ctx, SafeEagerOpHandle op, | |||
string key, object value, TF_AttrType type, | |||
Dictionary<string, long> attr_list_sizes, | |||
Status status) | |||
@@ -16,18 +16,17 @@ | |||
using System; | |||
using Tensorflow.Util; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Eager | |||
{ | |||
public sealed class SafeOpHandle : SafeTensorflowHandle | |||
public sealed class SafeEagerOpHandle : SafeTensorflowHandle | |||
{ | |||
private SafeOpHandle() | |||
private SafeEagerOpHandle() | |||
{ | |||
} | |||
public SafeOpHandle(IntPtr handle) | |||
public SafeEagerOpHandle(IntPtr handle) | |||
: base(handle) | |||
{ | |||
@@ -59,7 +59,7 @@ namespace Tensorflow | |||
/// <param name="status">TF_Status*</param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern TF_AttrType TFE_OpGetAttrType(SafeOpHandle op, string attr_name, ref byte is_list, SafeStatusHandle status); | |||
public static extern TF_AttrType TFE_OpGetAttrType(SafeEagerOpHandle op, string attr_name, ref byte is_list, SafeStatusHandle status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern TF_AttrType TFE_OpNameGetAttrType(SafeContextHandle ctx, string op_or_function_name, string attr_name, ref byte is_list, SafeStatusHandle status); | |||
@@ -72,7 +72,7 @@ namespace Tensorflow | |||
/// <param name="input_name">const char*</param> | |||
/// <param name="status">TF_Status*</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern int TFE_OpGetInputLength(SafeOpHandle op, string input_name, SafeStatusHandle status); | |||
public static extern int TFE_OpGetInputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status); | |||
/// <summary> | |||
/// Returns the length (number of tensors) of the output argument `output_name` | |||
@@ -83,7 +83,7 @@ namespace Tensorflow | |||
/// <param name="status"></param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern int TFE_OpGetOutputLength(SafeOpHandle op, string input_name, SafeStatusHandle status); | |||
public static extern int TFE_OpGetOutputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status); | |||
/// <summary> | |||
/// | |||
@@ -94,7 +94,7 @@ namespace Tensorflow | |||
/// <param name="status">TF_Status*</param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern int TFE_OpAddInputList(SafeOpHandle op, [In, MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(SafeHandleArrayMarshaler))] SafeTensorHandleHandle[] inputs, int num_inputs, SafeStatusHandle status); | |||
public static extern int TFE_OpAddInputList(SafeEagerOpHandle op, [In, MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(SafeHandleArrayMarshaler))] SafeTensorHandleHandle[] inputs, int num_inputs, SafeStatusHandle status); | |||
/// <summary> | |||
/// | |||
@@ -161,7 +161,7 @@ namespace Tensorflow | |||
/// <param name="retvals"></param> | |||
/// <param name="num_retvals"></param> | |||
/// <param name="status"></param> | |||
public static void TFE_Execute(SafeOpHandle op, SafeTensorHandleHandle[] retvals, out int num_retvals, SafeStatusHandle status) | |||
public static void TFE_Execute(SafeEagerOpHandle op, SafeTensorHandleHandle[] retvals, out int num_retvals, SafeStatusHandle status) | |||
{ | |||
unsafe | |||
{ | |||
@@ -187,7 +187,7 @@ namespace Tensorflow | |||
/// <param name="num_retvals">int*</param> | |||
/// <param name="status">TF_Status*</param> | |||
[DllImport(TensorFlowLibName)] | |||
private static unsafe extern void TFE_Execute(SafeOpHandle op, IntPtr* retvals, ref int num_retvals, SafeStatusHandle status); | |||
private static unsafe extern void TFE_Execute(SafeEagerOpHandle op, IntPtr* retvals, ref int num_retvals, SafeStatusHandle status); | |||
/// <summary> | |||
/// | |||
@@ -197,7 +197,7 @@ namespace Tensorflow | |||
/// <param name="status">TF_Status*</param> | |||
/// <returns></returns> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern SafeOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status); | |||
public static extern SafeEagerOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status); | |||
/// <summary> | |||
/// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This | |||
@@ -213,7 +213,7 @@ namespace Tensorflow | |||
/// <param name="raw_device_name">const char*</param> | |||
/// <param name="status">TF_Status*</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpReset(SafeOpHandle op_to_reset, string op_or_function_name, string raw_device_name, SafeStatusHandle status); | |||
public static extern void TFE_OpReset(SafeEagerOpHandle op_to_reset, string op_or_function_name, string raw_device_name, SafeStatusHandle status); | |||
/// <summary> | |||
/// | |||
@@ -229,13 +229,13 @@ namespace Tensorflow | |||
/// <param name="attr_name">const char*</param> | |||
/// <param name="value">TF_DataType</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpSetAttrType(SafeOpHandle op, string attr_name, TF_DataType value); | |||
public static extern void TFE_OpSetAttrType(SafeEagerOpHandle op, string attr_name, TF_DataType value); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpSetAttrInt(SafeOpHandle op, string attr_name, long value); | |||
public static extern void TFE_OpSetAttrInt(SafeEagerOpHandle op, string attr_name, long value); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpSetAttrFloat(SafeOpHandle op, string attr_name, float value); | |||
public static extern void TFE_OpSetAttrFloat(SafeEagerOpHandle op, string attr_name, float value); | |||
/// <summary> | |||
/// | |||
@@ -246,19 +246,19 @@ namespace Tensorflow | |||
/// <param name="num_dims">const int</param> | |||
/// <param name="out_status">TF_Status*</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpSetAttrShape(SafeOpHandle op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status); | |||
public static extern void TFE_OpSetAttrShape(SafeEagerOpHandle op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpSetAttrShapeList(SafeOpHandle op, string attr_name, IntPtr[] dims, int[] num_dims, int num_values, SafeStatusHandle out_status); | |||
public static extern void TFE_OpSetAttrShapeList(SafeEagerOpHandle op, string attr_name, IntPtr[] dims, int[] num_dims, int num_values, SafeStatusHandle out_status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpSetAttrStringList(SafeOpHandle op, string attr_name, string[] values, ulong[] lengths, int num_values); | |||
public static extern void TFE_OpSetAttrStringList(SafeEagerOpHandle op, string attr_name, string[] values, ulong[] lengths, int num_values); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpSetAttrBool(SafeOpHandle op, string attr_name, bool value); | |||
public static extern void TFE_OpSetAttrBool(SafeEagerOpHandle op, string attr_name, bool value); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpSetAttrFunctionName(SafeOpHandle op, string attr_name, string data, int length); | |||
public static extern void TFE_OpSetAttrFunctionName(SafeEagerOpHandle op, string attr_name, string data, int length); | |||
/// <summary> | |||
/// | |||
@@ -268,16 +268,16 @@ namespace Tensorflow | |||
/// <param name="value">const void*</param> | |||
/// <param name="length">size_t</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpSetAttrString(SafeOpHandle op, string attr_name, string value, ulong length); | |||
public static extern void TFE_OpSetAttrString(SafeEagerOpHandle op, string attr_name, string value, ulong length); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpSetAttrTypeList(SafeOpHandle op, string attr_name, TF_DataType[] values, int num_values); | |||
public static extern void TFE_OpSetAttrTypeList(SafeEagerOpHandle op, string attr_name, TF_DataType[] values, int num_values); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpSetAttrIntList(SafeOpHandle op, string attr_name, long[] values, int num_values); | |||
public static extern void TFE_OpSetAttrIntList(SafeEagerOpHandle op, string attr_name, long[] values, int num_values); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpSetAttrValueProto(SafeOpHandle op, string attr_name, IMessage[] proto, int proto_len, SafeStatusHandle status); | |||
public static extern void TFE_OpSetAttrValueProto(SafeEagerOpHandle op, string attr_name, IMessage[] proto, int proto_len, SafeStatusHandle status); | |||
/// <summary> | |||
/// | |||
@@ -286,7 +286,7 @@ namespace Tensorflow | |||
/// <param name="device_name"></param> | |||
/// <param name="status"></param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpSetDevice(SafeOpHandle op, string device_name, SafeStatusHandle status); | |||
public static extern void TFE_OpSetDevice(SafeEagerOpHandle op, string device_name, SafeStatusHandle status); | |||
/// <summary> | |||
/// | |||
@@ -295,7 +295,7 @@ namespace Tensorflow | |||
/// <param name="h">TFE_TensorHandle*</param> | |||
/// <param name="status">TF_Status*</param> | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TFE_OpAddInput(SafeOpHandle op, SafeTensorHandleHandle h, SafeStatusHandle status); | |||
public static extern void TFE_OpAddInput(SafeEagerOpHandle op, SafeTensorHandleHandle h, SafeStatusHandle status); | |||
/// <summary> | |||
/// | |||
@@ -51,7 +51,7 @@ namespace Tensorflow.Gradients | |||
} | |||
[RegisterGradient("ConcatV2")] | |||
public static Tensor[] _ConcatGradV2(Operation op, Tensor[] grads) | |||
public static Tensor[] _ConcatV2Grad(Operation op, Tensor[] grads) | |||
{ | |||
var grad = grads[0]; | |||
return _ConcatGradHelper(op, grad, start_value_index: 0, end_value_index: -1, dim_index: -1); | |||
@@ -50,42 +50,11 @@ namespace Tensorflow | |||
{ | |||
// tf.Logger.Debug($"Caculate Gradient: {oper.name} {m.Name}"); | |||
var results = m.Name switch | |||
{ | |||
/*"_AddGrad" => math_grad._AddGrad(oper, out_grads), | |||
"_AddV2Grad" => math_grad._AddV2Grad(oper, out_grads), | |||
"_BiasAddGrad" => nn_grad._BiasAddGrad(oper, out_grads), | |||
"_CastGrad" => math_grad._CastGrad(oper, out_grads), | |||
"_ConcatGradV2" => array_grad._ConcatGradV2(oper, out_grads), | |||
"_Conv2DGrad" => nn_grad._Conv2DGrad(oper, out_grads), | |||
"_ExpandDimsGrad" => array_grad._ExpandDimsGrad(oper, out_grads), | |||
"_ExpGrad" => math_grad._ExpGrad(oper, out_grads), | |||
"_FusedBatchNormV3Grad" => nn_grad._FusedBatchNormV3Grad(oper, out_grads), | |||
"_IdGrad" => math_grad._IdGrad(oper, out_grads), | |||
"_LeakyReluGrad" => nn_grad._LeakyReluGrad(oper, out_grads), | |||
"_Log1pGrad" => math_grad._Log1pGrad(oper, out_grads), | |||
"_MaximumGrad" => math_grad._MaximumGrad(oper, out_grads), | |||
"_MeanGrad" => math_grad._MeanGrad(oper, out_grads), | |||
"_MinimumGrad" => math_grad._MinimumGrad(oper, out_grads), | |||
"_MulGrad" => math_grad._MulGrad(oper, out_grads), | |||
"_NegGrad" => math_grad._NegGrad(oper, out_grads), | |||
"_PadGrad" => array_grad._PadGrad(oper, out_grads), | |||
"_PowGrad" => math_grad._PowGrad(oper, out_grads), | |||
"_RealDivGrad" => math_grad._RealDivGrad(oper, out_grads), | |||
"_ReadGrad" => resource_variable_grad._ReadGrad(oper, out_grads), | |||
"_ReshapeGrad" => array_grad._ReshapeGrad(oper, out_grads), | |||
"_ResizeNearestNeighborGrad" => image_grad._ResizeNearestNeighborGrad(oper, out_grads), | |||
"_SelectGrad" => math_grad._SelectGrad(oper, out_grads), | |||
"_SigmoidGrad" => math_grad._SigmoidGrad(oper, out_grads), | |||
"_SumGrad" => math_grad._SumGrad(oper, out_grads), | |||
"_SubGrad" => math_grad._SubGrad(oper, out_grads), | |||
"_StridedSliceGrad" => array_grad._StridedSliceGrad(oper, out_grads),*/ | |||
_ => g.InvokeMember(m.Name, | |||
BindingFlags.InvokeMethod, | |||
null, | |||
null, | |||
args: new object[] { oper, out_grads }) as Tensor[] | |||
}; | |||
var results = g.InvokeMember(m.Name, | |||
BindingFlags.InvokeMethod, | |||
null, | |||
null, | |||
args: new object[] { oper, out_grads }) as Tensor[]; | |||
// foreach (var result in results.Where(x => x != null)) | |||
// tf.Logger.Debug($"Gradient: {result.name} {result.shape}"); | |||
@@ -39,7 +39,7 @@ namespace Tensorflow | |||
public void _add_control_input(Operation op) | |||
{ | |||
c_api.TF_AddControlInput(OpDesc, op); | |||
// c_api.TF_AddControlInput(_opDesc, op); | |||
//c_api.AddControlInput(graph, _handle, op); | |||
} | |||
@@ -46,7 +46,6 @@ namespace Tensorflow | |||
private readonly IntPtr _handle; // _c_op in python | |||
private readonly Graph _graph; | |||
private NodeDef _node_def; | |||
public string type => OpType; | |||
@@ -57,24 +56,14 @@ namespace Tensorflow | |||
public int _id_value { get; set; } | |||
public Operation op => this; | |||
public TF_DataType dtype => TF_DataType.DtInvalid; | |||
public virtual string name => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationName(_handle)); | |||
public string OpType => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | |||
public virtual string name => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationName(_handle)); | |||
public string OpType => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | |||
public string Device => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | |||
public string Device => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | |||
bool _is_stateful; | |||
public OperationDescription OpDesc { get; set; } | |||
// OperationDescription _opDesc; | |||
public NodeDef node_def | |||
{ | |||
get | |||
{ | |||
if (_node_def == null) | |||
_node_def = GetNodeDef(); | |||
return _node_def; | |||
} | |||
} | |||
public NodeDef node_def => GetNodeDef(); | |||
public Operation(IntPtr handle, Graph g = null) | |||
{ | |||
@@ -168,8 +157,7 @@ namespace Tensorflow | |||
if (op_def == null) | |||
op_def = g.GetOpDef(node_def.Op); | |||
(_handle, OpDesc) = ops._create_c_op(g, node_def, inputs, control_input_ops.ToArray(), op_def); | |||
_is_stateful = op_def.IsStateful; | |||
(_handle, _) = ops._create_c_op(g, node_def, inputs, control_input_ops.ToArray(), op_def); | |||
// Initialize self._outputs. | |||
output_types = new TF_DataType[NumOutputs]; | |||
@@ -199,16 +187,11 @@ namespace Tensorflow | |||
if (tf.executing_eagerly()) | |||
return (T[])get_attr(name); | |||
AttrValue x = null; | |||
using var buf = new Buffer(); | |||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf.Handle, tf.Status.Handle); | |||
tf.Status.Check(true); | |||
lock (Locks.ProcessWide) | |||
{ | |||
using var buf = new Buffer(); | |||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf.Handle, tf.Status.Handle); | |||
tf.Status.Check(true); | |||
x = AttrValue.Parser.ParseFrom(buf.ToArray()); | |||
} | |||
var x = AttrValue.Parser.ParseFrom(buf.ToArray()); | |||
string oneof_value = x.ValueCase.ToString(); | |||
if (string.IsNullOrEmpty(oneof_value)) | |||
@@ -227,16 +210,11 @@ namespace Tensorflow | |||
public virtual object get_attr(string name) | |||
{ | |||
AttrValue x = null; | |||
lock (Locks.ProcessWide) | |||
{ | |||
using var buf = new Buffer(); | |||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf.Handle, tf.Status.Handle); | |||
tf.Status.Check(true); | |||
using var buf = new Buffer(); | |||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf.Handle, tf.Status.Handle); | |||
tf.Status.Check(true); | |||
x = AttrValue.Parser.ParseFrom(buf.ToArray()); | |||
} | |||
var x = AttrValue.Parser.ParseFrom(buf.ToArray()); | |||
string oneof_value = x.ValueCase.ToString(); | |||
if (string.IsNullOrEmpty(oneof_value)) | |||
@@ -262,15 +240,10 @@ namespace Tensorflow | |||
private NodeDef GetNodeDef() | |||
{ | |||
lock (Locks.ProcessWide) | |||
using (var s = new Status()) | |||
using (var buffer = new Buffer()) | |||
{ | |||
c_api.TF_OperationToNodeDef(_handle, buffer.Handle, s.Handle); | |||
s.Check(); | |||
return NodeDef.Parser.ParseFrom(buffer.ToArray()); | |||
} | |||
using var buffer = new Buffer(); | |||
c_api.TF_OperationToNodeDef(_handle, buffer.Handle, tf.Status.Handle); | |||
tf.Status.Check(throwException: true); | |||
return NodeDef.Parser.ParseFrom(buffer.ToArray()); | |||
} | |||
/// <summary> | |||
@@ -284,21 +257,21 @@ namespace Tensorflow | |||
{ | |||
_assert_same_graph(tensor); | |||
var input = _tf_input(index); | |||
var output = tensor._as_tf_output(); | |||
// var input = _tf_input(index); | |||
// var output = tensor._as_tf_output(); | |||
// Reset cached inputs. | |||
_inputs_val = null; | |||
_node_def = null; | |||
// _node_def = null; | |||
// after the c_api call next time _inputs is accessed | |||
// the updated inputs are reloaded from the c_api | |||
lock (Locks.ProcessWide) | |||
{ | |||
// lock (Locks.ProcessWide) | |||
// { | |||
// disable | |||
// c_api.TF_UpdateEdge(_graph, output, input, tf.Status.Handle); | |||
//var updated_inputs = inputs; | |||
tf.Status.Check(); | |||
} | |||
// tf.Status.Check(); | |||
// } | |||
} | |||
private void _assert_same_graph(Tensor tensor) | |||
@@ -311,7 +284,7 @@ namespace Tensorflow | |||
/// </summary> | |||
public TF_Output _tf_output(int output_idx) | |||
{ | |||
return new TF_Output(op, output_idx); | |||
return new TF_Output(_handle, output_idx); | |||
} | |||
/// <summary> | |||
@@ -319,7 +292,7 @@ namespace Tensorflow | |||
/// </summary> | |||
public TF_Input _tf_input(int input_idx) | |||
{ | |||
return new TF_Input(op, input_idx); | |||
return new TF_Input(_handle, input_idx); | |||
} | |||
public NDArray numpy() => throw new NotImplementedException(""); | |||
@@ -80,27 +80,16 @@ namespace Tensorflow | |||
return tf_with(ops.name_scope(name, "zeros", shape), scope => | |||
{ | |||
name = scope; | |||
var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); | |||
Tensor zeros = null; | |||
switch (dtype) | |||
// var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); | |||
Tensor zeros = dtype switch | |||
{ | |||
case TF_DataType.TF_DOUBLE: | |||
zeros = constant(0d); | |||
break; | |||
case TF_DataType.TF_FLOAT: | |||
zeros = constant(0f); | |||
break; | |||
case TF_DataType.TF_INT8: | |||
zeros = constant((sbyte)0); | |||
break; | |||
case TF_DataType.TF_UINT8: | |||
zeros = constant((byte)0); | |||
break; | |||
default: | |||
zeros = constant(0); | |||
break; | |||
} | |||
return fill(shape_tensor, zeros, name: name); | |||
TF_DataType.TF_DOUBLE => constant(0d), | |||
TF_DataType.TF_FLOAT => constant(0f), | |||
TF_DataType.TF_INT8 => constant((sbyte)0), | |||
TF_DataType.TF_UINT8 => constant((byte)0), | |||
_ => constant(0) | |||
}; | |||
return fill(shape, zeros, name: name); | |||
}); | |||
} | |||
else | |||
@@ -311,12 +300,8 @@ namespace Tensorflow | |||
/// <param name="value">A value to fill the returned `tf.Tensor`.</param> | |||
/// <param name="name">Optional string. The name of the output `tf.Tensor`.</param> | |||
/// <returns>A `tf.Tensor` with shape `dims` and the same dtype as `value`.</returns> | |||
public static Tensor fill(Tensor dims, Tensor value, string name = null) | |||
{ | |||
var result = gen_array_ops.fill(dims, value, name: name); | |||
// tensor_util.maybe_set_static_shape(result, dims) | |||
return result; | |||
} | |||
public static Tensor fill<T>(Shape dims, T value, string name = null) | |||
=> gen_array_ops.fill(dims, value, name: name); | |||
/// <summary> | |||
/// Returns the rank of a tensor. | |||
@@ -425,25 +410,18 @@ namespace Tensorflow | |||
dtype = dtype.as_base_dtype(); | |||
name = scope; | |||
Tensor ones = null; | |||
switch (dtype) | |||
Tensor ones = dtype switch | |||
{ | |||
case TF_DataType.TF_DOUBLE: | |||
ones = constant(1.0d); | |||
break; | |||
case TF_DataType.TF_FLOAT: | |||
ones = constant(1.0f); | |||
break; | |||
default: | |||
ones = constant(1); | |||
break; | |||
} | |||
TF_DataType.TF_DOUBLE => constant(1.0d), | |||
TF_DataType.TF_FLOAT => constant(1.0f), | |||
_ => constant(1) | |||
}; | |||
if (shape.ndim == 0) | |||
return ones; | |||
var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); | |||
return fill(shape_tensor, ones, name: name); | |||
// var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); | |||
return fill(shape, ones, name: name); | |||
}); | |||
public static Tensor one_hot(Tensor indices, Tensor depth, | |||
@@ -2086,8 +2086,7 @@ new_height, new_width"); | |||
gather_idx), | |||
new[] { batch_size, -1 }); | |||
} | |||
var invalid_index = array_ops.fill(ops.convert_to_tensor(new object[] { batch_size, max_output_size }), | |||
tf.constant(0)); | |||
var invalid_index = array_ops.fill(new Shape((int)batch_size, (int)max_output_size), 0); | |||
var idx_index = array_ops.expand_dims(math_ops.range(max_output_size), 0); | |||
var num_valid_expanded = array_ops.expand_dims(num_valid, 1); | |||
idx = array_ops.where(idx_index < num_valid_expanded, | |||
@@ -223,42 +223,39 @@ namespace Tensorflow | |||
var input_tensors = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | |||
lock (Locks.ProcessWide) | |||
{ | |||
var op_desc = graph.NewOperation(node_def.Op, node_def.Name); | |||
var op_desc = graph.NewOperation(node_def.Op, node_def.Name); | |||
if (!string.IsNullOrEmpty(node_def.Device)) | |||
c_api.TF_SetDevice(op_desc, node_def.Device); | |||
if (!string.IsNullOrEmpty(node_def.Device)) | |||
c_api.TF_SetDevice(op_desc, node_def.Device); | |||
// Add inputs | |||
foreach (var op_input in input_tensors) | |||
{ | |||
if (op_input.IsList) | |||
c_api.TF_AddInputList(op_desc, op_input.Select(x => x._as_tf_output()).ToArray(), op_input.Count()); | |||
else if (op_input.Count() == 1) | |||
c_api.TF_AddInput(op_desc, op_input[0]._as_tf_output()); | |||
} | |||
// Add inputs | |||
foreach (var op_input in input_tensors) | |||
{ | |||
if (op_input.IsList) | |||
c_api.TF_AddInputList(op_desc, op_input.Select(x => x._as_tf_output()).ToArray(), op_input.Count()); | |||
else if (op_input.Count() == 1) | |||
c_api.TF_AddInput(op_desc, op_input[0]._as_tf_output()); | |||
} | |||
var status = tf.Status; | |||
var status = tf.Status; | |||
// Add control inputs | |||
foreach (var control_input in control_inputs) | |||
c_api.TF_AddControlInput(op_desc, control_input); | |||
// Add control inputs | |||
foreach (var control_input in control_inputs) | |||
c_api.TF_AddControlInput(op_desc, control_input); | |||
// Add attrs | |||
foreach (var attr in node_def.Attr) | |||
{ | |||
var bytes = attr.Value.ToByteArray(); | |||
c_api.TF_SetAttrValueProto(op_desc, attr.Key, bytes, proto_len: bytes.Length, status: status.Handle); | |||
status.Check(true); | |||
} | |||
// Add attrs | |||
foreach (var attr in node_def.Attr) | |||
{ | |||
var bytes = attr.Value.ToByteArray(); | |||
c_api.TF_SetAttrValueProto(op_desc, attr.Key, bytes, proto_len: bytes.Length, status: status.Handle); | |||
status.Check(true); | |||
} | |||
var c_op = c_api.TF_FinishOperation(op_desc, status.Handle); | |||
var c_op = op_desc.FinishOperation(status); | |||
status.Check(true); | |||
status.Check(true); | |||
return (c_op, op_desc); | |||
} | |||
return (c_op, op_desc); | |||
} | |||
public static Tensors[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField<string, AttrValue> attrs) | |||
@@ -101,9 +101,9 @@ namespace Tensorflow.Keras.Engine | |||
var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}")); | |||
Binding.tf_output_redirect.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}"); | |||
} | |||
} | |||
GC.Collect(); | |||
GC.Collect(); | |||
} | |||
GC.WaitForPendingFinalizers(); | |||
} | |||
} | |||
@@ -5,18 +5,23 @@ using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Layers { | |||
public class Tanh : Layer { | |||
public Tanh ( LayerArgs args ) : base(args) { | |||
// Tanh has no arguments | |||
} | |||
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | |||
Tensor x = inputs; | |||
namespace Tensorflow.Keras.Layers | |||
{ | |||
public class Tanh : Layer | |||
{ | |||
public Tanh(LayerArgs args) : base(args) | |||
{ | |||
// Tanh has no arguments | |||
} | |||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||
{ | |||
Tensor x = inputs; | |||
return tf.tanh(x); | |||
} | |||
public override Shape ComputeOutputShape ( Shape input_shape ) { | |||
return input_shape; | |||
} | |||
} | |||
return tf.tanh(x); | |||
} | |||
public override Shape ComputeOutputShape(Shape input_shape) | |||
{ | |||
return input_shape; | |||
} | |||
} | |||
} |
@@ -80,25 +80,25 @@ namespace Tensorflow.Native.UnitTest | |||
protected ulong TF_TensorByteSize(SafeTensorHandle t) | |||
=> c_api.TF_TensorByteSize(t); | |||
protected void TFE_OpAddInput(SafeOpHandle op, SafeTensorHandleHandle h, SafeStatusHandle status) | |||
protected void TFE_OpAddInput(SafeEagerOpHandle op, SafeTensorHandleHandle h, SafeStatusHandle status) | |||
=> c_api.TFE_OpAddInput(op, h, status); | |||
protected void TFE_OpSetAttrType(SafeOpHandle op, string attr_name, TF_DataType value) | |||
protected void TFE_OpSetAttrType(SafeEagerOpHandle op, string attr_name, TF_DataType value) | |||
=> c_api.TFE_OpSetAttrType(op, attr_name, value); | |||
protected void TFE_OpSetAttrShape(SafeOpHandle op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status) | |||
protected void TFE_OpSetAttrShape(SafeEagerOpHandle op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status) | |||
=> c_api.TFE_OpSetAttrShape(op, attr_name, dims, num_dims, out_status); | |||
protected void TFE_OpSetAttrString(SafeOpHandle op, string attr_name, string value, uint length) | |||
protected void TFE_OpSetAttrString(SafeEagerOpHandle op, string attr_name, string value, uint length) | |||
=> c_api.TFE_OpSetAttrString(op, attr_name, value, length); | |||
protected SafeOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status) | |||
protected SafeEagerOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status) | |||
=> c_api.TFE_NewOp(ctx, op_or_function_name, status); | |||
protected SafeTensorHandleHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status) | |||
=> c_api.TFE_NewTensorHandle(t, status); | |||
protected void TFE_Execute(SafeOpHandle op, SafeTensorHandleHandle[] retvals, out int num_retvals, SafeStatusHandle status) | |||
protected void TFE_Execute(SafeEagerOpHandle op, SafeTensorHandleHandle[] retvals, out int num_retvals, SafeStatusHandle status) | |||
=> c_api.TFE_Execute(op, retvals, out num_retvals, status); | |||
protected SafeContextOptionsHandle TFE_NewContextOptions() | |||
@@ -107,13 +107,13 @@ namespace Tensorflow.Native.UnitTest | |||
protected SafeContextHandle TFE_NewContext(SafeContextOptionsHandle opts, SafeStatusHandle status) | |||
=> c_api.TFE_NewContext(opts, status); | |||
protected int TFE_OpGetInputLength(SafeOpHandle op, string input_name, SafeStatusHandle status) | |||
protected int TFE_OpGetInputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status) | |||
=> c_api.TFE_OpGetInputLength(op, input_name, status); | |||
protected int TFE_OpAddInputList(SafeOpHandle op, SafeTensorHandleHandle[] inputs, int num_inputs, SafeStatusHandle status) | |||
protected int TFE_OpAddInputList(SafeEagerOpHandle op, SafeTensorHandleHandle[] inputs, int num_inputs, SafeStatusHandle status) | |||
=> c_api.TFE_OpAddInputList(op, inputs, num_inputs, status); | |||
protected int TFE_OpGetOutputLength(SafeOpHandle op, string input_name, SafeStatusHandle status) | |||
protected int TFE_OpGetOutputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status) | |||
=> c_api.TFE_OpGetOutputLength(op, input_name, status); | |||
protected void TFE_DeleteTensorHandle(IntPtr h) | |||
@@ -149,7 +149,7 @@ namespace Tensorflow.Native.UnitTest | |||
protected SafeTensorHandleHandle TFE_TensorHandleCopyToDevice(SafeTensorHandleHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status) | |||
=> c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); | |||
protected void TFE_OpSetDevice(SafeOpHandle op, string device_name, SafeStatusHandle status) | |||
protected void TFE_OpSetDevice(SafeEagerOpHandle op, string device_name, SafeStatusHandle status) | |||
=> c_api.TFE_OpSetDevice(op, device_name, status); | |||
} | |||
} |
@@ -25,7 +25,7 @@ namespace Tensorflow.Native.UnitTest.Eager | |||
return th; | |||
} | |||
SafeOpHandle MatMulOp(SafeContextHandle ctx, SafeTensorHandleHandle a, SafeTensorHandleHandle b) | |||
SafeEagerOpHandle MatMulOp(SafeContextHandle ctx, SafeTensorHandleHandle a, SafeTensorHandleHandle b) | |||
{ | |||
using var status = TF_NewStatus(); | |||
@@ -63,7 +63,7 @@ namespace Tensorflow.Native.UnitTest.Eager | |||
return false; | |||
} | |||
SafeOpHandle ShapeOp(SafeContextHandle ctx, SafeTensorHandleHandle a) | |||
SafeEagerOpHandle ShapeOp(SafeContextHandle ctx, SafeTensorHandleHandle a) | |||
{ | |||
using var status = TF_NewStatus(); | |||