@@ -30,49 +30,19 @@ namespace Tensorflow.Contexts | |||
public sealed partial class Context | |||
{ | |||
// [DebuggerStepThrough] | |||
public T RunInAutoMode<T>(Func<T> graphAction, Func<T> eagerAction, params object[] args) | |||
{ | |||
if (tf.Context.has_graph_arg(args)) | |||
{ | |||
if (executing_eagerly()) | |||
{ | |||
graph_mode(); | |||
var result = graphAction(); | |||
restore_mode(); | |||
return result; | |||
} | |||
else | |||
{ | |||
return graphAction(); | |||
} | |||
} | |||
else | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
return eagerAction(); | |||
} | |||
else | |||
{ | |||
return graphAction(); | |||
} | |||
} | |||
} | |||
// [DebuggerStepThrough] | |||
public Tensors RunInAutoMode2(string OpType, string Name, AutoModeArgs args) | |||
public Tensors ExecuteOp(string OpType, string Name, AutoModeArgs args) | |||
{ | |||
var inputArgs = ConvertToDict(args.OpInputArgs); | |||
var attrDict = ConvertToDict(args.OpAttrs); | |||
Func<Tensor> graphAction = () => | |||
Func<Tensors> graphAction = () => | |||
{ | |||
foreach (var attr in attrDict) | |||
inputArgs[attr.Key] = attr.Value; | |||
return tf.OpDefLib._apply_op_helper(OpType, Name, inputArgs).output; | |||
return tf.OpDefLib._apply_op_helper(OpType, Name, inputArgs).outputs; | |||
}; | |||
Func<Tensor> eagerAction = () => | |||
Func<Tensors> eagerAction = () => | |||
{ | |||
var attrs = new object[attrDict.Count() * 2]; | |||
int i = 0; | |||
@@ -87,7 +57,7 @@ namespace Tensorflow.Contexts | |||
OpType, Name, | |||
null, | |||
inputArgs.Values.ToArray(), | |||
attrs).FirstOrDefault(); | |||
attrs); | |||
}; | |||
if (tf.Context.has_graph_arg(inputArgs.Values)) | |||
@@ -269,29 +269,24 @@ namespace Tensorflow.Operations | |||
} | |||
public static Tensor[] fused_batch_norm_grad_v3(FusedBatchNormParams @params) | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("FusedBatchNormGradV3", name: @params.Name, | |||
args: new | |||
{ | |||
y_backprop = @params.YBackprop, | |||
x = @params.X, | |||
scale = @params.Scale, | |||
reserve_space_1 = @params.ReserveSpace1, | |||
reserve_space_2 = @params.ReserveSpace2, | |||
reserve_space_3 = @params.ReserveSpace3, | |||
epsilon = @params.Epsilon, | |||
data_format = @params.DataFormat, | |||
is_training = @params.IsTraining | |||
}).outputs, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"FusedBatchNormGradV3", @params.Name, | |||
null, | |||
@params.YBackprop, @params.X, @params.Scale, | |||
@params.ReserveSpace1, @params.ReserveSpace2, @params.ReserveSpace3, | |||
"epsilon", @params.Epsilon, | |||
"data_format", @params.DataFormat, | |||
"is_training", @params.IsTraining), | |||
@params.YBackprop); | |||
=> tf.Context.ExecuteOp("FusedBatchNormGradV3", @params.Name, new AutoModeArgs | |||
{ | |||
OpInputArgs = new | |||
{ | |||
y_backprop = @params.YBackprop, | |||
x = @params.X, | |||
scale = @params.Scale, | |||
reserve_space_1 = @params.ReserveSpace1, | |||
reserve_space_2 = @params.ReserveSpace2, | |||
reserve_space_3 = @params.ReserveSpace3 | |||
}, | |||
OpAttrs = new | |||
{ | |||
epsilon = @params.Epsilon, | |||
data_format = @params.DataFormat, | |||
is_training = @params.IsTraining | |||
} | |||
}); | |||
public static Tensor[] fused_batch_norm(Tensor x, | |||
Tensor scale, | |||
@@ -388,14 +383,10 @@ namespace Tensorflow.Operations | |||
} | |||
public static Tensor log_softmax(Tensor logits, string name = null) | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("LogSoftmax", name: name, | |||
args: new { logits }).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"LogSoftmax", name, | |||
null, | |||
logits).FirstOrDefault(), | |||
logits); | |||
=> tf.Context.ExecuteOp("LogSoftmax", name, new AutoModeArgs | |||
{ | |||
OpInputArgs = new { logits } | |||
}); | |||
/// <summary> | |||
/// Says whether the targets are in the top `K` predictions. | |||
@@ -418,19 +409,11 @@ namespace Tensorflow.Operations | |||
} | |||
public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null) | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("LeakyRelu", name: name, | |||
args: new | |||
{ | |||
features, | |||
alpha | |||
}).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"LeakyRelu", name, | |||
null, | |||
features, | |||
"alpha", alpha).FirstOrDefault(), | |||
features); | |||
=> tf.Context.ExecuteOp("LeakyRelu", name, new AutoModeArgs | |||
{ | |||
OpInputArgs = new { features }, | |||
OpAttrs = new { alpha } | |||
}); | |||
public static Tensor max_pool(Tensor input, | |||
int[] ksize, | |||
@@ -737,7 +737,7 @@ namespace Tensorflow | |||
public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy, | |||
long begin_mask = 0, long end_mask = 0, long ellipsis_mask = 0, long new_axis_mask = 0, | |||
long shrink_axis_mask = 0, string name = null) | |||
=> tf.Context.RunInAutoMode2("StridedSliceGrad", name, new AutoModeArgs | |||
=> tf.Context.ExecuteOp("StridedSliceGrad", name, new AutoModeArgs | |||
{ | |||
OpInputArgs = new | |||
{ | |||
@@ -960,7 +960,7 @@ namespace Tensorflow | |||
=> gen_array_ops.slice(input, begin, size, name: name); | |||
public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name = null) | |||
=> tf.Context.RunInAutoMode2("Slice", name, new AutoModeArgs | |||
=> tf.Context.ExecuteOp("Slice", name, new AutoModeArgs | |||
{ | |||
OpInputArgs = new { input, begin, size }, | |||
GetGradientAttrs = (op) => new | |||
@@ -72,14 +72,10 @@ namespace Tensorflow | |||
} | |||
public static Tensor concat_v2(Tensor[] values, int axis, string name = null) | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("ConcatV2", name: name, | |||
args: new { values, axis }).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"ConcatV2", name, | |||
null, | |||
values, axis).FirstOrDefault(), | |||
values); | |||
=> tf.Context.ExecuteOp("ConcatV2", name, new AutoModeArgs | |||
{ | |||
OpInputArgs = new { values, axis } | |||
}); | |||
private static Tensor concat_v2_eager_fallback<T1, T2>(T1[] values, T2 axis, string name, Context ctx) | |||
{ | |||
@@ -202,14 +198,11 @@ namespace Tensorflow | |||
} | |||
public static Tensor pack(Tensor[] values, int axis = 0, string name = null) | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("Pack", name, new { values, axis }).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Pack", name, | |||
null, | |||
values, | |||
"axis", axis).FirstOrDefault(), | |||
values, axis); | |||
=> tf.Context.ExecuteOp("Pack", name, new AutoModeArgs | |||
{ | |||
OpInputArgs = new { values }, | |||
OpAttrs = new { axis } | |||
}); | |||
/// <summary> | |||
/// Return a tensor with the same shape and contents as the input tensor or value. | |||
@@ -326,31 +319,16 @@ namespace Tensorflow | |||
} | |||
public static Tensor reshape<T>(Tensor tensor, T shape, string name = null) | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("Reshape", name, new { tensor, shape }).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Reshape", name, | |||
null, | |||
tensor, shape).FirstOrDefault(), | |||
tensor, shape); | |||
=> tf.Context.ExecuteOp("Reshape", name, new AutoModeArgs | |||
{ | |||
OpInputArgs = new { tensor, shape } | |||
}); | |||
public static Tensor reshape(Tensor tensor, object[] shape, string name = null) | |||
{ | |||
try | |||
=> tf.Context.ExecuteOp("Reshape", name, new AutoModeArgs | |||
{ | |||
return tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("Reshape", name, new { tensor, shape }).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Reshape", name, | |||
null, | |||
tensor, shape).FirstOrDefault(), | |||
tensor, shape); | |||
} | |||
catch (InvalidArgumentError ex) | |||
{ | |||
return reshape_eager_fallback(tensor, shape, name, tf.Context); | |||
} | |||
} | |||
OpInputArgs = new { tensor, shape } | |||
}); | |||
private static Tensor reshape_eager_fallback(Tensor tensor, object[] shape, string name, Context ctx) | |||
{ | |||
@@ -467,15 +445,11 @@ namespace Tensorflow | |||
} | |||
public static Tensor shape(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null) | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("Shape", name, | |||
new { input, out_type }).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Shape", name, | |||
null, | |||
input, | |||
"out_type", out_type).FirstOrDefault(), | |||
input); | |||
=> tf.Context.ExecuteOp("Shape", name, new AutoModeArgs | |||
{ | |||
OpInputArgs = new { input }, | |||
OpAttrs = new { out_type } | |||
}); | |||
/// <summary> | |||
/// Returns shape of tensors. | |||
@@ -559,22 +533,16 @@ namespace Tensorflow | |||
} | |||
public static Tensor tile(Tensor input, Tensor multiples, string name = null) | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("Tile", name, new { input, multiples }).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Tile", name, | |||
null, | |||
input, multiples).FirstOrDefault(), | |||
input, multiples); | |||
=> tf.Context.ExecuteOp("Tile", name, new AutoModeArgs | |||
{ | |||
OpInputArgs = new { input, multiples } | |||
}); | |||
public static Tensor tile(Tensor input, object[] multiples, string name = null) | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("Tile", name, new { input, multiples }).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Tile", name, | |||
null, | |||
input, multiples).FirstOrDefault(), | |||
input, multiples); | |||
=> tf.Context.ExecuteOp("Tile", name, new AutoModeArgs | |||
{ | |||
OpInputArgs = new { input, multiples } | |||
}); | |||
public static Tensor transpose<T1>(Tensor x, T1 perm, string name = null) | |||
{ | |||
@@ -592,22 +560,16 @@ namespace Tensorflow | |||
} | |||
public static Tensor ones_like(Tensor x, string name = null) | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("OnesLike", name, new { x }).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"OnesLike", name, | |||
null, | |||
x).FirstOrDefault(), | |||
x); | |||
=> tf.Context.ExecuteOp("OnesLike", name, new AutoModeArgs | |||
{ | |||
OpInputArgs = new { x } | |||
}); | |||
public static Tensor zeros_like(Tensor x, string name = null) | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("ZerosLike", name, new { x }).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"ZerosLike", name, | |||
null, | |||
x).FirstOrDefault(), | |||
x); | |||
=> tf.Context.ExecuteOp("ZerosLike", name, new AutoModeArgs | |||
{ | |||
OpInputArgs = new { x } | |||
}); | |||
public static Tensor stop_gradient(Tensor x, string name = null) | |||
{ | |||
@@ -623,53 +585,37 @@ namespace Tensorflow | |||
long new_axis_mask = 0, | |||
long shrink_axis_mask = 0, | |||
string name = null) | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("StridedSlice", name, new | |||
=> tf.Context.ExecuteOp("StridedSlice", name, new AutoModeArgs | |||
{ | |||
input, | |||
begin, | |||
end, | |||
strides, | |||
begin_mask, | |||
end_mask, | |||
ellipsis_mask, | |||
new_axis_mask, | |||
shrink_axis_mask | |||
}).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"StridedSlice", name, | |||
null, | |||
input, begin, end, strides, | |||
"begin_mask", begin_mask, | |||
"end_mask", end_mask, | |||
"ellipsis_mask", ellipsis_mask, | |||
"new_axis_mask", new_axis_mask, | |||
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), | |||
input, begin, end, strides); | |||
public static Operation resource_strided_slice_assign(Tensor input, Tensor begin, Tensor end, Tensor strides, Tensor value, | |||
OpInputArgs = new { input, begin, end, strides }, | |||
OpAttrs = new | |||
{ | |||
begin_mask, | |||
end_mask, | |||
ellipsis_mask, | |||
new_axis_mask, | |||
shrink_axis_mask | |||
} | |||
}); | |||
public static Tensor resource_strided_slice_assign(Tensor input, Tensor begin, Tensor end, Tensor strides, Tensor value, | |||
int begin_mask = 0, | |||
int end_mask = 0, | |||
int ellipsis_mask = 0, | |||
int new_axis_mask = 0, | |||
int shrink_axis_mask = 0, | |||
string name = null) | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("ResourceStridedSliceAssign", name, new | |||
=> tf.Context.ExecuteOp("ResourceStridedSliceAssign", name, new AutoModeArgs | |||
{ | |||
input, begin, end, strides, value, | |||
begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask | |||
}).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"ResourceStridedSliceAssign", name, | |||
null, | |||
input, begin, end, strides, value, | |||
"begin_mask", begin_mask, | |||
"end_mask", end_mask, | |||
"ellipsis_mask", ellipsis_mask, | |||
"new_axis_mask", new_axis_mask, | |||
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(), | |||
input, begin, end, strides, value); | |||
OpInputArgs = new { input, begin, end, strides, value }, | |||
OpAttrs = new { | |||
begin_mask, | |||
end_mask, | |||
ellipsis_mask, | |||
new_axis_mask, | |||
shrink_axis_mask | |||
} | |||
}); | |||
public static Tensor strided_slice<T>(Tensor input, T[] begin, T[] end, T[] strides, | |||
int begin_mask = 0, | |||
@@ -222,25 +222,15 @@ namespace Tensorflow | |||
public static Tensor resize_nearest_neighbor<Tsize>(Tensor images, Tsize size, bool align_corners = false, | |||
bool half_pixel_centers = false, string name = null) | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("ResizeNearestNeighbor", name: name, args: new | |||
=> tf.Context.ExecuteOp("ResizeNearestNeighbor", name, new AutoModeArgs | |||
{ | |||
images, | |||
size, | |||
align_corners, | |||
half_pixel_centers | |||
}).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"ResizeNearestNeighbor", name, | |||
null, | |||
images, size, | |||
"align_corners", align_corners, | |||
"half_pixel_centers", half_pixel_centers).FirstOrDefault(), | |||
images); | |||
OpInputArgs = new { images, size }, | |||
OpAttrs = new { align_corners, half_pixel_centers } | |||
}); | |||
public static Tensor resize_nearest_neighbor_grad(Tensor grads, Tensor size, bool align_corners = false, | |||
bool half_pixel_centers = false, string name = null) | |||
=> tf.Context.RunInAutoMode2("ResizeNearestNeighborGrad", name, new AutoModeArgs | |||
=> tf.Context.ExecuteOp("ResizeNearestNeighborGrad", name, new AutoModeArgs | |||
{ | |||
OpInputArgs = new { grads, size }, | |||
OpAttrs = new { align_corners, half_pixel_centers }, | |||
@@ -116,13 +116,10 @@ namespace Tensorflow | |||
/// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) | |||
/// </remarks> | |||
public static Tensor div_no_nan(Tensor x, Tensor y, string name = null) | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("DivNoNan", name: name, new { x, y }).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"DivNoNan", name, | |||
null, | |||
x, y).FirstOrDefault(), | |||
x, y); | |||
=> tf.Context.ExecuteOp("DivNoNan", name, new AutoModeArgs | |||
{ | |||
OpInputArgs = new { x, y } | |||
}); | |||
public static Tensor mean(Tensor input, int axis, bool keep_dims = false, string name = null) | |||
=> mean(input, ops.convert_to_tensor(axis), keep_dims: keep_dims, name: name); | |||
@@ -141,7 +138,7 @@ namespace Tensorflow | |||
/// <param name="name"> A name for the operation (optional).</param> | |||
/// <returns> A `Tensor`. Has the same type as `input`.</returns> | |||
public static Tensor mean(Tensor input, Tensor axis, bool keep_dims = false, string name = null) | |||
=> tf.Context.RunInAutoMode2("Mean", name, new AutoModeArgs | |||
=> tf.Context.ExecuteOp("Mean", name, new AutoModeArgs | |||
{ | |||
OpInputArgs = new { input, axis }, | |||
OpAttrs = new { keep_dims, reduction_indices = axis }, | |||
@@ -318,13 +315,10 @@ namespace Tensorflow | |||
/// Specifically, <c>y = 1 / (1 + exp(-x))</c>. | |||
/// </remarks> | |||
public static Tensor sigmoid(Tensor x, string name = "Sigmoid") | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("Sigmoid", name: name, new { x }).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Sigmoid", name, | |||
null, | |||
x).FirstOrDefault(), | |||
x); | |||
=> tf.Context.ExecuteOp("Sigmoid", name, new AutoModeArgs | |||
{ | |||
OpInputArgs = new { x } | |||
}); | |||
/// <summary> | |||
/// Computes the gradient of the sigmoid of <c>x</c> wrt its input. | |||
@@ -344,7 +338,7 @@ namespace Tensorflow | |||
/// <c>dy</c> is the corresponding input gradient. | |||
/// </remarks> | |||
public static Tensor sigmoid_grad(Tensor y, Tensor dy, string name = "SigmoidGrad") | |||
=> tf.Context.RunInAutoMode2("SigmoidGrad", name, new AutoModeArgs | |||
=> tf.Context.ExecuteOp("SigmoidGrad", name, new AutoModeArgs | |||
{ | |||
OpInputArgs = new { y, dy } | |||
}); | |||
@@ -576,13 +570,10 @@ namespace Tensorflow | |||
} | |||
public static Tensor log1p(Tensor x, string name = null) | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("Log1p", name: name, new { x }).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Log1p", name, | |||
null, | |||
x).FirstOrDefault(), | |||
x); | |||
=> tf.Context.ExecuteOp("Log1p", name, new AutoModeArgs | |||
{ | |||
OpInputArgs = new { x } | |||
}); | |||
public static Tensor logical_and(Tensor x, Tensor y, string name = null) | |||
=> tf.OpDefLib._apply_op_helper("LogicalAnd", name, args: new { x, y }); | |||
@@ -691,13 +682,10 @@ namespace Tensorflow | |||
/// <param name="name"> A name for the operation (optional).</param> | |||
/// <returns> A `Tensor`. Has the same type as `x`.</returns> | |||
public static Tensor exp(Tensor x, string name = null) | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("Exp", name, args: new { x }).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Exp", name, | |||
null, | |||
x).FirstOrDefault(), | |||
x); | |||
=> tf.Context.ExecuteOp("Exp", name, new AutoModeArgs | |||
{ | |||
OpInputArgs = new { x } | |||
}); | |||
/// <summary> | |||
/// Computes natural logarithm of x element-wise. | |||
@@ -739,14 +727,11 @@ namespace Tensorflow | |||
} | |||
public static Tensor cast(Tensor x, TF_DataType DstT, bool Truncate = false, string name = null) | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate }).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Cast", name, | |||
null, | |||
x, | |||
"DstT", DstT, "Truncate", Truncate).FirstOrDefault(), | |||
x); | |||
=> tf.Context.ExecuteOp("Cast", name, new AutoModeArgs | |||
{ | |||
OpInputArgs = new { x }, | |||
OpAttrs = new { DstT, Truncate } | |||
}); | |||
public static Tensor neg(Tensor x, string name = null) | |||
{ | |||
@@ -783,7 +768,7 @@ namespace Tensorflow | |||
} | |||
public static Tensor sub(Tensor x, Tensor y, string name = null) | |||
=> tf.Context.RunInAutoMode2("Sub", name, new AutoModeArgs | |||
=> tf.Context.ExecuteOp("Sub", name, new AutoModeArgs | |||
{ | |||
OpInputArgs = new { x, y } | |||
}); | |||
@@ -1087,14 +1072,17 @@ namespace Tensorflow | |||
} | |||
public static Tensor _max<Tx, Ty>(Tx input, Ty axis, bool keep_dims = false, string name = null) | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("Max", name, new { input, reduction_indices = axis, keep_dims }).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Max", name, | |||
null, | |||
input, axis, | |||
"keep_dims", keep_dims).FirstOrDefault(), | |||
input as Tensor); | |||
=> tf.Context.ExecuteOp("Max", name, new AutoModeArgs | |||
{ | |||
OpInputArgs = new { input, axis }, | |||
OpAttrs = new { keep_dims, reduction_indices = axis }, | |||
GetGradientAttrs = (op) => new | |||
{ | |||
T = op.get_attr<TF_DataType>("T"), | |||
align_corners = op.get_attr<bool>("align_corners"), | |||
half_pixel_centers = op.get_attr<bool>("half_pixel_centers") | |||
} | |||
}); | |||
public static Tensor _min<Tx, Ty>(Tx input, Ty axis, bool keep_dims = false, string name = null) | |||
{ | |||
@@ -1170,13 +1158,10 @@ namespace Tensorflow | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static Tensor range(Tensor start, Tensor limit, Tensor delta, string name = null) | |||
=> tf.Context.RunInAutoMode(() | |||
=> tf.OpDefLib._apply_op_helper("Range", name, new { start, limit, delta }).output, () | |||
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"Range", name, | |||
null, | |||
start, limit, delta).FirstOrDefault(), | |||
start, limit, delta); | |||
=> tf.Context.ExecuteOp("Range", name, new AutoModeArgs | |||
{ | |||
OpInputArgs = new { start, limit, delta } | |||
}); | |||
/// <summary> | |||
/// Rounds the values of a tensor to the nearest integer, element-wise. | |||
@@ -45,7 +45,7 @@ namespace Tensorflow | |||
=> gen_math_ops.add(x, y, name); | |||
public static Tensor add_v2(Tensor x, Tensor y, string name = null) | |||
=> tf.Context.RunInAutoMode2("AddV2", name, new AutoModeArgs | |||
=> tf.Context.ExecuteOp("AddV2", name, new AutoModeArgs | |||
{ | |||
OpInputArgs = new { x, y } | |||
}); | |||
@@ -261,7 +261,7 @@ namespace Tensorflow | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static Tensor erf(Tensor x, string name = null) | |||
=> tf.Context.RunInAutoMode2("Erf", name, new AutoModeArgs | |||
=> tf.Context.ExecuteOp("Erf", name, new AutoModeArgs | |||
{ | |||
OpInputArgs = new { x } | |||
}); | |||
@@ -270,7 +270,7 @@ namespace Tensorflow | |||
=> gen_math_ops.sqrt(x, name: name); | |||
public static Tensor multiply(Tensor x, Tensor y, string name = null) | |||
=> tf.Context.RunInAutoMode2("Mul", name, new AutoModeArgs | |||
=> tf.Context.ExecuteOp("Mul", name, new AutoModeArgs | |||
{ | |||
OpInputArgs = new { x, y } | |||
}); | |||