Browse Source

Change RunInAutoMode to ExecuteOp

tags/v0.40-tf2.4-tstring
Oceania2018 4 years ago
parent
commit
0471e28f6a
7 changed files with 141 additions and 267 deletions
  1. +5
    -35
      src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs
  2. +27
    -44
      src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs
  3. +2
    -2
      src/TensorFlowNET.Core/Operations/array_ops.cs
  4. +60
    -114
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  5. +5
    -15
      src/TensorFlowNET.Core/Operations/gen_image_ops.cs
  6. +39
    -54
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  7. +3
    -3
      src/TensorFlowNET.Core/Operations/math_ops.cs

+ 5
- 35
src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs View File

@@ -30,49 +30,19 @@ namespace Tensorflow.Contexts
public sealed partial class Context
{
// [DebuggerStepThrough]
public T RunInAutoMode<T>(Func<T> graphAction, Func<T> eagerAction, params object[] args)
{
if (tf.Context.has_graph_arg(args))
{
if (executing_eagerly())
{
graph_mode();
var result = graphAction();
restore_mode();
return result;
}
else
{
return graphAction();
}
}
else
{
if (tf.Context.executing_eagerly())
{
return eagerAction();
}
else
{
return graphAction();
}
}
}
// [DebuggerStepThrough]
public Tensors RunInAutoMode2(string OpType, string Name, AutoModeArgs args)
public Tensors ExecuteOp(string OpType, string Name, AutoModeArgs args)
{
var inputArgs = ConvertToDict(args.OpInputArgs);
var attrDict = ConvertToDict(args.OpAttrs);
Func<Tensor> graphAction = () =>
Func<Tensors> graphAction = () =>
{
foreach (var attr in attrDict)
inputArgs[attr.Key] = attr.Value;
return tf.OpDefLib._apply_op_helper(OpType, Name, inputArgs).output;
return tf.OpDefLib._apply_op_helper(OpType, Name, inputArgs).outputs;
};

Func<Tensor> eagerAction = () =>
Func<Tensors> eagerAction = () =>
{
var attrs = new object[attrDict.Count() * 2];
int i = 0;
@@ -87,7 +57,7 @@ namespace Tensorflow.Contexts
OpType, Name,
null,
inputArgs.Values.ToArray(),
attrs).FirstOrDefault();
attrs);
};

if (tf.Context.has_graph_arg(inputArgs.Values))


+ 27
- 44
src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs View File

@@ -269,29 +269,24 @@ namespace Tensorflow.Operations
}

public static Tensor[] fused_batch_norm_grad_v3(FusedBatchNormParams @params)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("FusedBatchNormGradV3", name: @params.Name,
args: new
{
y_backprop = @params.YBackprop,
x = @params.X,
scale = @params.Scale,
reserve_space_1 = @params.ReserveSpace1,
reserve_space_2 = @params.ReserveSpace2,
reserve_space_3 = @params.ReserveSpace3,
epsilon = @params.Epsilon,
data_format = @params.DataFormat,
is_training = @params.IsTraining
}).outputs, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"FusedBatchNormGradV3", @params.Name,
null,
@params.YBackprop, @params.X, @params.Scale,
@params.ReserveSpace1, @params.ReserveSpace2, @params.ReserveSpace3,
"epsilon", @params.Epsilon,
"data_format", @params.DataFormat,
"is_training", @params.IsTraining),
@params.YBackprop);
=> tf.Context.ExecuteOp("FusedBatchNormGradV3", @params.Name, new AutoModeArgs
{
OpInputArgs = new
{
y_backprop = @params.YBackprop,
x = @params.X,
scale = @params.Scale,
reserve_space_1 = @params.ReserveSpace1,
reserve_space_2 = @params.ReserveSpace2,
reserve_space_3 = @params.ReserveSpace3
},
OpAttrs = new
{
epsilon = @params.Epsilon,
data_format = @params.DataFormat,
is_training = @params.IsTraining
}
});

public static Tensor[] fused_batch_norm(Tensor x,
Tensor scale,
@@ -388,14 +383,10 @@ namespace Tensorflow.Operations
}

