@@ -116,6 +116,9 @@ namespace Tensorflow | |||||
public Tensor fill<T>(Tensor dims, T value, string name = null) | public Tensor fill<T>(Tensor dims, T value, string name = null) | ||||
=> gen_array_ops.fill(dims, value, name: name); | => gen_array_ops.fill(dims, value, name: name); | ||||
public Tensor fill<T>(Shape dims, T value, string name = null) | |||||
=> array_ops.fill(dims, value, name: name); | |||||
/// <summary> | /// <summary> | ||||
/// Return a tensor with the same shape and contents as input. | /// Return a tensor with the same shape and contents as input. | ||||
/// </summary> | /// </summary> | ||||
@@ -68,12 +68,12 @@ namespace Tensorflow.Eager | |||||
Tensor[] op_inputs, | Tensor[] op_inputs, | ||||
object[] attrs, | object[] attrs, | ||||
Tensor[] op_outputs) | Tensor[] op_outputs) | ||||
=> (output_grads, unneeded_gradients) => | |||||
=> (out_grads, unneeded_gradients) => | |||||
{ | { | ||||
if (ops.gradientFunctions[op_name] == null) | if (ops.gradientFunctions[op_name] == null) | ||||
return new Tensor[op_inputs.Length]; | return new Tensor[op_inputs.Length]; | ||||
var op = new EagerOperation | |||||
var oper = new EagerOperation | |||||
{ | { | ||||
Name = op_name, | Name = op_name, | ||||
NumInputs = op_inputs.Length, | NumInputs = op_inputs.Length, | ||||
@@ -84,7 +84,40 @@ namespace Tensorflow.Eager | |||||
Attrs = attrs | Attrs = attrs | ||||
}; | }; | ||||
return ops.gradientFunctions[op_name](op, output_grads); | |||||
/*return op_name switch | |||||
{ | |||||
"Add" => math_grad._AddGrad(oper, out_grads), | |||||
"AddV2" => math_grad._AddV2Grad(oper, out_grads), | |||||
"BiasAdd" => nn_grad._BiasAddGrad(oper, out_grads), | |||||
"Cast" => math_grad._CastGrad(oper, out_grads), | |||||
"ConcatV2" => array_grad._ConcatV2Grad(oper, out_grads), | |||||
"Conv2D" => nn_grad._Conv2DGrad(oper, out_grads), | |||||
"ExpandDims" => array_grad._ExpandDimsGrad(oper, out_grads), | |||||
"Exp" => math_grad._ExpGrad(oper, out_grads), | |||||
"FusedBatchNormV3" => nn_grad._FusedBatchNormV3Grad(oper, out_grads), | |||||
"Id" => math_grad._IdGrad(oper, out_grads), | |||||
"LeakyRelu" => nn_grad._LeakyReluGrad(oper, out_grads), | |||||
"Log1p" => math_grad._Log1pGrad(oper, out_grads), | |||||
"Maximum" => math_grad._MaximumGrad(oper, out_grads), | |||||
"Mean" => math_grad._MeanGrad(oper, out_grads), | |||||
"Minimum" => math_grad._MinimumGrad(oper, out_grads), | |||||
"Mul" => math_grad._MulGrad(oper, out_grads), | |||||
"Neg" => math_grad._NegGrad(oper, out_grads), | |||||
"Pad" => array_grad._PadGrad(oper, out_grads), | |||||
"Pow" => math_grad._PowGrad(oper, out_grads), | |||||
"RealDiv" => math_grad._RealDivGrad(oper, out_grads), | |||||
"Read" => resource_variable_grad._ReadGrad(oper, out_grads), | |||||
"Reshape" => array_grad._ReshapeGrad(oper, out_grads), | |||||
"ResizeNearestNeighbor" => image_grad._ResizeNearestNeighborGrad(oper, out_grads), | |||||
"Select" => math_grad._SelectGrad(oper, out_grads), | |||||
"Sigmoid" => math_grad._SigmoidGrad(oper, out_grads), | |||||
"Sum" => math_grad._SumGrad(oper, out_grads), | |||||
"Sub" => math_grad._SubGrad(oper, out_grads), | |||||
"StridedSlice" => array_grad._StridedSliceGrad(oper, out_grads), | |||||
_ => ops.gradientFunctions[op_name](oper, out_grads) | |||||
};*/ | |||||
return ops.gradientFunctions[op_name](oper, out_grads); | |||||
}; | }; | ||||
bool CouldForwardprop() | bool CouldForwardprop() | ||||
@@ -15,7 +15,7 @@ namespace Tensorflow.Eager | |||||
/// </summary> | /// </summary> | ||||
public partial class EagerRunner | public partial class EagerRunner | ||||
{ | { | ||||
UnorderedMap<string, SafeOpHandle> thread_local_eager_operation_map = new UnorderedMap<string, SafeOpHandle>(); | |||||
UnorderedMap<string, SafeEagerOpHandle> thread_local_eager_operation_map = new UnorderedMap<string, SafeEagerOpHandle>(); | |||||
public void ClearEagerOperationMap() | public void ClearEagerOperationMap() | ||||
=> thread_local_eager_operation_map.Clear(); | => thread_local_eager_operation_map.Clear(); | ||||
@@ -157,7 +157,7 @@ namespace Tensorflow.Eager | |||||
return flat_result; | return flat_result; | ||||
} | } | ||||
SafeOpHandle GetOp(Context ctx, string op_or_function_name, Status status) | |||||
SafeEagerOpHandle GetOp(Context ctx, string op_or_function_name, Status status) | |||||
{ | { | ||||
if (thread_local_eager_operation_map.find(op_or_function_name, out var op)) | if (thread_local_eager_operation_map.find(op_or_function_name, out var op)) | ||||
c_api.TFE_OpReset(op, op_or_function_name, ctx.DeviceName, status.Handle); | c_api.TFE_OpReset(op, op_or_function_name, ctx.DeviceName, status.Handle); | ||||
@@ -205,7 +205,7 @@ namespace Tensorflow.Eager | |||||
ArgDef input_arg, | ArgDef input_arg, | ||||
List<object> flattened_attrs, | List<object> flattened_attrs, | ||||
List<Tensor> flattened_inputs, | List<Tensor> flattened_inputs, | ||||
SafeOpHandle op, | |||||
SafeEagerOpHandle op, | |||||
Status status) | Status status) | ||||
{ | { | ||||
var tensor = tf.convert_to_tensor(inputs); | var tensor = tf.convert_to_tensor(inputs); | ||||
@@ -225,7 +225,7 @@ namespace Tensorflow.Eager | |||||
return true; | return true; | ||||
} | } | ||||
public void SetOpAttrs(SafeOpHandle op, params object[] attrs) | |||||
public void SetOpAttrs(SafeEagerOpHandle op, params object[] attrs) | |||||
{ | { | ||||
var status = tf.Status; | var status = tf.Status; | ||||
var len = attrs.Length; | var len = attrs.Length; | ||||
@@ -258,7 +258,7 @@ namespace Tensorflow.Eager | |||||
/// <param name="attr_value"></param> | /// <param name="attr_value"></param> | ||||
/// <param name="attr_list_sizes"></param> | /// <param name="attr_list_sizes"></param> | ||||
/// <param name="status"></param> | /// <param name="status"></param> | ||||
void SetOpAttrWithDefaults(Context ctx, SafeOpHandle op, AttrDef attr, | |||||
void SetOpAttrWithDefaults(Context ctx, SafeEagerOpHandle op, AttrDef attr, | |||||
string attr_name, object attr_value, | string attr_name, object attr_value, | ||||
Dictionary<string, long> attr_list_sizes, | Dictionary<string, long> attr_list_sizes, | ||||
Status status) | Status status) | ||||
@@ -280,7 +280,7 @@ namespace Tensorflow.Eager | |||||
} | } | ||||
} | } | ||||
bool SetOpAttrList(Context ctx, SafeOpHandle op, | |||||
bool SetOpAttrList(Context ctx, SafeEagerOpHandle op, | |||||
string key, object values, TF_AttrType type, | string key, object values, TF_AttrType type, | ||||
Dictionary<string, long> attr_list_sizes, | Dictionary<string, long> attr_list_sizes, | ||||
Status status) | Status status) | ||||
@@ -326,7 +326,7 @@ namespace Tensorflow.Eager | |||||
return true; | return true; | ||||
} | } | ||||
bool SetOpAttrScalar(Context ctx, SafeOpHandle op, | |||||
bool SetOpAttrScalar(Context ctx, SafeEagerOpHandle op, | |||||
string key, object value, TF_AttrType type, | string key, object value, TF_AttrType type, | ||||
Dictionary<string, long> attr_list_sizes, | Dictionary<string, long> attr_list_sizes, | ||||
Status status) | Status status) | ||||
@@ -16,18 +16,17 @@ | |||||
using System; | using System; | ||||
using Tensorflow.Util; | using Tensorflow.Util; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Eager | namespace Tensorflow.Eager | ||||
{ | { | ||||
public sealed class SafeOpHandle : SafeTensorflowHandle | |||||
public sealed class SafeEagerOpHandle : SafeTensorflowHandle | |||||
{ | { | ||||
private SafeOpHandle() | |||||
private SafeEagerOpHandle() | |||||
{ | { | ||||
} | } | ||||
public SafeOpHandle(IntPtr handle) | |||||
public SafeEagerOpHandle(IntPtr handle) | |||||
: base(handle) | : base(handle) | ||||
{ | { | ||||
@@ -59,7 +59,7 @@ namespace Tensorflow | |||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern TF_AttrType TFE_OpGetAttrType(SafeOpHandle op, string attr_name, ref byte is_list, SafeStatusHandle status); | |||||
public static extern TF_AttrType TFE_OpGetAttrType(SafeEagerOpHandle op, string attr_name, ref byte is_list, SafeStatusHandle status); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern TF_AttrType TFE_OpNameGetAttrType(SafeContextHandle ctx, string op_or_function_name, string attr_name, ref byte is_list, SafeStatusHandle status); | public static extern TF_AttrType TFE_OpNameGetAttrType(SafeContextHandle ctx, string op_or_function_name, string attr_name, ref byte is_list, SafeStatusHandle status); | ||||
@@ -72,7 +72,7 @@ namespace Tensorflow | |||||
/// <param name="input_name">const char*</param> | /// <param name="input_name">const char*</param> | ||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern int TFE_OpGetInputLength(SafeOpHandle op, string input_name, SafeStatusHandle status); | |||||
public static extern int TFE_OpGetInputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// Returns the length (number of tensors) of the output argument `output_name` | /// Returns the length (number of tensors) of the output argument `output_name` | ||||
@@ -83,7 +83,7 @@ namespace Tensorflow | |||||
/// <param name="status"></param> | /// <param name="status"></param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern int TFE_OpGetOutputLength(SafeOpHandle op, string input_name, SafeStatusHandle status); | |||||
public static extern int TFE_OpGetOutputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -94,7 +94,7 @@ namespace Tensorflow | |||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern int TFE_OpAddInputList(SafeOpHandle op, [In, MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(SafeHandleArrayMarshaler))] SafeTensorHandleHandle[] inputs, int num_inputs, SafeStatusHandle status); | |||||
public static extern int TFE_OpAddInputList(SafeEagerOpHandle op, [In, MarshalAs(UnmanagedType.CustomMarshaler, MarshalTypeRef = typeof(SafeHandleArrayMarshaler))] SafeTensorHandleHandle[] inputs, int num_inputs, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -161,7 +161,7 @@ namespace Tensorflow | |||||
/// <param name="retvals"></param> | /// <param name="retvals"></param> | ||||
/// <param name="num_retvals"></param> | /// <param name="num_retvals"></param> | ||||
/// <param name="status"></param> | /// <param name="status"></param> | ||||
public static void TFE_Execute(SafeOpHandle op, SafeTensorHandleHandle[] retvals, out int num_retvals, SafeStatusHandle status) | |||||
public static void TFE_Execute(SafeEagerOpHandle op, SafeTensorHandleHandle[] retvals, out int num_retvals, SafeStatusHandle status) | |||||
{ | { | ||||
unsafe | unsafe | ||||
{ | { | ||||
@@ -187,7 +187,7 @@ namespace Tensorflow | |||||
/// <param name="num_retvals">int*</param> | /// <param name="num_retvals">int*</param> | ||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
private static unsafe extern void TFE_Execute(SafeOpHandle op, IntPtr* retvals, ref int num_retvals, SafeStatusHandle status); | |||||
private static unsafe extern void TFE_Execute(SafeEagerOpHandle op, IntPtr* retvals, ref int num_retvals, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -197,7 +197,7 @@ namespace Tensorflow | |||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
/// <returns></returns> | /// <returns></returns> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern SafeOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status); | |||||
public static extern SafeEagerOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This | /// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This | ||||
@@ -213,7 +213,7 @@ namespace Tensorflow | |||||
/// <param name="raw_device_name">const char*</param> | /// <param name="raw_device_name">const char*</param> | ||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_OpReset(SafeOpHandle op_to_reset, string op_or_function_name, string raw_device_name, SafeStatusHandle status); | |||||
public static extern void TFE_OpReset(SafeEagerOpHandle op_to_reset, string op_or_function_name, string raw_device_name, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -229,13 +229,13 @@ namespace Tensorflow | |||||
/// <param name="attr_name">const char*</param> | /// <param name="attr_name">const char*</param> | ||||
/// <param name="value">TF_DataType</param> | /// <param name="value">TF_DataType</param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_OpSetAttrType(SafeOpHandle op, string attr_name, TF_DataType value); | |||||
public static extern void TFE_OpSetAttrType(SafeEagerOpHandle op, string attr_name, TF_DataType value); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_OpSetAttrInt(SafeOpHandle op, string attr_name, long value); | |||||
public static extern void TFE_OpSetAttrInt(SafeEagerOpHandle op, string attr_name, long value); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_OpSetAttrFloat(SafeOpHandle op, string attr_name, float value); | |||||
public static extern void TFE_OpSetAttrFloat(SafeEagerOpHandle op, string attr_name, float value); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -246,19 +246,19 @@ namespace Tensorflow | |||||
/// <param name="num_dims">const int</param> | /// <param name="num_dims">const int</param> | ||||
/// <param name="out_status">TF_Status*</param> | /// <param name="out_status">TF_Status*</param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_OpSetAttrShape(SafeOpHandle op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status); | |||||
public static extern void TFE_OpSetAttrShape(SafeEagerOpHandle op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_OpSetAttrShapeList(SafeOpHandle op, string attr_name, IntPtr[] dims, int[] num_dims, int num_values, SafeStatusHandle out_status); | |||||
public static extern void TFE_OpSetAttrShapeList(SafeEagerOpHandle op, string attr_name, IntPtr[] dims, int[] num_dims, int num_values, SafeStatusHandle out_status); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_OpSetAttrStringList(SafeOpHandle op, string attr_name, string[] values, ulong[] lengths, int num_values); | |||||
public static extern void TFE_OpSetAttrStringList(SafeEagerOpHandle op, string attr_name, string[] values, ulong[] lengths, int num_values); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_OpSetAttrBool(SafeOpHandle op, string attr_name, bool value); | |||||
public static extern void TFE_OpSetAttrBool(SafeEagerOpHandle op, string attr_name, bool value); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_OpSetAttrFunctionName(SafeOpHandle op, string attr_name, string data, int length); | |||||
public static extern void TFE_OpSetAttrFunctionName(SafeEagerOpHandle op, string attr_name, string data, int length); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -268,16 +268,16 @@ namespace Tensorflow | |||||
/// <param name="value">const void*</param> | /// <param name="value">const void*</param> | ||||
/// <param name="length">size_t</param> | /// <param name="length">size_t</param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_OpSetAttrString(SafeOpHandle op, string attr_name, string value, ulong length); | |||||
public static extern void TFE_OpSetAttrString(SafeEagerOpHandle op, string attr_name, string value, ulong length); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_OpSetAttrTypeList(SafeOpHandle op, string attr_name, TF_DataType[] values, int num_values); | |||||
public static extern void TFE_OpSetAttrTypeList(SafeEagerOpHandle op, string attr_name, TF_DataType[] values, int num_values); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_OpSetAttrIntList(SafeOpHandle op, string attr_name, long[] values, int num_values); | |||||
public static extern void TFE_OpSetAttrIntList(SafeEagerOpHandle op, string attr_name, long[] values, int num_values); | |||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_OpSetAttrValueProto(SafeOpHandle op, string attr_name, IMessage[] proto, int proto_len, SafeStatusHandle status); | |||||
public static extern void TFE_OpSetAttrValueProto(SafeEagerOpHandle op, string attr_name, IMessage[] proto, int proto_len, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -286,7 +286,7 @@ namespace Tensorflow | |||||
/// <param name="device_name"></param> | /// <param name="device_name"></param> | ||||
/// <param name="status"></param> | /// <param name="status"></param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_OpSetDevice(SafeOpHandle op, string device_name, SafeStatusHandle status); | |||||
public static extern void TFE_OpSetDevice(SafeEagerOpHandle op, string device_name, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -295,7 +295,7 @@ namespace Tensorflow | |||||
/// <param name="h">TFE_TensorHandle*</param> | /// <param name="h">TFE_TensorHandle*</param> | ||||
/// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
[DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
public static extern void TFE_OpAddInput(SafeOpHandle op, SafeTensorHandleHandle h, SafeStatusHandle status); | |||||
public static extern void TFE_OpAddInput(SafeEagerOpHandle op, SafeTensorHandleHandle h, SafeStatusHandle status); | |||||
/// <summary> | /// <summary> | ||||
/// | /// | ||||
@@ -51,7 +51,7 @@ namespace Tensorflow.Gradients | |||||
} | } | ||||
[RegisterGradient("ConcatV2")] | [RegisterGradient("ConcatV2")] | ||||
public static Tensor[] _ConcatGradV2(Operation op, Tensor[] grads) | |||||
public static Tensor[] _ConcatV2Grad(Operation op, Tensor[] grads) | |||||
{ | { | ||||
var grad = grads[0]; | var grad = grads[0]; | ||||
return _ConcatGradHelper(op, grad, start_value_index: 0, end_value_index: -1, dim_index: -1); | return _ConcatGradHelper(op, grad, start_value_index: 0, end_value_index: -1, dim_index: -1); | ||||
@@ -50,42 +50,11 @@ namespace Tensorflow | |||||
{ | { | ||||
// tf.Logger.Debug($"Caculate Gradient: {oper.name} {m.Name}"); | // tf.Logger.Debug($"Caculate Gradient: {oper.name} {m.Name}"); | ||||
var results = m.Name switch | |||||
{ | |||||
/*"_AddGrad" => math_grad._AddGrad(oper, out_grads), | |||||
"_AddV2Grad" => math_grad._AddV2Grad(oper, out_grads), | |||||
"_BiasAddGrad" => nn_grad._BiasAddGrad(oper, out_grads), | |||||
"_CastGrad" => math_grad._CastGrad(oper, out_grads), | |||||
"_ConcatGradV2" => array_grad._ConcatGradV2(oper, out_grads), | |||||
"_Conv2DGrad" => nn_grad._Conv2DGrad(oper, out_grads), | |||||
"_ExpandDimsGrad" => array_grad._ExpandDimsGrad(oper, out_grads), | |||||
"_ExpGrad" => math_grad._ExpGrad(oper, out_grads), | |||||
"_FusedBatchNormV3Grad" => nn_grad._FusedBatchNormV3Grad(oper, out_grads), | |||||
"_IdGrad" => math_grad._IdGrad(oper, out_grads), | |||||
"_LeakyReluGrad" => nn_grad._LeakyReluGrad(oper, out_grads), | |||||
"_Log1pGrad" => math_grad._Log1pGrad(oper, out_grads), | |||||
"_MaximumGrad" => math_grad._MaximumGrad(oper, out_grads), | |||||
"_MeanGrad" => math_grad._MeanGrad(oper, out_grads), | |||||
"_MinimumGrad" => math_grad._MinimumGrad(oper, out_grads), | |||||
"_MulGrad" => math_grad._MulGrad(oper, out_grads), | |||||
"_NegGrad" => math_grad._NegGrad(oper, out_grads), | |||||
"_PadGrad" => array_grad._PadGrad(oper, out_grads), | |||||
"_PowGrad" => math_grad._PowGrad(oper, out_grads), | |||||
"_RealDivGrad" => math_grad._RealDivGrad(oper, out_grads), | |||||
"_ReadGrad" => resource_variable_grad._ReadGrad(oper, out_grads), | |||||
"_ReshapeGrad" => array_grad._ReshapeGrad(oper, out_grads), | |||||
"_ResizeNearestNeighborGrad" => image_grad._ResizeNearestNeighborGrad(oper, out_grads), | |||||
"_SelectGrad" => math_grad._SelectGrad(oper, out_grads), | |||||
"_SigmoidGrad" => math_grad._SigmoidGrad(oper, out_grads), | |||||
"_SumGrad" => math_grad._SumGrad(oper, out_grads), | |||||
"_SubGrad" => math_grad._SubGrad(oper, out_grads), | |||||
"_StridedSliceGrad" => array_grad._StridedSliceGrad(oper, out_grads),*/ | |||||
_ => g.InvokeMember(m.Name, | |||||
BindingFlags.InvokeMethod, | |||||
null, | |||||
null, | |||||
args: new object[] { oper, out_grads }) as Tensor[] | |||||
}; | |||||
var results = g.InvokeMember(m.Name, | |||||
BindingFlags.InvokeMethod, | |||||
null, | |||||
null, | |||||
args: new object[] { oper, out_grads }) as Tensor[]; | |||||
// foreach (var result in results.Where(x => x != null)) | // foreach (var result in results.Where(x => x != null)) | ||||
// tf.Logger.Debug($"Gradient: {result.name} {result.shape}"); | // tf.Logger.Debug($"Gradient: {result.name} {result.shape}"); | ||||
@@ -39,7 +39,7 @@ namespace Tensorflow | |||||
public void _add_control_input(Operation op) | public void _add_control_input(Operation op) | ||||
{ | { | ||||
c_api.TF_AddControlInput(OpDesc, op); | |||||
// c_api.TF_AddControlInput(_opDesc, op); | |||||
//c_api.AddControlInput(graph, _handle, op); | //c_api.AddControlInput(graph, _handle, op); | ||||
} | } | ||||
@@ -46,7 +46,6 @@ namespace Tensorflow | |||||
private readonly IntPtr _handle; // _c_op in python | private readonly IntPtr _handle; // _c_op in python | ||||
private readonly Graph _graph; | private readonly Graph _graph; | ||||
private NodeDef _node_def; | |||||
public string type => OpType; | public string type => OpType; | ||||
@@ -57,24 +56,14 @@ namespace Tensorflow | |||||
public int _id_value { get; set; } | public int _id_value { get; set; } | ||||
public Operation op => this; | public Operation op => this; | ||||
public TF_DataType dtype => TF_DataType.DtInvalid; | public TF_DataType dtype => TF_DataType.DtInvalid; | ||||
public virtual string name => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationName(_handle)); | |||||
public string OpType => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | |||||
public virtual string name => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationName(_handle)); | |||||
public string OpType => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationOpType(_handle)); | |||||
public string Device => _handle == IntPtr.Zero ? null : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | |||||
public string Device => _handle == IntPtr.Zero ? "" : c_api.StringPiece(c_api.TF_OperationDevice(_handle)); | |||||
bool _is_stateful; | |||||
public OperationDescription OpDesc { get; set; } | |||||
// OperationDescription _opDesc; | |||||
public NodeDef node_def | |||||
{ | |||||
get | |||||
{ | |||||
if (_node_def == null) | |||||
_node_def = GetNodeDef(); | |||||
return _node_def; | |||||
} | |||||
} | |||||
public NodeDef node_def => GetNodeDef(); | |||||
public Operation(IntPtr handle, Graph g = null) | public Operation(IntPtr handle, Graph g = null) | ||||
{ | { | ||||
@@ -168,8 +157,7 @@ namespace Tensorflow | |||||
if (op_def == null) | if (op_def == null) | ||||
op_def = g.GetOpDef(node_def.Op); | op_def = g.GetOpDef(node_def.Op); | ||||
(_handle, OpDesc) = ops._create_c_op(g, node_def, inputs, control_input_ops.ToArray(), op_def); | |||||
_is_stateful = op_def.IsStateful; | |||||
(_handle, _) = ops._create_c_op(g, node_def, inputs, control_input_ops.ToArray(), op_def); | |||||
// Initialize self._outputs. | // Initialize self._outputs. | ||||
output_types = new TF_DataType[NumOutputs]; | output_types = new TF_DataType[NumOutputs]; | ||||
@@ -199,16 +187,11 @@ namespace Tensorflow | |||||
if (tf.executing_eagerly()) | if (tf.executing_eagerly()) | ||||
return (T[])get_attr(name); | return (T[])get_attr(name); | ||||
AttrValue x = null; | |||||
using var buf = new Buffer(); | |||||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf.Handle, tf.Status.Handle); | |||||
tf.Status.Check(true); | |||||
lock (Locks.ProcessWide) | |||||
{ | |||||
using var buf = new Buffer(); | |||||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf.Handle, tf.Status.Handle); | |||||
tf.Status.Check(true); | |||||
x = AttrValue.Parser.ParseFrom(buf.ToArray()); | |||||
} | |||||
var x = AttrValue.Parser.ParseFrom(buf.ToArray()); | |||||
string oneof_value = x.ValueCase.ToString(); | string oneof_value = x.ValueCase.ToString(); | ||||
if (string.IsNullOrEmpty(oneof_value)) | if (string.IsNullOrEmpty(oneof_value)) | ||||
@@ -227,16 +210,11 @@ namespace Tensorflow | |||||
public virtual object get_attr(string name) | public virtual object get_attr(string name) | ||||
{ | { | ||||
AttrValue x = null; | |||||
lock (Locks.ProcessWide) | |||||
{ | |||||
using var buf = new Buffer(); | |||||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf.Handle, tf.Status.Handle); | |||||
tf.Status.Check(true); | |||||
using var buf = new Buffer(); | |||||
c_api.TF_OperationGetAttrValueProto(_handle, name, buf.Handle, tf.Status.Handle); | |||||
tf.Status.Check(true); | |||||
x = AttrValue.Parser.ParseFrom(buf.ToArray()); | |||||
} | |||||
var x = AttrValue.Parser.ParseFrom(buf.ToArray()); | |||||
string oneof_value = x.ValueCase.ToString(); | string oneof_value = x.ValueCase.ToString(); | ||||
if (string.IsNullOrEmpty(oneof_value)) | if (string.IsNullOrEmpty(oneof_value)) | ||||
@@ -262,15 +240,10 @@ namespace Tensorflow | |||||
private NodeDef GetNodeDef() | private NodeDef GetNodeDef() | ||||
{ | { | ||||
lock (Locks.ProcessWide) | |||||
using (var s = new Status()) | |||||
using (var buffer = new Buffer()) | |||||
{ | |||||
c_api.TF_OperationToNodeDef(_handle, buffer.Handle, s.Handle); | |||||
s.Check(); | |||||
return NodeDef.Parser.ParseFrom(buffer.ToArray()); | |||||
} | |||||
using var buffer = new Buffer(); | |||||
c_api.TF_OperationToNodeDef(_handle, buffer.Handle, tf.Status.Handle); | |||||
tf.Status.Check(throwException: true); | |||||
return NodeDef.Parser.ParseFrom(buffer.ToArray()); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -284,21 +257,21 @@ namespace Tensorflow | |||||
{ | { | ||||
_assert_same_graph(tensor); | _assert_same_graph(tensor); | ||||
var input = _tf_input(index); | |||||
var output = tensor._as_tf_output(); | |||||
// var input = _tf_input(index); | |||||
// var output = tensor._as_tf_output(); | |||||
// Reset cached inputs. | // Reset cached inputs. | ||||
_inputs_val = null; | _inputs_val = null; | ||||
_node_def = null; | |||||
// _node_def = null; | |||||
// after the c_api call next time _inputs is accessed | // after the c_api call next time _inputs is accessed | ||||
// the updated inputs are reloaded from the c_api | // the updated inputs are reloaded from the c_api | ||||
lock (Locks.ProcessWide) | |||||
{ | |||||
// lock (Locks.ProcessWide) | |||||
// { | |||||
// disable | // disable | ||||
// c_api.TF_UpdateEdge(_graph, output, input, tf.Status.Handle); | // c_api.TF_UpdateEdge(_graph, output, input, tf.Status.Handle); | ||||
//var updated_inputs = inputs; | //var updated_inputs = inputs; | ||||
tf.Status.Check(); | |||||
} | |||||
// tf.Status.Check(); | |||||
// } | |||||
} | } | ||||
private void _assert_same_graph(Tensor tensor) | private void _assert_same_graph(Tensor tensor) | ||||
@@ -311,7 +284,7 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
public TF_Output _tf_output(int output_idx) | public TF_Output _tf_output(int output_idx) | ||||
{ | { | ||||
return new TF_Output(op, output_idx); | |||||
return new TF_Output(_handle, output_idx); | |||||
} | } | ||||
/// <summary> | /// <summary> | ||||
@@ -319,7 +292,7 @@ namespace Tensorflow | |||||
/// </summary> | /// </summary> | ||||
public TF_Input _tf_input(int input_idx) | public TF_Input _tf_input(int input_idx) | ||||
{ | { | ||||
return new TF_Input(op, input_idx); | |||||
return new TF_Input(_handle, input_idx); | |||||
} | } | ||||
public NDArray numpy() => throw new NotImplementedException(""); | public NDArray numpy() => throw new NotImplementedException(""); | ||||
@@ -80,27 +80,16 @@ namespace Tensorflow | |||||
return tf_with(ops.name_scope(name, "zeros", shape), scope => | return tf_with(ops.name_scope(name, "zeros", shape), scope => | ||||
{ | { | ||||
name = scope; | name = scope; | ||||
var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); | |||||
Tensor zeros = null; | |||||
switch (dtype) | |||||
// var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); | |||||
Tensor zeros = dtype switch | |||||
{ | { | ||||
case TF_DataType.TF_DOUBLE: | |||||
zeros = constant(0d); | |||||
break; | |||||
case TF_DataType.TF_FLOAT: | |||||
zeros = constant(0f); | |||||
break; | |||||
case TF_DataType.TF_INT8: | |||||
zeros = constant((sbyte)0); | |||||
break; | |||||
case TF_DataType.TF_UINT8: | |||||
zeros = constant((byte)0); | |||||
break; | |||||
default: | |||||
zeros = constant(0); | |||||
break; | |||||
} | |||||
return fill(shape_tensor, zeros, name: name); | |||||
TF_DataType.TF_DOUBLE => constant(0d), | |||||
TF_DataType.TF_FLOAT => constant(0f), | |||||
TF_DataType.TF_INT8 => constant((sbyte)0), | |||||
TF_DataType.TF_UINT8 => constant((byte)0), | |||||
_ => constant(0) | |||||
}; | |||||
return fill(shape, zeros, name: name); | |||||
}); | }); | ||||
} | } | ||||
else | else | ||||
@@ -311,12 +300,8 @@ namespace Tensorflow | |||||
/// <param name="value">A value to fill the returned `tf.Tensor`.</param> | /// <param name="value">A value to fill the returned `tf.Tensor`.</param> | ||||
/// <param name="name">Optional string. The name of the output `tf.Tensor`.</param> | /// <param name="name">Optional string. The name of the output `tf.Tensor`.</param> | ||||
/// <returns>A `tf.Tensor` with shape `dims` and the same dtype as `value`.</returns> | /// <returns>A `tf.Tensor` with shape `dims` and the same dtype as `value`.</returns> | ||||
public static Tensor fill(Tensor dims, Tensor value, string name = null) | |||||
{ | |||||
var result = gen_array_ops.fill(dims, value, name: name); | |||||
// tensor_util.maybe_set_static_shape(result, dims) | |||||
return result; | |||||
} | |||||
public static Tensor fill<T>(Shape dims, T value, string name = null) | |||||
=> gen_array_ops.fill(dims, value, name: name); | |||||
/// <summary> | /// <summary> | ||||
/// Returns the rank of a tensor. | /// Returns the rank of a tensor. | ||||
@@ -425,25 +410,18 @@ namespace Tensorflow | |||||
dtype = dtype.as_base_dtype(); | dtype = dtype.as_base_dtype(); | ||||
name = scope; | name = scope; | ||||
Tensor ones = null; | |||||
switch (dtype) | |||||
Tensor ones = dtype switch | |||||
{ | { | ||||
case TF_DataType.TF_DOUBLE: | |||||
ones = constant(1.0d); | |||||
break; | |||||
case TF_DataType.TF_FLOAT: | |||||
ones = constant(1.0f); | |||||
break; | |||||
default: | |||||
ones = constant(1); | |||||
break; | |||||
} | |||||
TF_DataType.TF_DOUBLE => constant(1.0d), | |||||
TF_DataType.TF_FLOAT => constant(1.0f), | |||||
_ => constant(1) | |||||
}; | |||||
if (shape.ndim == 0) | if (shape.ndim == 0) | ||||
return ones; | return ones; | ||||
var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); | |||||
return fill(shape_tensor, ones, name: name); | |||||
// var shape_tensor = constant_op._tensor_shape_tensor_conversion_function(shape); | |||||
return fill(shape, ones, name: name); | |||||
}); | }); | ||||
public static Tensor one_hot(Tensor indices, Tensor depth, | public static Tensor one_hot(Tensor indices, Tensor depth, | ||||
@@ -2086,8 +2086,7 @@ new_height, new_width"); | |||||
gather_idx), | gather_idx), | ||||
new[] { batch_size, -1 }); | new[] { batch_size, -1 }); | ||||
} | } | ||||
var invalid_index = array_ops.fill(ops.convert_to_tensor(new object[] { batch_size, max_output_size }), | |||||
tf.constant(0)); | |||||
var invalid_index = array_ops.fill(new Shape((int)batch_size, (int)max_output_size), 0); | |||||
var idx_index = array_ops.expand_dims(math_ops.range(max_output_size), 0); | var idx_index = array_ops.expand_dims(math_ops.range(max_output_size), 0); | ||||
var num_valid_expanded = array_ops.expand_dims(num_valid, 1); | var num_valid_expanded = array_ops.expand_dims(num_valid, 1); | ||||
idx = array_ops.where(idx_index < num_valid_expanded, | idx = array_ops.where(idx_index < num_valid_expanded, | ||||
@@ -223,42 +223,39 @@ namespace Tensorflow | |||||
var input_tensors = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | var input_tensors = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | ||||
lock (Locks.ProcessWide) | |||||
{ | |||||
var op_desc = graph.NewOperation(node_def.Op, node_def.Name); | |||||
var op_desc = graph.NewOperation(node_def.Op, node_def.Name); | |||||
if (!string.IsNullOrEmpty(node_def.Device)) | |||||
c_api.TF_SetDevice(op_desc, node_def.Device); | |||||
if (!string.IsNullOrEmpty(node_def.Device)) | |||||
c_api.TF_SetDevice(op_desc, node_def.Device); | |||||
// Add inputs | |||||
foreach (var op_input in input_tensors) | |||||
{ | |||||
if (op_input.IsList) | |||||
c_api.TF_AddInputList(op_desc, op_input.Select(x => x._as_tf_output()).ToArray(), op_input.Count()); | |||||
else if (op_input.Count() == 1) | |||||
c_api.TF_AddInput(op_desc, op_input[0]._as_tf_output()); | |||||
} | |||||
// Add inputs | |||||
foreach (var op_input in input_tensors) | |||||
{ | |||||
if (op_input.IsList) | |||||
c_api.TF_AddInputList(op_desc, op_input.Select(x => x._as_tf_output()).ToArray(), op_input.Count()); | |||||
else if (op_input.Count() == 1) | |||||
c_api.TF_AddInput(op_desc, op_input[0]._as_tf_output()); | |||||
} | |||||
var status = tf.Status; | |||||
var status = tf.Status; | |||||
// Add control inputs | |||||
foreach (var control_input in control_inputs) | |||||
c_api.TF_AddControlInput(op_desc, control_input); | |||||
// Add control inputs | |||||
foreach (var control_input in control_inputs) | |||||
c_api.TF_AddControlInput(op_desc, control_input); | |||||
// Add attrs | |||||
foreach (var attr in node_def.Attr) | |||||
{ | |||||
var bytes = attr.Value.ToByteArray(); | |||||
c_api.TF_SetAttrValueProto(op_desc, attr.Key, bytes, proto_len: bytes.Length, status: status.Handle); | |||||
status.Check(true); | |||||
} | |||||
// Add attrs | |||||
foreach (var attr in node_def.Attr) | |||||
{ | |||||
var bytes = attr.Value.ToByteArray(); | |||||
c_api.TF_SetAttrValueProto(op_desc, attr.Key, bytes, proto_len: bytes.Length, status: status.Handle); | |||||
status.Check(true); | |||||
} | |||||
var c_op = c_api.TF_FinishOperation(op_desc, status.Handle); | |||||
var c_op = op_desc.FinishOperation(status); | |||||
status.Check(true); | |||||
status.Check(true); | |||||
return (c_op, op_desc); | |||||
} | |||||
return (c_op, op_desc); | |||||
} | } | ||||
public static Tensors[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField<string, AttrValue> attrs) | public static Tensors[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField<string, AttrValue> attrs) | ||||
@@ -101,9 +101,9 @@ namespace Tensorflow.Keras.Engine | |||||
var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}")); | var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}")); | ||||
Binding.tf_output_redirect.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}"); | Binding.tf_output_redirect.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}"); | ||||
} | } | ||||
} | |||||
GC.Collect(); | |||||
GC.Collect(); | |||||
} | |||||
GC.WaitForPendingFinalizers(); | GC.WaitForPendingFinalizers(); | ||||
} | } | ||||
} | } | ||||
@@ -5,18 +5,23 @@ using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow.Keras.Layers { | |||||
public class Tanh : Layer { | |||||
public Tanh ( LayerArgs args ) : base(args) { | |||||
// Tanh has no arguments | |||||
} | |||||
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | |||||
Tensor x = inputs; | |||||
namespace Tensorflow.Keras.Layers | |||||
{ | |||||
public class Tanh : Layer | |||||
{ | |||||
public Tanh(LayerArgs args) : base(args) | |||||
{ | |||||
// Tanh has no arguments | |||||
} | |||||
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | |||||
{ | |||||
Tensor x = inputs; | |||||
return tf.tanh(x); | |||||
} | |||||
public override Shape ComputeOutputShape ( Shape input_shape ) { | |||||
return input_shape; | |||||
} | |||||
} | |||||
return tf.tanh(x); | |||||
} | |||||
public override Shape ComputeOutputShape(Shape input_shape) | |||||
{ | |||||
return input_shape; | |||||
} | |||||
} | |||||
} | } |
@@ -80,25 +80,25 @@ namespace Tensorflow.Native.UnitTest | |||||
protected ulong TF_TensorByteSize(SafeTensorHandle t) | protected ulong TF_TensorByteSize(SafeTensorHandle t) | ||||
=> c_api.TF_TensorByteSize(t); | => c_api.TF_TensorByteSize(t); | ||||
protected void TFE_OpAddInput(SafeOpHandle op, SafeTensorHandleHandle h, SafeStatusHandle status) | |||||
protected void TFE_OpAddInput(SafeEagerOpHandle op, SafeTensorHandleHandle h, SafeStatusHandle status) | |||||
=> c_api.TFE_OpAddInput(op, h, status); | => c_api.TFE_OpAddInput(op, h, status); | ||||
protected void TFE_OpSetAttrType(SafeOpHandle op, string attr_name, TF_DataType value) | |||||
protected void TFE_OpSetAttrType(SafeEagerOpHandle op, string attr_name, TF_DataType value) | |||||
=> c_api.TFE_OpSetAttrType(op, attr_name, value); | => c_api.TFE_OpSetAttrType(op, attr_name, value); | ||||
protected void TFE_OpSetAttrShape(SafeOpHandle op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status) | |||||
protected void TFE_OpSetAttrShape(SafeEagerOpHandle op, string attr_name, long[] dims, int num_dims, SafeStatusHandle out_status) | |||||
=> c_api.TFE_OpSetAttrShape(op, attr_name, dims, num_dims, out_status); | => c_api.TFE_OpSetAttrShape(op, attr_name, dims, num_dims, out_status); | ||||
protected void TFE_OpSetAttrString(SafeOpHandle op, string attr_name, string value, uint length) | |||||
protected void TFE_OpSetAttrString(SafeEagerOpHandle op, string attr_name, string value, uint length) | |||||
=> c_api.TFE_OpSetAttrString(op, attr_name, value, length); | => c_api.TFE_OpSetAttrString(op, attr_name, value, length); | ||||
protected SafeOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status) | |||||
protected SafeEagerOpHandle TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status) | |||||
=> c_api.TFE_NewOp(ctx, op_or_function_name, status); | => c_api.TFE_NewOp(ctx, op_or_function_name, status); | ||||
protected SafeTensorHandleHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status) | protected SafeTensorHandleHandle TFE_NewTensorHandle(SafeTensorHandle t, SafeStatusHandle status) | ||||
=> c_api.TFE_NewTensorHandle(t, status); | => c_api.TFE_NewTensorHandle(t, status); | ||||
protected void TFE_Execute(SafeOpHandle op, SafeTensorHandleHandle[] retvals, out int num_retvals, SafeStatusHandle status) | |||||
protected void TFE_Execute(SafeEagerOpHandle op, SafeTensorHandleHandle[] retvals, out int num_retvals, SafeStatusHandle status) | |||||
=> c_api.TFE_Execute(op, retvals, out num_retvals, status); | => c_api.TFE_Execute(op, retvals, out num_retvals, status); | ||||
protected SafeContextOptionsHandle TFE_NewContextOptions() | protected SafeContextOptionsHandle TFE_NewContextOptions() | ||||
@@ -107,13 +107,13 @@ namespace Tensorflow.Native.UnitTest | |||||
protected SafeContextHandle TFE_NewContext(SafeContextOptionsHandle opts, SafeStatusHandle status) | protected SafeContextHandle TFE_NewContext(SafeContextOptionsHandle opts, SafeStatusHandle status) | ||||
=> c_api.TFE_NewContext(opts, status); | => c_api.TFE_NewContext(opts, status); | ||||
protected int TFE_OpGetInputLength(SafeOpHandle op, string input_name, SafeStatusHandle status) | |||||
protected int TFE_OpGetInputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status) | |||||
=> c_api.TFE_OpGetInputLength(op, input_name, status); | => c_api.TFE_OpGetInputLength(op, input_name, status); | ||||
protected int TFE_OpAddInputList(SafeOpHandle op, SafeTensorHandleHandle[] inputs, int num_inputs, SafeStatusHandle status) | |||||
protected int TFE_OpAddInputList(SafeEagerOpHandle op, SafeTensorHandleHandle[] inputs, int num_inputs, SafeStatusHandle status) | |||||
=> c_api.TFE_OpAddInputList(op, inputs, num_inputs, status); | => c_api.TFE_OpAddInputList(op, inputs, num_inputs, status); | ||||
protected int TFE_OpGetOutputLength(SafeOpHandle op, string input_name, SafeStatusHandle status) | |||||
protected int TFE_OpGetOutputLength(SafeEagerOpHandle op, string input_name, SafeStatusHandle status) | |||||
=> c_api.TFE_OpGetOutputLength(op, input_name, status); | => c_api.TFE_OpGetOutputLength(op, input_name, status); | ||||
protected void TFE_DeleteTensorHandle(IntPtr h) | protected void TFE_DeleteTensorHandle(IntPtr h) | ||||
@@ -149,7 +149,7 @@ namespace Tensorflow.Native.UnitTest | |||||
protected SafeTensorHandleHandle TFE_TensorHandleCopyToDevice(SafeTensorHandleHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status) | protected SafeTensorHandleHandle TFE_TensorHandleCopyToDevice(SafeTensorHandleHandle h, SafeContextHandle ctx, string device_name, SafeStatusHandle status) | ||||
=> c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); | => c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); | ||||
protected void TFE_OpSetDevice(SafeOpHandle op, string device_name, SafeStatusHandle status) | |||||
protected void TFE_OpSetDevice(SafeEagerOpHandle op, string device_name, SafeStatusHandle status) | |||||
=> c_api.TFE_OpSetDevice(op, device_name, status); | => c_api.TFE_OpSetDevice(op, device_name, status); | ||||
} | } | ||||
} | } |
@@ -25,7 +25,7 @@ namespace Tensorflow.Native.UnitTest.Eager | |||||
return th; | return th; | ||||
} | } | ||||
SafeOpHandle MatMulOp(SafeContextHandle ctx, SafeTensorHandleHandle a, SafeTensorHandleHandle b) | |||||
SafeEagerOpHandle MatMulOp(SafeContextHandle ctx, SafeTensorHandleHandle a, SafeTensorHandleHandle b) | |||||
{ | { | ||||
using var status = TF_NewStatus(); | using var status = TF_NewStatus(); | ||||
@@ -63,7 +63,7 @@ namespace Tensorflow.Native.UnitTest.Eager | |||||
return false; | return false; | ||||
} | } | ||||
SafeOpHandle ShapeOp(SafeContextHandle ctx, SafeTensorHandleHandle a) | |||||
SafeEagerOpHandle ShapeOp(SafeContextHandle ctx, SafeTensorHandleHandle a) | |||||
{ | { | ||||
using var status = TF_NewStatus(); | using var status = TF_NewStatus(); | ||||