public static Tensor log_softmax(Tensor logits, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("LogSoftmax", name: name,
args: new { logits }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"LogSoftmax", name,
null,
logits).FirstOrDefault(),
logits);
=> tf.Context.ExecuteOp("LogSoftmax", name, new AutoModeArgs
{
OpInputArgs = new { logits }
});

/// <summary>
/// Says whether the targets are in the top `K` predictions.
@@ -418,19 +409,11 @@ namespace Tensorflow.Operations
}

public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("LeakyRelu", name: name,
args: new
{
features,
alpha
}).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"LeakyRelu", name,
null,
features,
"alpha", alpha).FirstOrDefault(),
features);
=> tf.Context.ExecuteOp("LeakyRelu", name, new AutoModeArgs
{
OpInputArgs = new { features },
OpAttrs = new { alpha }
});

public static Tensor max_pool(Tensor input,
int[] ksize,


+ 2
- 2
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -737,7 +737,7 @@ namespace Tensorflow
public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy,
long begin_mask = 0, long end_mask = 0, long ellipsis_mask = 0, long new_axis_mask = 0,
long shrink_axis_mask = 0, string name = null)
=> tf.Context.RunInAutoMode2("StridedSliceGrad", name, new AutoModeArgs
=> tf.Context.ExecuteOp("StridedSliceGrad", name, new AutoModeArgs
{
OpInputArgs = new
{
@@ -960,7 +960,7 @@ namespace Tensorflow
=> gen_array_ops.slice(input, begin, size, name: name);

public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name = null)
=> tf.Context.RunInAutoMode2("Slice", name, new AutoModeArgs
=> tf.Context.ExecuteOp("Slice", name, new AutoModeArgs
{
OpInputArgs = new { input, begin, size },
GetGradientAttrs = (op) => new


+ 60
- 114
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -72,14 +72,10 @@ namespace Tensorflow
}

public static Tensor concat_v2(Tensor[] values, int axis, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("ConcatV2", name: name,
args: new { values, axis }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"ConcatV2", name,
null,
values, axis).FirstOrDefault(),
values);
=> tf.Context.ExecuteOp("ConcatV2", name, new AutoModeArgs
{
OpInputArgs = new { values, axis }
});

private static Tensor concat_v2_eager_fallback<T1, T2>(T1[] values, T2 axis, string name, Context ctx)
{
@@ -202,14 +198,11 @@ namespace Tensorflow
}

public static Tensor pack(Tensor[] values, int axis = 0, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("Pack", name, new { values, axis }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Pack", name,
null,
values,
"axis", axis).FirstOrDefault(),
values, axis);
=> tf.Context.ExecuteOp("Pack", name, new AutoModeArgs
{
OpInputArgs = new { values },
OpAttrs = new { axis }
});

/// <summary>
/// Return a tensor with the same shape and contents as the input tensor or value.
@@ -326,31 +319,16 @@ namespace Tensorflow
}

public static Tensor reshape<T>(Tensor tensor, T shape, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("Reshape", name, new { tensor, shape }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Reshape", name,
null,
tensor, shape).FirstOrDefault(),
tensor, shape);
=> tf.Context.ExecuteOp("Reshape", name, new AutoModeArgs
{
OpInputArgs = new { tensor, shape }
});

public static Tensor reshape(Tensor tensor, object[] shape, string name = null)
{
try
=> tf.Context.ExecuteOp("Reshape", name, new AutoModeArgs
{
return tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("Reshape", name, new { tensor, shape }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Reshape", name,
null,
tensor, shape).FirstOrDefault(),
tensor, shape);
}
catch (InvalidArgumentError ex)
{
return reshape_eager_fallback(tensor, shape, name, tf.Context);
}
}
OpInputArgs = new { tensor, shape }
});

private static Tensor reshape_eager_fallback(Tensor tensor, object[] shape, string name, Context ctx)
{
@@ -467,15 +445,11 @@ namespace Tensorflow
}

public static Tensor shape(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("Shape", name,
new { input, out_type }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Shape", name,
null,
input,
"out_type", out_type).FirstOrDefault(),
input);
=> tf.Context.ExecuteOp("Shape", name, new AutoModeArgs
{
OpInputArgs = new { input },
OpAttrs = new { out_type }
});

/// <summary>
/// Returns shape of tensors.
@@ -559,22 +533,16 @@ namespace Tensorflow
}

public static Tensor tile(Tensor input, Tensor multiples, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("Tile", name, new { input, multiples }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Tile", name,
null,
input, multiples).FirstOrDefault(),
input, multiples);
=> tf.Context.ExecuteOp("Tile", name, new AutoModeArgs
{
OpInputArgs = new { input, multiples }
});

public static Tensor tile(Tensor input, object[] multiples, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("Tile", name, new { input, multiples }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Tile", name,
null,
input, multiples).FirstOrDefault(),
input, multiples);
=> tf.Context.ExecuteOp("Tile", name, new AutoModeArgs
{
OpInputArgs = new { input, multiples }
});

public static Tensor transpose<T1>(Tensor x, T1 perm, string name = null)
{
@@ -592,22 +560,16 @@ namespace Tensorflow
}

public static Tensor ones_like(Tensor x, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("OnesLike", name, new { x }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"OnesLike", name,
null,
x).FirstOrDefault(),
x);
=> tf.Context.ExecuteOp("OnesLike", name, new AutoModeArgs
{
OpInputArgs = new { x }
});

public static Tensor zeros_like(Tensor x, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("ZerosLike", name, new { x }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"ZerosLike", name,
null,
x).FirstOrDefault(),
x);
=> tf.Context.ExecuteOp("ZerosLike", name, new AutoModeArgs
{
OpInputArgs = new { x }
});

public static Tensor stop_gradient(Tensor x, string name = null)
{
@@ -623,53 +585,37 @@ namespace Tensorflow
long new_axis_mask = 0,
long shrink_axis_mask = 0,
string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("StridedSlice", name, new
=> tf.Context.ExecuteOp("StridedSlice", name, new AutoModeArgs
{
input,
begin,
end,
strides,
begin_mask,
end_mask,
ellipsis_mask,
new_axis_mask,
shrink_axis_mask
}).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"StridedSlice", name,
null,
input, begin, end, strides,
"begin_mask", begin_mask,
"end_mask", end_mask,
"ellipsis_mask", ellipsis_mask,
"new_axis_mask", new_axis_mask,
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(),
input, begin, end, strides);

public static Operation resource_strided_slice_assign(Tensor input, Tensor begin, Tensor end, Tensor strides, Tensor value,
OpInputArgs = new { input, begin, end, strides },
OpAttrs = new
{
begin_mask,
end_mask,
ellipsis_mask,
new_axis_mask,
shrink_axis_mask
}
});

public static Tensor resource_strided_slice_assign(Tensor input, Tensor begin, Tensor end, Tensor strides, Tensor value,
int begin_mask = 0,
int end_mask = 0,
int ellipsis_mask = 0,
int new_axis_mask = 0,
int shrink_axis_mask = 0,
string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("ResourceStridedSliceAssign", name, new
=> tf.Context.ExecuteOp("ResourceStridedSliceAssign", name, new AutoModeArgs
{
input, begin, end, strides, value,
begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask
}).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"ResourceStridedSliceAssign", name,
null,
input, begin, end, strides, value,
"begin_mask", begin_mask,
"end_mask", end_mask,
"ellipsis_mask", ellipsis_mask,
"new_axis_mask", new_axis_mask,
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(),
input, begin, end, strides, value);
OpInputArgs = new { input, begin, end, strides, value },
OpAttrs = new {
begin_mask,
end_mask,
ellipsis_mask,
new_axis_mask,
shrink_axis_mask
}
});

public static Tensor strided_slice<T>(Tensor input, T[] begin, T[] end, T[] strides,
int begin_mask = 0,


+ 5
- 15
src/TensorFlowNET.Core/Operations/gen_image_ops.cs View File

@@ -222,25 +222,15 @@ namespace Tensorflow

public static Tensor resize_nearest_neighbor<Tsize>(Tensor images, Tsize size, bool align_corners = false,
bool half_pixel_centers = false, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("ResizeNearestNeighbor", name: name, args: new
=> tf.Context.ExecuteOp("ResizeNearestNeighbor", name, new AutoModeArgs
{
images,
size,
align_corners,
half_pixel_centers
}).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"ResizeNearestNeighbor", name,
null,
images, size,
"align_corners", align_corners,
"half_pixel_centers", half_pixel_centers).FirstOrDefault(),
images);
OpInputArgs = new { images, size },
OpAttrs = new { align_corners, half_pixel_centers }
});

public static Tensor resize_nearest_neighbor_grad(Tensor grads, Tensor size, bool align_corners = false,
bool half_pixel_centers = false, string name = null)
=> tf.Context.RunInAutoMode2("ResizeNearestNeighborGrad", name, new AutoModeArgs
=> tf.Context.ExecuteOp("ResizeNearestNeighborGrad", name, new AutoModeArgs
{
OpInputArgs = new { grads, size },
OpAttrs = new { align_corners, half_pixel_centers },


+ 39
- 54
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -116,13 +116,10 @@ namespace Tensorflow
/// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
/// </remarks>
public static Tensor div_no_nan(Tensor x, Tensor y, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("DivNoNan", name: name, new { x, y }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"DivNoNan", name,
null,
x, y).FirstOrDefault(),
x, y);
=> tf.Context.ExecuteOp("DivNoNan", name, new AutoModeArgs
{
OpInputArgs = new { x, y }
});

public static Tensor mean(Tensor input, int axis, bool keep_dims = false, string name = null)
=> mean(input, ops.convert_to_tensor(axis), keep_dims: keep_dims, name: name);
@@ -141,7 +138,7 @@ namespace Tensorflow
/// <param name="name"> A name for the operation (optional).</param>
/// <returns> A `Tensor`. Has the same type as `input`.</returns>
public static Tensor mean(Tensor input, Tensor axis, bool keep_dims = false, string name = null)
=> tf.Context.RunInAutoMode2("Mean", name, new AutoModeArgs
=> tf.Context.ExecuteOp("Mean", name, new AutoModeArgs
{
OpInputArgs = new { input, axis },
OpAttrs = new { keep_dims, reduction_indices = axis },
@@ -318,13 +315,10 @@ namespace Tensorflow
/// Specifically, <c>y = 1 / (1 + exp(-x))</c>.
/// </remarks>
public static Tensor sigmoid(Tensor x, string name = "Sigmoid")
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("Sigmoid", name: name, new { x }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Sigmoid", name,
null,
x).FirstOrDefault(),
x);
=> tf.Context.ExecuteOp("Sigmoid", name, new AutoModeArgs
{
OpInputArgs = new { x }
});

/// <summary>
/// Computes the gradient of the sigmoid of <c>x</c> wrt its input.
@@ -344,7 +338,7 @@ namespace Tensorflow
/// <c>dy</c> is the corresponding input gradient.
/// </remarks>
public static Tensor sigmoid_grad(Tensor y, Tensor dy, string name = "SigmoidGrad")
=> tf.Context.RunInAutoMode2("SigmoidGrad", name, new AutoModeArgs
=> tf.Context.ExecuteOp("SigmoidGrad", name, new AutoModeArgs
{
OpInputArgs = new { y, dy }
});
@@ -576,13 +570,10 @@ namespace Tensorflow
}

public static Tensor log1p(Tensor x, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("Log1p", name: name, new { x }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Log1p", name,
null,
x).FirstOrDefault(),
x);
=> tf.Context.ExecuteOp("Log1p", name, new AutoModeArgs
{
OpInputArgs = new { x }
});

public static Tensor logical_and(Tensor x, Tensor y, string name = null)
=> tf.OpDefLib._apply_op_helper("LogicalAnd", name, args: new { x, y });
@@ -691,13 +682,10 @@ namespace Tensorflow
/// <param name="name"> A name for the operation (optional).</param>
/// <returns> A `Tensor`. Has the same type as `x`.</returns>
public static Tensor exp(Tensor x, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("Exp", name, args: new { x }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Exp", name,
null,
x).FirstOrDefault(),
x);
=> tf.Context.ExecuteOp("Exp", name, new AutoModeArgs
{
OpInputArgs = new { x }
});

/// <summary>
/// Computes natural logarithm of x element-wise.
@@ -739,14 +727,11 @@ namespace Tensorflow
}
public static Tensor cast(Tensor x, TF_DataType DstT, bool Truncate = false, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Cast", name,
null,
x,
"DstT", DstT, "Truncate", Truncate).FirstOrDefault(),
x);
=> tf.Context.ExecuteOp("Cast", name, new AutoModeArgs
{
OpInputArgs = new { x },
OpAttrs = new { DstT, Truncate }
});

public static Tensor neg(Tensor x, string name = null)
{
@@ -783,7 +768,7 @@ namespace Tensorflow
}

public static Tensor sub(Tensor x, Tensor y, string name = null)
=> tf.Context.RunInAutoMode2("Sub", name, new AutoModeArgs
=> tf.Context.ExecuteOp("Sub", name, new AutoModeArgs
{
OpInputArgs = new { x, y }
});
@@ -1087,14 +1072,17 @@ namespace Tensorflow
}

public static Tensor _max<Tx, Ty>(Tx input, Ty axis, bool keep_dims = false, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("Max", name, new { input, reduction_indices = axis, keep_dims }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Max", name,
null,
input, axis,
"keep_dims", keep_dims).FirstOrDefault(),
input as Tensor);
=> tf.Context.ExecuteOp("Max", name, new AutoModeArgs
{
OpInputArgs = new { input, axis },
OpAttrs = new { keep_dims, reduction_indices = axis },
GetGradientAttrs = (op) => new
{
T = op.get_attr<TF_DataType>("T"),
align_corners = op.get_attr<bool>("align_corners"),
half_pixel_centers = op.get_attr<bool>("half_pixel_centers")
}
});

public static Tensor _min<Tx, Ty>(Tx input, Ty axis, bool keep_dims = false, string name = null)
{
@@ -1170,13 +1158,10 @@ namespace Tensorflow
/// <param name="name"></param>
/// <returns></returns>
public static Tensor range(Tensor start, Tensor limit, Tensor delta, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("Range", name, new { start, limit, delta }).output, ()
=> tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Range", name,
null,
start, limit, delta).FirstOrDefault(),
start, limit, delta);
=> tf.Context.ExecuteOp("Range", name, new AutoModeArgs
{
OpInputArgs = new { start, limit, delta }
});

/// <summary>
/// Rounds the values of a tensor to the nearest integer, element-wise.


+ 3
- 3
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -45,7 +45,7 @@ namespace Tensorflow
=> gen_math_ops.add(x, y, name);

public static Tensor add_v2(Tensor x, Tensor y, string name = null)
=> tf.Context.RunInAutoMode2("AddV2", name, new AutoModeArgs
=> tf.Context.ExecuteOp("AddV2", name, new AutoModeArgs
{
OpInputArgs = new { x, y }
});
@@ -261,7 +261,7 @@ namespace Tensorflow
/// <param name="name"></param>
/// <returns></returns>
public static Tensor erf(Tensor x, string name = null)
=> tf.Context.RunInAutoMode2("Erf", name, new AutoModeArgs
=> tf.Context.ExecuteOp("Erf", name, new AutoModeArgs
{
OpInputArgs = new { x }
});
@@ -270,7 +270,7 @@ namespace Tensorflow
=> gen_math_ops.sqrt(x, name: name);

public static Tensor multiply(Tensor x, Tensor y, string name = null)
=> tf.Context.RunInAutoMode2("Mul", name, new AutoModeArgs
=> tf.Context.ExecuteOp("Mul", name, new AutoModeArgs
{
OpInputArgs = new { x, y }
});


Loading…
Cancel
Save