@@ -112,7 +112,7 @@ namespace Tensorflow | |||
var strides = new[] { 1, 1, 1, 1 }; | |||
var dilations = new[] { 1, 1, 1, 1 }; | |||
var results = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("Conv2D", null, input, filter) | |||
var results = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(tf.Context, "Conv2D", null, input, filter) | |||
{ | |||
attrs = ConvertToDict(new | |||
{ | |||
@@ -134,7 +134,7 @@ namespace Tensorflow | |||
var strides = new[] { 1, 1, 1, 1 }; | |||
var dilations = new[] { 1, 1, 1, 1 }; | |||
var results = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("Conv2D", null, input, filter) | |||
var results = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(tf.Context, "Conv2D", null, input, filter) | |||
{ | |||
attrs = ConvertToDict(new | |||
{ | |||
@@ -44,7 +44,8 @@ namespace Tensorflow | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public Tensor batch_to_space_nd<T>(T input, int[] block_shape, int[,] crops, string name = null) | |||
=> gen_array_ops.batch_to_space_nd(input, block_shape, crops, name: name); | |||
=> gen_array_ops.batch_to_space_nd(ops.convert_to_tensor(input), ops.convert_to_tensor(block_shape), | |||
ops.convert_to_tensor(crops), name: name); | |||
/// <summary> | |||
/// Apply boolean mask to tensor. | |||
@@ -91,7 +92,7 @@ namespace Tensorflow | |||
}); | |||
} | |||
return gen_array_ops.concat_v2(values.ToArray(), axis, name: name); | |||
return gen_array_ops.concat_v2(values.ToArray(), ops.convert_to_tensor(axis), name: name); | |||
} | |||
/// <summary> | |||
@@ -115,7 +116,7 @@ namespace Tensorflow | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public Tensor fill<T>(Tensor dims, T value, string name = null) | |||
=> gen_array_ops.fill(dims, value, name: name); | |||
=> gen_array_ops.fill(dims, ops.convert_to_tensor(value), name: name); | |||
public Tensor fill<T>(Shape dims, T value, string name = null) | |||
=> array_ops.fill(dims, value, name: name); | |||
@@ -138,7 +139,7 @@ namespace Tensorflow | |||
/// <param name="axis"></param> | |||
/// <returns></returns> | |||
public Tensor gather(Tensor @params, Tensor indices, string name = null, int axis = 0) | |||
=> array_ops.gather(@params, indices, name: name, axis: axis); | |||
=> array_ops.gather(@params, indices, name: name, axis: ops.convert_to_tensor(axis)); | |||
/// <summary> | |||
/// Return the elements, either from `x` or `y`, depending on the `condition`. | |||
@@ -166,7 +167,7 @@ namespace Tensorflow | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public Tensor reverse(Tensor tensor, int[] axis, string name = null) | |||
=> gen_array_ops.reverse(tensor, axis, name: name); | |||
=> gen_array_ops.reverse(tensor, ops.convert_to_tensor(axis), name: name); | |||
public Tensor reverse(Tensor tensor, Tensor axis, string name = null) | |||
=> gen_array_ops.reverse(tensor, axis, name: name); | |||
@@ -189,7 +190,8 @@ namespace Tensorflow | |||
/// <param name="name">A name for the operation (optional).</param> | |||
/// <returns>A `Tensor` the same type as `input`.</returns> | |||
public Tensor slice<Tb, Ts>(Tensor input, Tb[] begin, Ts[] size, string name = null) | |||
=> array_ops.slice(input, begin, size, name: name); | |||
=> array_ops.slice(input, begin.Select(x => ops.convert_to_tensor(x)).ToArray(), | |||
size.Select(x => ops.convert_to_tensor(x)).ToArray(), name: name); | |||
public Tensor squeeze(Tensor input, int axis, string name = null, int squeeze_dims = -1) | |||
=> array_ops.squeeze(input, new[] { axis }, name); | |||
@@ -255,7 +257,7 @@ namespace Tensorflow | |||
/// <param name="name">A name for the operation (optional).</param> | |||
/// <returns>A `Tensor`. Has the same type as `input`.</returns> | |||
public Tensor placeholder_with_default<T>(T input, int[] shape, string name = null) | |||
=> gen_array_ops.placeholder_with_default(input, shape, name: name); | |||
=> gen_array_ops.placeholder_with_default(ops.convert_to_tensor(input), shape, name: name); | |||
/// <summary> | |||
/// Returns the shape of a tensor. | |||
@@ -130,7 +130,7 @@ namespace Tensorflow | |||
=> gen_math_ops.add(a, b, name: name); | |||
public Tensor add<Tx, Ty>(Tx a, Ty b, string name = null) | |||
=> gen_math_ops.add(a, b, name: name); | |||
=> gen_math_ops.add(ops.convert_to_tensor(a), ops.convert_to_tensor(b), name: name); | |||
/// <summary> | |||
/// Adds all input tensors element-wise. | |||
@@ -151,10 +151,10 @@ namespace Tensorflow | |||
=> gen_math_ops.atan(x, name); | |||
public Tensor arg_max(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null) | |||
=> gen_math_ops.arg_max(input, dimension, output_type: output_type, name: name); | |||
=> gen_math_ops.arg_max(input, ops.convert_to_tensor(dimension), output_type: output_type, name: name); | |||
public Tensor arg_min(Tensor input, int dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null) | |||
=> gen_math_ops.arg_min(input, dimension, output_type: output_type, name: name); | |||
=> gen_math_ops.arg_min(input, ops.convert_to_tensor(dimension), output_type: output_type, name: name); | |||
public Tensor is_finite(Tensor input, string name = null) | |||
=> gen_math_ops.is_finite(input, name); | |||
@@ -199,7 +199,7 @@ namespace Tensorflow | |||
=> gen_math_ops.cos(x, name); | |||
public Tensor cos(float x, string name = null) | |||
=> gen_math_ops.cos(x, name); | |||
=> gen_math_ops.cos(ops.convert_to_tensor(x), name); | |||
/// <summary> | |||
/// Computes hyperbolic cosine of x element-wise. | |||
@@ -235,7 +235,7 @@ namespace Tensorflow | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public Tensor greater<Tx, Ty>(Tx x, Ty y, string name = null) | |||
=> gen_math_ops.greater(x, y, name); | |||
=> gen_math_ops.greater(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name); | |||
/// <summary> | |||
/// Returns the truth value of (x >= y) element-wise. | |||
@@ -247,7 +247,7 @@ namespace Tensorflow | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public Tensor greater_equal<Tx, Ty>(Tx x, Ty y, string name = null) | |||
=> gen_math_ops.greater_equal(x, y, name); | |||
=> gen_math_ops.greater_equal(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name); | |||
/// <summary> | |||
/// Returns the truth value of (x < y) element-wise. | |||
@@ -259,7 +259,7 @@ namespace Tensorflow | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public Tensor less<Tx, Ty>(Tx x, Ty y, string name = null) | |||
=> gen_math_ops.less(x, y, name); | |||
=> gen_math_ops.less(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name); | |||
/// <summary> | |||
/// Computes the log of the absolute value of `Gamma(x)` element-wise. | |||
@@ -280,7 +280,7 @@ namespace Tensorflow | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public Tensor less_equal<Tx, Ty>(Tx x, Ty y, string name = null) | |||
=> gen_math_ops.less_equal(x, y, name); | |||
=> gen_math_ops.less_equal(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name); | |||
/// <summary> | |||
/// Computes natural logarithm of (1 + x) element-wise. | |||
@@ -292,7 +292,7 @@ namespace Tensorflow | |||
=> gen_math_ops.log1p(x, name); | |||
public Tensor logical_and<T>(T x, T y, string name = null) | |||
=> gen_math_ops.logical_and(x, y, name); | |||
=> gen_math_ops.logical_and(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name); | |||
public Tensor logical_not(Tensor x, string name = null) | |||
=> gen_math_ops.logical_not(x, name); | |||
@@ -301,7 +301,10 @@ namespace Tensorflow | |||
=> gen_math_ops.logical_or(x, y, name); | |||
public Tensor logical_xor(Tensor x, Tensor y, string name = "LogicalXor") | |||
=> gen_math_ops.logical_xor(x, y, name); | |||
{ | |||
return gen_math_ops.logical_and(gen_math_ops.logical_or(x, y), | |||
gen_math_ops.logical_not(gen_math_ops.logical_and(x, y)), name); | |||
} | |||
/// <summary> | |||
/// Clips tensor values to a specified min and max. | |||
@@ -312,7 +315,7 @@ namespace Tensorflow | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public Tensor _clip_by_value(Tensor t, Tensor clip_value_min, Tensor clip_value_max, string name = null) | |||
=> gen_math_ops._clip_by_value(t, clip_value_min, clip_value_max); | |||
=> gen_math_ops.clip_by_value(t, clip_value_min, clip_value_max); | |||
/// <summary> | |||
/// Clips tensor values to a specified min and max. | |||
@@ -345,7 +348,7 @@ namespace Tensorflow | |||
=> clip_ops.clip_by_value(t, clip_value_min, clip_value_max, name); | |||
public Tensor sub<Tx, Ty>(Tx a, Ty b, string name = null) | |||
=> gen_math_ops.sub(a, b, name: name); | |||
=> gen_math_ops.sub(ops.convert_to_tensor(a), ops.convert_to_tensor(b), name: name); | |||
public Tensor divide(Tensor a, Tensor b) | |||
=> a / b; | |||
@@ -396,7 +399,7 @@ namespace Tensorflow | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public Tensor max<Tx, Ty>(Tx input, Ty axis, bool keep_dims = false, string name = null) | |||
=> gen_math_ops._max(input, axis, keep_dims: keep_dims, name: name); | |||
=> gen_math_ops.max(ops.convert_to_tensor(input), ops.convert_to_tensor(axis), keep_dims: keep_dims, name: name); | |||
/// <summary> | |||
/// Computes the minimum of elements across dimensions of a tensor. | |||
@@ -409,7 +412,7 @@ namespace Tensorflow | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public Tensor min<Tx, Ty>(Tx input, Ty axis, bool keep_dims = false, string name = null) | |||
=> gen_math_ops._min(input, axis, keep_dims: keep_dims, name: name); | |||
=> gen_math_ops.min(ops.convert_to_tensor(input), ops.convert_to_tensor(axis), keep_dims: keep_dims, name: name); | |||
/// <summary> | |||
/// Returns the max of x and y (i.e. x > y ? x : y) element-wise. | |||
@@ -421,7 +424,7 @@ namespace Tensorflow | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public Tensor maximum<T1, T2>(T1 x, T2 y, string name = null) | |||
=> gen_math_ops.maximum(x, y, name: name); | |||
=> gen_math_ops.maximum(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name); | |||
/// <summary> | |||
/// Returns the min of x and y (i.e. x < y ? x : y) element-wise. | |||
@@ -433,7 +436,7 @@ namespace Tensorflow | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public Tensor minimum<T1, T2>(T1 x, T2 y, string name = null) | |||
=> gen_math_ops.minimum(x, y, name: name); | |||
=> gen_math_ops.minimum(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name); | |||
public Tensor multiply(Tensor x, Tensor y, string name = null) | |||
=> gen_math_ops.mul(x, y, name: name); | |||
@@ -448,7 +451,7 @@ namespace Tensorflow | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public Tensor multiply<Tx, Ty>(Tx x, Ty y, string name = null) | |||
=> gen_math_ops.mul(x, y, name: name); | |||
=> gen_math_ops.mul(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name); | |||
public Tensor negative(Tensor x, string name = null) | |||
=> gen_math_ops.neg(x, name); | |||
@@ -577,7 +580,7 @@ namespace Tensorflow | |||
=> math_ops.sigmoid(x, name: name); | |||
public Tensor sum(Tensor input, int axis, bool keep_dims = false, string name = null) | |||
=> gen_math_ops._sum(input, axis, keep_dims: keep_dims, name: name); | |||
=> gen_math_ops.sum(input, ops.convert_to_tensor(axis), keep_dims: keep_dims, name: name); | |||
public Tensor reduce_mean(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null, int? reduction_indices = null) | |||
=> math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name, reduction_indices: reduction_indices); | |||
@@ -29,21 +29,8 @@ namespace Tensorflow | |||
public Tensor conv2d(Tensor input, Tensor filter, int[] strides, string padding, bool use_cudnn_on_gpu = true, | |||
string data_format = "NHWC", int[] dilations = null, string name = null) | |||
{ | |||
var parameters = new Conv2dParams | |||
{ | |||
Input = input, | |||
Filter = filter, | |||
Strides = strides, | |||
Padding = padding, | |||
UseCudnnOnGpu = use_cudnn_on_gpu, | |||
DataFormat = data_format, | |||
Name = name | |||
}; | |||
if (dilations != null) | |||
parameters.Dilations = dilations; | |||
return gen_nn_ops.conv2d(parameters); | |||
return gen_nn_ops.conv2d(input, filter, strides, padding, use_cudnn_on_gpu, | |||
data_format: data_format, dilations: dilations, name: name); | |||
} | |||
public Tensor[] ctc_greedy_decoder(Tensor inputs, Tensor sequence_length, bool merge_repeated = true, string name = null) | |||
@@ -118,7 +105,7 @@ namespace Tensorflow | |||
public IActivation softmax() => new softmax(); | |||
public Tensor tanh(Tensor x, string name = null) | |||
=> gen_nn_ops.tanh(x, name); | |||
=> gen_math_ops.tanh(x, name); | |||
public Tensor relu(Tensor features, string name = null) | |||
=> gen_nn_ops.relu(features, name); | |||
@@ -146,14 +133,14 @@ namespace Tensorflow | |||
=> nn_ops.in_top_k(predictions, targets, k, name); | |||
public Tensor[] top_k(Tensor input, int k = 1, bool sorted = true, string name = null) | |||
=> gen_nn_ops.top_kv2(input, k: k, sorted: sorted, name: name); | |||
=> gen_nn_ops.top_kv2(input, k: ops.convert_to_tensor(k), sorted: sorted, name: name); | |||
public Tensor bias_add(Tensor value, IVariableV1 bias, string data_format = null, string name = null) | |||
{ | |||
return tf_with(ops.name_scope(name, "BiasAdd", new { value, bias }), scope => | |||
{ | |||
name = scope; | |||
return gen_nn_ops.bias_add(value, bias, data_format: data_format, name: name); | |||
return gen_nn_ops.bias_add(value, ops.convert_to_tensor(bias), data_format: data_format, name: name); | |||
}); | |||
} | |||
@@ -172,7 +159,7 @@ namespace Tensorflow | |||
/// <returns></returns> | |||
public Tensor lrn(Tensor input, int depth_radius = 5, int bias = 1, | |||
int alpha = 1, float beta = 0.5f, string name = null) | |||
=> gen_nn_ops.local_response_normalization(input, depth_radius: depth_radius, bias: bias, | |||
=> gen_nn_ops.lrn(input, depth_radius: depth_radius, bias: bias, | |||
alpha: alpha, beta: beta, name: name); | |||
public Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null) | |||
@@ -31,6 +31,6 @@ namespace Tensorflow | |||
public Tensor reshape(Tensor tensor, | |||
object[] shape, | |||
string name = null) | |||
=> gen_array_ops.reshape(tensor, shape, name); | |||
=> gen_array_ops.reshape(tensor, ops.convert_to_tensor(shape), name); | |||
} | |||
} |
@@ -46,10 +46,10 @@ namespace Tensorflow | |||
int ellipsis_mask = 0, | |||
int new_axis_mask = 0, | |||
int shrink_axis_mask = 0, | |||
string name = null) => gen_array_ops.strided_slice(input: input, | |||
begin: begin, | |||
end: end, | |||
strides: strides, | |||
string name = null) => array_ops.strided_slice(input, | |||
begin: ops.convert_to_tensor(begin), | |||
end: ops.convert_to_tensor(end), | |||
strides: ops.convert_to_tensor(strides), | |||
begin_mask: begin_mask, | |||
end_mask: end_mask, | |||
ellipsis_mask: ellipsis_mask, | |||
@@ -23,7 +23,7 @@ namespace Tensorflow | |||
=> gen_array_ops.tile(input, multiples, name); | |||
public Tensor tile(Tensor input, object[] multiples, string name = null) | |||
=> gen_array_ops.tile(input, multiples, name); | |||
=> gen_array_ops.tile(input, ops.convert_to_tensor(multiples), name); | |||
public Tensor tile(Tensor input, Shape multiples, string name = null) | |||
{ | |||
@@ -57,6 +57,21 @@ namespace Tensorflow | |||
[DllImport(TensorFlowLibName)] | |||
public static extern int TF_OperationGetAttrValueProto(IntPtr oper, string attr_name, SafeBufferHandle output_attr_value, SafeStatusHandle status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_OperationGetAttrType(IntPtr oper, string attr_name, IntPtr value, SafeStatusHandle status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_OperationGetAttrInt(IntPtr oper, string attr_name, IntPtr value, SafeStatusHandle status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_OperationGetAttrFloat(IntPtr oper, string attr_name, IntPtr value, SafeStatusHandle status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_OperationGetAttrBool(IntPtr oper, string attr_name, IntPtr value, SafeStatusHandle status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_OperationGetAttrShape(IntPtr oper, string attr_name, long[] value, int num_dims, SafeStatusHandle status); | |||
[DllImport(TensorFlowLibName)] | |||
public static extern void TF_SetAttrBool(IntPtr desc, string attr_name, bool value); | |||
@@ -88,7 +88,7 @@ namespace Tensorflow.Clustering | |||
public Tensor op() | |||
{ | |||
var x = control_flow_ops.cond(gen_math_ops.equal(_num_remaining, 0), | |||
var x = control_flow_ops.cond(gen_math_ops.equal(_num_remaining, ops.convert_to_tensor(0)), | |||
() => | |||
{ | |||
return check_ops.assert_equal(_cluster_centers_initialized, true); | |||
@@ -49,7 +49,7 @@ namespace Tensorflow.Contexts | |||
Tensors ExecEagerAction(string OpType, string Name, ExecuteOpArgs args) | |||
{ | |||
var opExecInfo = new FastPathOpExecInfo(OpType, Name, args.OpInputArgs) | |||
var opExecInfo = new FastPathOpExecInfo(tf.Context, OpType, Name, args.OpInputArgs) | |||
{ | |||
attrs = args.OpAttrs | |||
}; | |||
@@ -68,7 +68,8 @@ namespace Tensorflow.Eager | |||
var input_arg = op_def.InputArg[i]; | |||
if (!string.IsNullOrEmpty(input_arg.NumberAttr)) | |||
{ | |||
int len = (input as object[]).Length; | |||
var fast_input_array = input is Tensors tensors ? (object[])tensors : (object[])input; | |||
int len = fast_input_array.Length; | |||
c_api.TFE_OpSetAttrInt(op, input_arg.NumberAttr, len); | |||
if (op_exec_info.run_callbacks) | |||
{ | |||
@@ -79,7 +80,6 @@ namespace Tensorflow.Eager | |||
if (len > 0) | |||
{ | |||
var fast_input_array = (object[])op_exec_info.args[i]; | |||
// First item adds the type attr. | |||
if (!AddInputToOp(fast_input_array[i], true, input_arg, flattened_attrs, flattened_inputs, op, status)) | |||
return null; | |||
@@ -17,8 +17,9 @@ namespace Tensorflow | |||
public bool run_callbacks { get; set; } | |||
public Action callbacks { get; set; } | |||
public FastPathOpExecInfo(string opName, string name, params object[] inputArgs) | |||
public FastPathOpExecInfo(Context ctx, string opName, string name, params object[] inputArgs) | |||
{ | |||
this.ctx = ctx; | |||
this.op_name = opName; | |||
this.name = name; | |||
this.args = inputArgs; | |||
@@ -7,10 +7,11 @@ using Tensorflow.Contexts; | |||
using static Tensorflow.ApiDef.Types; | |||
using static Tensorflow.CostGraphDef.Types; | |||
using static Tensorflow.Binding; | |||
using Tensorflow.Gradients; | |||
namespace Tensorflow.Eager | |||
{ | |||
internal static class execute | |||
internal static class _execute | |||
{ | |||
public static (DataType[], Tensor[]) onvert_to_mixed_eager_tensors(Tensor[] values, Context ctx) | |||
{ | |||
@@ -18,7 +19,7 @@ namespace Tensorflow.Eager | |||
var types = v.Select(t => t.dtype.as_datatype_enum()); | |||
return (types.ToArray(), v.ToArray()); | |||
} | |||
public static Tensor[] executes(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null) | |||
public static Tensor[] execute(string op_name, int num_outputs, Tensor[] inputs, object[] attrs, Context ctx, string name = null) | |||
{ | |||
return quick_execute(op_name, num_outputs, inputs, attrs, ctx, name); | |||
} | |||
@@ -33,7 +34,12 @@ namespace Tensorflow.Eager | |||
} | |||
public static bool must_record_gradient() | |||
{ | |||
return false; | |||
return tf.GetTapeSet().Count != 0; | |||
} | |||
public static bool record_gradient(string op_name, Tensor[] inputs, object[] attrs, Tensor[] results) | |||
{ | |||
return tf.Runner.RecordGradient(op_name, inputs, attrs, results); | |||
} | |||
} | |||
} |
@@ -147,7 +147,7 @@ namespace Tensorflow.Functions | |||
Tensor[] outputs; | |||
if (executing_eagerly) | |||
{ | |||
outputs = execute.executes( | |||
outputs = _execute.execute( | |||
Signature.Name, | |||
_num_outputs, | |||
args, | |||
@@ -44,6 +44,15 @@ namespace Tensorflow.Gradients | |||
return tape; | |||
} | |||
public void PushTape(ITape tape) | |||
{ | |||
// Enters a context inside which operations are recorded on this tape. | |||
if (tf.Context.executing_eagerly()) | |||
tf.Context.ensure_initialized(); | |||
_tapeSet.Push(tape); | |||
} | |||
ITape PopTape() | |||
{ | |||
_tape.StopRecord(); | |||
@@ -36,8 +36,7 @@ namespace Tensorflow.Gradients | |||
var input_value = op.inputs[0]; | |||
var broadcast_shape = op.inputs[1]; | |||
var input_value_shape = array_ops.shape(input_value); | |||
var (_, reduction_axes) = gen_array_ops.broadcast_gradient_args(broadcast_shape, | |||
input_value_shape); | |||
var reduction_axes = gen_array_ops.broadcast_gradient_args(broadcast_shape, input_value_shape)[1]; | |||
var updates_grad_reshaped = math_ops.reduce_sum(grad, | |||
axis: reduction_axes, | |||
keepdims: true); | |||
@@ -351,16 +350,16 @@ namespace Tensorflow.Gradients | |||
null, | |||
null, | |||
null, | |||
gen_array_ops.strided_slice( | |||
array_ops.strided_slice( | |||
grad, | |||
begin, | |||
end, | |||
strides, | |||
begin_mask: op.get_attr<long>("begin_mask"), | |||
end_mask: op.get_attr<long>("end_mask"), | |||
ellipsis_mask: op.get_attr<long>("ellipsis_mask"), | |||
new_axis_mask: op.get_attr<long>("new_axis_mask"), | |||
shrink_axis_mask: op.get_attr<long>("shrink_axis_mask")) | |||
begin_mask: (int)op.get_attr<long>("begin_mask"), | |||
end_mask: (int)op.get_attr<long>("end_mask"), | |||
ellipsis_mask: (int)op.get_attr<long>("ellipsis_mask"), | |||
new_axis_mask: (int)op.get_attr<long>("new_axis_mask"), | |||
shrink_axis_mask: (int)op.get_attr<long>("shrink_axis_mask")) | |||
}; | |||
} | |||
@@ -53,7 +53,8 @@ namespace Tensorflow.Gradients | |||
var sx = array_ops.shape(x); | |||
var sy = array_ops.shape(y); | |||
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); | |||
var args = gen_array_ops.broadcast_gradient_args(sx, sy); | |||
var (rx, ry) = (args[0], args[1]); | |||
var sum1 = math_ops.reduce_sum(grad, rx); | |||
var r1 = gen_array_ops.reshape(sum1, sx); | |||
@@ -101,7 +102,8 @@ namespace Tensorflow.Gradients | |||
var y = op.inputs[1]; | |||
var sx = array_ops.shape(x); | |||
var sy = array_ops.shape(y); | |||
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); | |||
var args = gen_array_ops.broadcast_gradient_args(sx, sy); | |||
var (rx, ry) = (args[0], args[1]); | |||
x = math_ops.conj(x); | |||
y = math_ops.conj(y); | |||
@@ -427,7 +429,8 @@ namespace Tensorflow.Gradients | |||
isMaximum | |||
? gen_math_ops.greater_equal(x, y) | |||
: gen_math_ops.less_equal(x, y); | |||
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); | |||
var args = gen_array_ops.broadcast_gradient_args(sx, sy); | |||
var (rx, ry) = (args[0], args[1]); | |||
var xgrad = array_ops.where(xmask, grad, zeros); | |||
var gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx); | |||
var ygrad = array_ops.where(xmask, zeros, grad); | |||
@@ -458,7 +461,7 @@ namespace Tensorflow.Gradients | |||
private static Tensor _safe_shape_div(Tensor x, Tensor y) | |||
{ | |||
return math_ops.floordiv(x, gen_math_ops.maximum(y, 1)); | |||
return math_ops.floordiv(x, gen_math_ops.maximum(y, ops.convert_to_tensor(1))); | |||
} | |||
[RegisterGradient("Sub")] | |||
@@ -573,7 +576,8 @@ namespace Tensorflow.Gradients | |||
var sx = array_ops.shape(x); | |||
var sy = array_ops.shape(y); | |||
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); | |||
var args = gen_array_ops.broadcast_gradient_args(sx, sy); | |||
var (rx, ry) = (args[0], args[1]); | |||
x = math_ops.conj(x); | |||
y = math_ops.conj(y); | |||
@@ -824,7 +828,7 @@ namespace Tensorflow.Gradients | |||
mask = x > 0.0f; | |||
var ones = array_ops.ones_like(x); | |||
var safe_x = array_ops.where(mask, x, ones); | |||
var x1 = gen_array_ops.log(safe_x); | |||
var x1 = math_ops.log(safe_x); | |||
var y1 = array_ops.zeros_like(x); | |||
var log_x = array_ops.where(mask, x1, y1); | |||
var mul1 = grad * z * log_x; | |||
@@ -855,7 +859,8 @@ namespace Tensorflow.Gradients | |||
sy = array_ops.shape_internal(y, optimize: false); | |||
} | |||
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy); | |||
var args = gen_array_ops.broadcast_gradient_args(sx, sy); | |||
var (rx, ry) = (args[0], args[1]); | |||
return new[] | |||
{ | |||
(sx, rx, !x.shape.Equals(grad.shape)), | |||
@@ -47,8 +47,8 @@ namespace Tensorflow.Gradients | |||
{ | |||
return new Tensor[] | |||
{ | |||
gen_math_ops.mul(grad, y), | |||
gen_math_ops.mul(grad, x) | |||
math_ops.multiply(grad, y), | |||
math_ops.multiply(grad, x) | |||
}; | |||
} | |||
@@ -192,17 +192,8 @@ namespace Tensorflow.Gradients | |||
explicit_paddings: explicit_paddings, | |||
dilations: dilations, | |||
data_format: data_format), | |||
gen_nn_ops.conv2d(new Conv2dParams | |||
{ | |||
Input = grad, | |||
Filter = op.inputs[1], | |||
Strides = strides, | |||
Padding = padding, | |||
DataFormat = data_format, | |||
Dilations = dilations, | |||
ExplicitPaddings = explicit_paddings, | |||
UseCudnnOnGpu = use_cudnn_on_gpu | |||
}) | |||
gen_nn_ops.conv2d(grad, op.inputs[1], strides, padding, | |||
use_cudnn_on_gpu, explicit_paddings, data_format, dilations) | |||
}; | |||
} | |||
@@ -265,20 +256,27 @@ namespace Tensorflow.Gradients | |||
var epsilon = op.get_attr<float>("epsilon"); | |||
var data_format = op.get_attr<string>("data_format"); | |||
var is_training = op.get_attr<bool>("is_training"); | |||
Func<FusedBatchNormParams, Tensor[]> grad_fun = null; | |||
switch (version) | |||
Func<FusedBatchNormParams, Tensor[]> grad_fun = (p) => | |||
{ | |||
case 2: | |||
grad_fun = gen_nn_ops.fused_batch_norm_grad_v3; | |||
break; | |||
case 1: | |||
// grad_fun = gen_nn_ops.fused_batch_norm_grad_v2; | |||
throw new NotImplementedException(""); | |||
default: | |||
grad_fun = gen_nn_ops.fused_batch_norm_grad; | |||
break; | |||
} | |||
if(version == 2) | |||
{ | |||
return gen_nn_ops.fused_batch_norm_grad_v3(p.YBackprop, p.X, p.Scale, | |||
p.ReserveSpace1, p.ReserveSpace2, p.ReserveSpace3, p.Epsilon, | |||
p.DataFormat, p.IsTraining, p.Name); | |||
} | |||
else if(version == 1) | |||
{ | |||
return gen_nn_ops.fused_batch_norm_grad_v2(p.YBackprop, p.X, p.Scale, | |||
p.ReserveSpace1, p.ReserveSpace2, p.Epsilon, p.DataFormat, | |||
p.IsTraining, p.Name); | |||
} | |||
else | |||
{ | |||
return gen_nn_ops.fused_batch_norm_grad(p.YBackprop, p.X, p.Scale, | |||
p.ReserveSpace1, p.ReserveSpace2, p.Epsilon, p.DataFormat, | |||
p.IsTraining, p.Name); | |||
} | |||
}; | |||
if (is_training) | |||
{ | |||
@@ -406,7 +404,7 @@ namespace Tensorflow.Gradients | |||
// finally reshaping it to the original input shape. | |||
var scatter = gen_array_ops.scatter_nd(array_ops.expand_dims(ind, -1), | |||
array_ops.reshape(grad, new int[] { -1 }), | |||
new Tensor[] { math_ops.reduce_prod(in_shape) }); | |||
math_ops.reduce_prod(in_shape)); | |||
return new Tensor[] | |||
{ | |||
@@ -34,7 +34,7 @@ namespace Tensorflow.Operations | |||
{ | |||
name = scope; | |||
value = ops.convert_to_tensor(value, name: "input"); | |||
return gen_nn_ops.average_pool( | |||
return gen_nn_ops.avg_pool( | |||
value, | |||
ksize: ksize, | |||
strides: strides, | |||
@@ -67,16 +67,15 @@ namespace Tensorflow.Operations | |||
var dilations = _get_sequence(args.DilationRate, num_spatial_dims, channel_index).ToArray(); | |||
var strides = _get_sequence(args.Strides, num_spatial_dims, channel_index).ToArray(); | |||
result = gen_nn_ops.conv2d(new Conv2dParams | |||
{ | |||
Input = input, | |||
Filter = filters, | |||
Strides = strides, | |||
Padding = padding, | |||
DataFormat = data_format, | |||
Dilations = dilations, | |||
Name = name | |||
}); | |||
result = gen_nn_ops.conv2d( | |||
input, | |||
filters, | |||
strides, | |||
padding, | |||
data_format: data_format, | |||
dilations: dilations, | |||
name: name | |||
); | |||
} | |||
else | |||
{ | |||
@@ -93,16 +92,15 @@ namespace Tensorflow.Operations | |||
input = array_ops.expand_dims(input, spatial_start_dim); | |||
filters = array_ops.expand_dims(filters, 0); | |||
result = gen_nn_ops.conv2d(new Conv2dParams | |||
{ | |||
Input = input, | |||
Filter = filters, | |||
Strides = strides.ToArray(), | |||
Padding = padding, | |||
DataFormat = channel_first ? "NCHW" : "NHWC", | |||
Dilations = dilations.ToArray(), | |||
Name = name | |||
}); | |||
result = gen_nn_ops.conv2d( | |||
input, | |||
filters, | |||
strides.ToArray(), | |||
padding, | |||
data_format: channel_first ? "NCHW" : "NHWC", | |||
dilations: dilations.ToArray(), | |||
name: name | |||
); | |||
result = array_ops.squeeze(result, new[] { spatial_start_dim }); | |||
} | |||
}); | |||
@@ -1,373 +0,0 @@ | |||
/***************************************************************************** | |||
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System.Linq; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Operations | |||
{ | |||
public class gen_nn_ops | |||
{ | |||
/// <summary> | |||
/// Computes a 2-D convolution given 4-D `input` and `filter` tensors. | |||
/// | |||
/// Given an input tensor of shape `[batch, in_height, in_width, in_channels]` | |||
/// and a filter / kernel tensor of shape | |||
/// `[filter_height, filter_width, in_channels, out_channels]`, this op | |||
/// performs the following: | |||
/// | |||
/// 1. Flattens the filter to a 2-D matrix with shape | |||
/// `[filter_height * filter_width * in_channels, output_channels]`. | |||
/// 2. Extracts image patches from the input tensor to form a *virtual* | |||
/// tensor of shape `[batch, out_height, out_width, | |||
/// filter_height * filter_width * in_channels]`. | |||
/// 3. For each patch, right-multiplies the filter matrix and the image patch | |||
/// vector. | |||
/// </summary> | |||
/// <param name="parameters"></param> | |||
/// <returns></returns> | |||
public static Tensor conv2d(Conv2dParams parameters) | |||
=> tf.Context.ExecuteOp("Conv2D", parameters.Name, new ExecuteOpArgs(parameters.Input, parameters.Filter) | |||
.SetAttributes(new | |||
{ | |||
strides = parameters.Strides, | |||
padding = parameters.Padding, | |||
use_cudnn_on_gpu = parameters.UseCudnnOnGpu, | |||
explicit_paddings = parameters.ExplicitPaddings, | |||
data_format = parameters.DataFormat, | |||
dilations = parameters.Dilations | |||
})); | |||
/// <summary> | |||
/// Computes the gradients of convolution with respect to the filter. | |||
/// </summary> | |||
/// <param name="parameters"></param> | |||
/// <returns></returns> | |||
public static Tensor conv2d_backprop_filter(Tensor input, Tensor filter_sizes, Tensor out_backprop, | |||
int[] strides, string padding, bool use_cudnn_on_gpu = true, | |||
int[] explicit_paddings = null, | |||
string data_format = "NHWC", | |||
int[] dilations = null, | |||
string name = null) | |||
=> tf.Context.ExecuteOp("Conv2DBackpropFilter", name, new ExecuteOpArgs(input, filter_sizes, out_backprop) | |||
.SetAttributes(new | |||
{ | |||
strides, | |||
padding, | |||
use_cudnn_on_gpu, | |||
explicit_paddings = explicit_paddings ?? new int[0], | |||
data_format, | |||
dilations = dilations ?? new int[] { 1, 1, 1, 1 } | |||
})); | |||
/// <summary> | |||
/// Computes the gradients of convolution with respect to the input. | |||
/// </summary> | |||
/// <param name="parameters"></param> | |||
/// <returns></returns> | |||
public static Tensor conv2d_backprop_input(Tensor input_sizes, Tensor filter, Tensor out_backprop, | |||
int[] strides, string padding, bool use_cudnn_on_gpu = true, | |||
int[] explicit_paddings = null, | |||
string data_format = "NHWC", | |||
int[] dilations = null, | |||
string name = null) | |||
=> tf.Context.ExecuteOp("Conv2DBackpropInput", name, new ExecuteOpArgs(input_sizes, filter, out_backprop) | |||
.SetAttributes(new | |||
{ | |||
strides, | |||
padding, | |||
use_cudnn_on_gpu, | |||
explicit_paddings = explicit_paddings ?? new int[0], | |||
data_format, | |||
dilations = dilations ?? new int[] { 1, 1, 1, 1 } | |||
})); | |||
public static Tensor bias_add(Tensor value, | |||
IVariableV1 bias, | |||
string data_format = null, | |||
string name = null) | |||
=> tf.Context.ExecuteOp("BiasAdd", name, new ExecuteOpArgs(value, bias) | |||
.SetAttributes(new { data_format = data_format ?? "NHWC" })); | |||
public static Tensor bias_add_grad(Tensor out_backprop, | |||
string data_format = "NHWC", | |||
string name = null) | |||
=> tf.Context.ExecuteOp("BiasAddGrad", name, new ExecuteOpArgs(out_backprop) | |||
.SetAttributes(new { data_format = data_format ?? "NHWC" })); | |||
/// <summary> | |||
/// Computes exponential linear: <c>exp(features) - 1</c> if &lt; 0, <c>features</c> otherwise. | |||
/// </summary> | |||
/// <param name="features"> | |||
/// </param> | |||
/// <param name="name"> | |||
/// If specified, the created operation in the graph will be this one, otherwise it will be named 'Elu'. | |||
/// </param> | |||
/// <returns> | |||
/// The Operation can be fetched from the resulting Tensor, by fetching the Operation property from the result. | |||
/// </returns> | |||
/// <remarks> | |||
/// See [Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs) | |||
/// ](http://arxiv.org/abs/1511.07289) | |||
/// </remarks> | |||
public static Tensor elu(Tensor features, string name = "Elu") | |||
{ | |||
var op = tf.OpDefLib._apply_op_helper("Elu", name: name, args: new { features }); | |||
return op.output; | |||
} | |||
/// <summary> | |||
/// Gradient for batch normalization. | |||
/// </summary> | |||
/// <param name="params"></param> | |||
/// <returns></returns> | |||
public static Tensor[] fused_batch_norm_grad(FusedBatchNormParams @params) | |||
{ | |||
var op = tf.OpDefLib._apply_op_helper("FusedBatchNormGrad", name: @params.Name, args: new | |||
{ | |||
y_backprop = @params.YBackprop, | |||
x = @params.X, | |||
scale = @params.Scale, | |||
reserve_space_1 = @params.ReserveSpace1, | |||
reserve_space_2 = @params.ReserveSpace2, | |||
epsilon = @params.Epsilon, | |||
data_format = @params.DataFormat, | |||
is_training = @params.IsTraining | |||
}); | |||
return op.outputs; | |||
} | |||
public static Tensor[] fused_batch_norm_grad_v3(FusedBatchNormParams @params) | |||
=> tf.Context.ExecuteOp("FusedBatchNormGradV3", @params.Name, | |||
new ExecuteOpArgs(@params.YBackprop, | |||
@params.X, | |||
@params.Scale, | |||
@params.ReserveSpace1, | |||
@params.ReserveSpace2, | |||
@params.ReserveSpace3) | |||
.SetAttributes(new | |||
{ | |||
epsilon = @params.Epsilon, | |||
data_format = @params.DataFormat, | |||
is_training = @params.IsTraining | |||
})); | |||
public static Tensor[] fused_batch_norm(Tensor x, | |||
Tensor scale, | |||
Tensor offset, | |||
Tensor mean, | |||
Tensor variance, | |||
float epsilon = 0.0001f, | |||
string data_format = "NHWC", | |||
bool is_training = true, | |||
string name = null) | |||
{ | |||
var _op = tf.OpDefLib._apply_op_helper("FusedBatchNorm", name: name, args: new | |||
{ | |||
x, | |||
scale, | |||
offset, | |||
mean, | |||
variance, | |||
epsilon, | |||
data_format, | |||
is_training | |||
}); | |||
return _op.outputs; | |||
} | |||
public static Tensors fused_batch_norm_v3(Tensor x, | |||
Tensor scale, | |||
Tensor offset, | |||
Tensor mean, | |||
Tensor variance, | |||
float epsilon = 0.0001f, | |||
float exponential_avg_factor = 1.0f, | |||
string data_format = "NHWC", | |||
bool is_training = true, | |||
string name = null) | |||
=> tf.Context.ExecuteOp("FusedBatchNormV3", name, new ExecuteOpArgs(x, scale, offset, mean, variance) | |||
.SetAttributes(new { epsilon, data_format, is_training })); | |||
/// <summary> | |||
/// Local Response Normalization. | |||
/// </summary> | |||
/// <param name="input"></param> | |||
/// <param name="depth_radius"></param> | |||
/// <param name="bias"></param> | |||
/// <param name="alpha"></param> | |||
/// <param name="beta"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static Tensor local_response_normalization(Tensor input, int depth_radius = 5, int bias = 1, | |||
int alpha = 1, float beta = 0.5f, string name = null) | |||
{ | |||
var _op = tf.OpDefLib._apply_op_helper("LRN", name: name, args: new | |||
{ | |||
input, | |||
depth_radius, | |||
bias, | |||
alpha, | |||
beta | |||
}); | |||
return _op.output; | |||
} | |||
public static Tensor log_softmax(Tensor logits, string name = null) | |||
=> tf.Context.ExecuteOp("LogSoftmax", name, new ExecuteOpArgs(logits)); | |||
/// <summary> | |||
/// Says whether the targets are in the top `K` predictions. | |||
/// </summary> | |||
/// <param name="predictions"></param> | |||
/// <param name="targets"></param> | |||
/// <param name="k"></param> | |||
/// <param name="name"></param> | |||
/// <returns>A `Tensor` of type `bool`.</returns> | |||
public static Tensor in_top_kv2(Tensor predictions, Tensor targets, int k, string name = null) | |||
=> tf.Context.ExecuteOp("InTopKV2", name, | |||
new ExecuteOpArgs(predictions, targets, k)); | |||
public static Tensor leaky_relu(Tensor features, float alpha = 0.2f, string name = null) | |||
=> tf.Context.ExecuteOp("LeakyRelu", name, | |||
new ExecuteOpArgs(features).SetAttributes(new { alpha })); | |||
public static Tensor average_pool(Tensor input, | |||
int[] ksize, | |||
int[] strides, | |||
string padding, | |||
string data_format = "NHWC", | |||
string name = null) | |||
=> tf.Context.ExecuteOp("AvgPool", name, new ExecuteOpArgs(input) | |||
.SetAttributes(new | |||
{ | |||
ksize, | |||
strides, | |||
padding, | |||
data_format | |||
})); | |||
public static Tensor max_pool(Tensor input, | |||
int[] ksize, | |||
int[] strides, | |||
string padding, | |||
string data_format = "NHWC", | |||
string name = null) | |||
=> tf.Context.ExecuteOp("MaxPool", name, new ExecuteOpArgs(input) | |||
.SetAttributes(new | |||
{ | |||
ksize, | |||
strides, | |||
padding, | |||
data_format | |||
})); | |||
public static Tensor max_pool_grad(Tensor orig_input, Tensor orig_output, Tensor grad, int[] ksize, int[] strides, string padding, | |||
string data_format = "NHWC", string name = null) | |||
=> tf.Context.ExecuteOp("MaxPoolGrad", name, new ExecuteOpArgs(orig_input, orig_output, grad) | |||
.SetAttributes(new | |||
{ | |||
ksize, | |||
strides, | |||
padding, | |||
data_format | |||
})); | |||
public static Tensor[] top_kv2<T>(Tensor input, T k, bool sorted = true, string name = null) | |||
{ | |||
var _op = tf.OpDefLib._apply_op_helper("TopKV2", name: name, args: new | |||
{ | |||
input, | |||
k, | |||
sorted | |||
}); | |||
return _op.outputs; | |||
} | |||
public static Tensor relu_grad(Tensor gradients, Tensor features, string name = null) | |||
=> tf.Context.ExecuteOp("ReluGrad", name, new ExecuteOpArgs(gradients, features)); | |||
public static Tensor leaky_relu_grad(Tensor gradients, Tensor features, float alpha = 0.2f, string name = null) | |||
=> tf.Context.ExecuteOp("LeakyReluGrad", name, new ExecuteOpArgs(gradients, features) | |||
.SetAttributes(new { alpha })); | |||
public static Tensor softmax(Tensor logits, string name = null) | |||
=> tf.Context.ExecuteOp("Softmax", name, new ExecuteOpArgs(logits)); | |||
/// <summary> | |||
/// Computes softmax cross entropy cost and gradients to backpropagate. | |||
/// </summary> | |||
/// <param name="features"></param> | |||
/// <param name="labels"></param> | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static (Tensor, Tensor) softmax_cross_entropy_with_logits(Tensor features, Tensor labels, string name = null) | |||
{ | |||
var results = tf.Context.ExecuteOp("SoftmaxCrossEntropyWithLogits", name, new ExecuteOpArgs(features, labels)); | |||
return (results[0], results[1]); | |||
} | |||
/// <summary> | |||
/// Computes softmax cross entropy cost and gradients to backpropagate. | |||
/// </summary> | |||
/// <param name="features"> | |||
/// batch_size x num_classes matrix | |||
/// </param> | |||
/// <param name="labels"> | |||
/// batch_size vector with values in [0, num_classes). | |||
/// This is the label for the given minibatch entry. | |||
/// </param> | |||
/// <param name="name"> | |||
/// If specified, the created operation in the graph will be this one, otherwise it will be named 'SparseSoftmaxCrossEntropyWithLogits'. | |||
/// </param> | |||
/// <returns> | |||
/// Returns a tuple with multiple values, as follows: | |||
/// loss : Per example loss (batch_size vector). | |||
/// backprop : backpropagated gradients (batch_size x num_classes matrix). | |||
/// The Operation can be fetched from any of the Tensorreturned in the tuple values, by fetching the Operation property. | |||
/// </returns> | |||
/// <remarks> | |||
/// Unlike <c>SoftmaxCrossEntropyWithLogits</c>, this operation does not accept | |||
/// a matrix of label probabilities, but rather a single label per row | |||
/// of features. This label is considered to have probability 1.0 for the | |||
/// given row. | |||
/// | |||
/// Inputs are the logits, not probabilities. | |||
/// </remarks> | |||
public static (Tensor loss, Tensor backprop) sparse_softmax_cross_entropy_with_logits(Tensor features, Tensor labels, string name = "SparseSoftmaxCrossEntropyWithLogits") | |||
{ | |||
var results = tf.Context.ExecuteOp("SparseSoftmaxCrossEntropyWithLogits", name, new ExecuteOpArgs(features, labels)); | |||
return (results[0], results[1]); | |||
} | |||
/// <summary> | |||
/// Computes rectified linear: `max(features, 0)`. | |||
/// </summary> | |||
/// <param name="features">A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `int64`, `bfloat16`, `uint16`, `half`, `uint32`, `uint64`, `qint8`.</param> | |||
/// <param name="name">A name for the operation (optional).</param> | |||
/// <returns>A `Tensor`. Has the same type as `features`.</returns> | |||
public static Tensor relu(Tensor features, string name = null) | |||
=> tf.Context.ExecuteOp("Relu", name, new ExecuteOpArgs(features)); | |||
public static Tensor tanh(Tensor x, string name = null) | |||
=> tf.Context.ExecuteOp("Tanh", name, new ExecuteOpArgs(x)); | |||
} | |||
} |
@@ -103,6 +103,11 @@ namespace Tensorflow | |||
DataType dtype = DataType.DtInvalid; | |||
DataType default_dtype = DataType.DtInvalid; | |||
if (values is Tensors tensors) | |||
{ | |||
values = (Tensor[])tensors; | |||
} | |||
if (_IsListParameter(input_arg)) | |||
{ | |||
if (!_IsListValue(values)) | |||
@@ -187,6 +187,33 @@ namespace Tensorflow | |||
public virtual T get_attr<T>(string name) | |||
=> (T)get_attr(name); | |||
internal unsafe TF_DataType _get_attr_type(string name) | |||
{ | |||
Status status = new(); | |||
TF_DataType result; | |||
c_api.TF_OperationGetAttrType(_handle, name, new IntPtr(&result), status); | |||
status.Check(true); | |||
return result; | |||
} | |||
internal unsafe int _get_attr_int(string name) | |||
{ | |||
Status status = new(); | |||
int result; | |||
c_api.TF_OperationGetAttrInt(_handle, name, new IntPtr(&result), status); | |||
status.Check(true); | |||
return result; | |||
} | |||
internal unsafe bool _get_attr_bool(string name) | |||
{ | |||
Status status = new(); | |||
bool result; | |||
c_api.TF_OperationGetAttrBool(_handle, name, new IntPtr(&result), status); | |||
status.Check(true); | |||
return result; | |||
} | |||
public virtual T[] get_attr_list<T>(string name) | |||
{ | |||
if (tf.executing_eagerly()) | |||
@@ -229,7 +256,42 @@ namespace Tensorflow | |||
if(oneof_value == AttrValue.ValueOneofCase.List) | |||
{ | |||
throw new NotImplementedException($"Unsupported field type in {oneof_value}"); | |||
if (x.List.S is not null && x.List.S.Count > 0) | |||
{ | |||
return x.List.S.Select(x => x.ToStringUtf8()).ToArray(); | |||
} | |||
else if (x.List.I is not null && x.List.I.Count > 0) | |||
{ | |||
return x.List.I.ToArray(); | |||
} | |||
else if (x.List.F is not null && x.List.F.Count > 0) | |||
{ | |||
return x.List.F.ToArray(); | |||
} | |||
else if (x.List.B is not null && x.List.B.Count > 0) | |||
{ | |||
return x.List.B.ToArray(); | |||
} | |||
else if (x.List.Shape is not null && x.List.Shape.Count > 0) | |||
{ | |||
return x.List.Shape.ToArray(); | |||
} | |||
else if (x.List.Tensor is not null && x.List.Tensor.Count > 0) | |||
{ | |||
return x.List.Tensor.ToArray(); | |||
} | |||
else if (x.List.Func is not null && x.List.Func.Count > 0) | |||
{ | |||
return x.List.Func.ToArray(); | |||
} | |||
else if (x.List.Type is not null && x.List.Type.Count > 0) | |||
{ | |||
return x.List.Type.Select(x => x.as_tf_dtype()).ToArray(); | |||
} | |||
else | |||
{ | |||
return null; | |||
} | |||
} | |||
if(oneof_value == AttrValue.ValueOneofCase.Type) | |||
{ | |||
@@ -22,12 +22,13 @@ using Tensorflow.Contexts; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Framework; | |||
using static Tensorflow.Binding; | |||
using System.Diagnostics; | |||
namespace Tensorflow | |||
{ | |||
public class array_ops | |||
{ | |||
public static Tensor placeholder_with_default<T>(T input, int[] shape, string name = null) | |||
public static Tensor placeholder_with_default(Tensor input, int[] shape, string name = null) | |||
=> gen_array_ops.placeholder_with_default(input, shape, name); | |||
/// <summary> | |||
@@ -132,7 +133,7 @@ namespace Tensorflow | |||
if (ndims_mask < 1) | |||
throw new ValueError("mask cannot be scalar."); | |||
var leading_size = gen_math_ops.prod(shape(tensor_tensor)[$"{axis}:{axis + ndims_mask}"], new[] { 0 }); | |||
var leading_size = gen_math_ops.prod(shape(tensor_tensor)[$"{axis}:{axis + ndims_mask}"], ops.convert_to_tensor(new[] { 0 })); | |||
var shape1 = concat(new[] | |||
{ | |||
shape(tensor_tensor)[$":{axis}"], | |||
@@ -153,7 +154,7 @@ namespace Tensorflow | |||
private static Tensor _apply_mask_1d(Tensor reshaped_tensor, Tensor mask, int axis = 0) | |||
{ | |||
var indices = squeeze(where(mask), axis: new[] { 1 }); | |||
return gather(reshaped_tensor, indices, axis: axis); | |||
return gather(reshaped_tensor, indices, axis: ops.convert_to_tensor(axis)); | |||
} | |||
public static Tensor zeros(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | |||
@@ -293,7 +294,7 @@ namespace Tensorflow | |||
} | |||
public static Tensor expand_dims(Tensor input, int axis = -1, string name = null) | |||
=> gen_array_ops.expand_dims(input, axis, name); | |||
=> gen_array_ops.expand_dims(input, ops.convert_to_tensor(axis), name); | |||
/// <summary> | |||
/// Creates a tensor filled with a scalar value. | |||
@@ -304,7 +305,7 @@ namespace Tensorflow | |||
/// <param name="name">Optional string. The name of the output `tf.Tensor`.</param> | |||
/// <returns>A `tf.Tensor` with shape `dims` and the same dtype as `value`.</returns> | |||
public static Tensor fill<T>(Shape dims, T value, string name = null) | |||
=> gen_array_ops.fill(dims, value, name: name); | |||
=> gen_array_ops.fill(dims, ops.convert_to_tensor(value), name: name); | |||
/// <summary> | |||
/// Returns the rank of a tensor. | |||
@@ -368,7 +369,7 @@ namespace Tensorflow | |||
=> gen_array_ops.reshape(tensor, shape, name: name); | |||
public static Tensor reshape(Tensor tensor, object[] shape, string name = null) | |||
=> gen_array_ops.reshape(tensor, shape, name: name); | |||
=> gen_array_ops.reshape(tensor, ops.convert_to_tensor(shape), name: name); | |||
private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true) | |||
{ | |||
@@ -466,7 +467,11 @@ namespace Tensorflow | |||
} | |||
public static (Tensor, Tensor) unique(Tensor x, TF_DataType out_idx = TF_DataType.TF_INT32, string name = null) | |||
=> gen_array_ops.unique(x, out_idx: out_idx, name: name); | |||
{ | |||
var res = gen_array_ops.unique(x, out_idx: out_idx, name: name); | |||
Debug.Assert(res.Length == 2); | |||
return (res[0], res[1]); | |||
} | |||
public static Tensor stack(Tensor[] values, int axis = 0, string name = "stack") | |||
{ | |||
@@ -492,12 +497,12 @@ namespace Tensorflow | |||
{ | |||
name = scope; | |||
condition = ops.convert_to_tensor(condition, preferred_dtype: dtypes.@bool, name: "condition"); | |||
return gen_array_ops.where(condition: condition, name: name); | |||
return gen_array_ops.where(condition, name: name); | |||
}); | |||
} | |||
else if (x != null && y != null) | |||
{ | |||
return gen_array_ops.select(condition, x, y, name); | |||
return gen_math_ops.select(condition, ops.convert_to_tensor(x), ops.convert_to_tensor(y), name); | |||
} | |||
else | |||
{ | |||
@@ -505,7 +510,6 @@ namespace Tensorflow | |||
} | |||
} | |||
public static Tensor where_v2(Tensor condition, object x = null, object y = null, string name = null) | |||
{ | |||
if (x == null && y == null) | |||
@@ -514,18 +518,19 @@ namespace Tensorflow | |||
{ | |||
name = scope; | |||
condition = ops.convert_to_tensor(condition, preferred_dtype: dtypes.@bool, name: "condition"); | |||
return gen_array_ops.where(condition: condition, name: name); | |||
return gen_array_ops.where(condition, name: name); | |||
}); | |||
} | |||
else if (x != null && y != null) | |||
{ | |||
return gen_array_ops.select_v2(condition, x, y, name); | |||
return gen_math_ops.select_v2(condition, ops.convert_to_tensor(x), ops.convert_to_tensor(y), name); | |||
} | |||
else | |||
{ | |||
throw new ValueError("x and y must both be non-None or both be None."); | |||
} | |||
} | |||
/// <summary> | |||
/// Returns the shape of a tensor. | |||
/// </summary> | |||
@@ -634,7 +639,13 @@ namespace Tensorflow | |||
/// <param name="name"></param> | |||
/// <returns></returns> | |||
public static Tensor stop_gradient(Tensor input, string name = null) | |||
=> tf.Context.ExecuteOp("StopGradient", name, new ExecuteOpArgs(input)); | |||
{ | |||
var tape = tf.GradientTape().stop_recording(); | |||
var result = gen_array_ops.stop_gradient(input, name); | |||
tape.StartRecord(); | |||
tf.GradientTape().PushTape(tape); | |||
return result; | |||
} | |||
/// <summary> | |||
/// Extracts a strided slice of a tensor (generalized python array indexing). | |||
@@ -858,7 +869,7 @@ namespace Tensorflow | |||
}); | |||
} | |||
return gen_array_ops.concat_v2(values, axis, name: name); | |||
return gen_array_ops.concat_v2(values, ops.convert_to_tensor(axis), name: name); | |||
} | |||
public static Tensor concat(Tensor[] values, Tensor axis, string name = "concat") | |||
@@ -868,7 +879,7 @@ namespace Tensorflow | |||
public static Tensor concat(object[] values, int axis, string name = "concat") | |||
{ | |||
return gen_array_ops.concat_v2(values, axis, name: name); | |||
return tf.Context.ExecuteOp("ConcatV2", name, new ExecuteOpArgs(values, axis)); | |||
} | |||
/// <summary> | |||
@@ -886,18 +897,33 @@ namespace Tensorflow | |||
/// </param> | |||
/// <param name="batch_dims">An integer. The number of batch dimensions. Must be less than or equal to rank(indices).</param> | |||
/// <returns></returns> | |||
public static Tensor gather<T1, T2>(T1 @params, T2 indices, string name = null, int axis = 0, int batch_dims = 0) | |||
public static Tensor gather(Tensor @params, Tensor indices, string name = null, Tensor axis = null, int batch_dims = 0) | |||
{ | |||
if (axis != 0) | |||
return gen_array_ops.gather_v2(@params, indices, axis, name: name); | |||
if (@params is ResourceVariable variable && | |||
indices is Tensor indices_tensor) | |||
return variable.sparse_read(indices_tensor, name); | |||
if (axis is null) | |||
axis = tf.convert_to_tensor(batch_dims); | |||
if(tensor_util.constant_value(axis) != 0) | |||
{ | |||
return gen_array_ops.gather_v2(@params, indices, axis, batch_dims: batch_dims, name: name); | |||
} | |||
return gen_array_ops.gather_v2(@params, indices, axis, name: name); | |||
} | |||
public static Tensor gather(Tensor @params, Tensor indices, int axis, string name = null, int batch_dims = 0) | |||
=> gather(@params, indices, name, ops.convert_to_tensor(axis), batch_dims); | |||
public static Tensor gather(ResourceVariable @params, Tensor indices, string name = null, Tensor axis = null, int batch_dims = 0) | |||
{ | |||
if (axis is null) | |||
axis = tf.convert_to_tensor(batch_dims); | |||
if (tensor_util.constant_value(axis) != 0) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
return @params.sparse_read(indices, name); | |||
} | |||
public static Tensor transpose<T1>(T1 a, Axis perm, string name = "transpose", bool conjugate = false) | |||
{ | |||
return tf_with(ops.name_scope(name, "transpose", new { a }), scope => | |||
@@ -927,7 +953,7 @@ namespace Tensorflow | |||
if (num == -1) | |||
num = (int)size_splits.shape[0]; | |||
return gen_array_ops.split_v(value, size_splits, axis, num, name: name); | |||
return gen_array_ops.split_v(value, size_splits, tf.convert_to_tensor(axis), num, name: name); | |||
} | |||
public static Tensor[] split<T>(Tensor value, int num_split, T axis, | |||
@@ -956,20 +982,10 @@ namespace Tensorflow | |||
} | |||
public static Tensor slice(Tensor input, Tensor[] begin, Tensor[] size, string name = null) | |||
=> gen_array_ops.slice(input, begin, size, name: name); | |||
public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null) | |||
=> gen_array_ops.slice(input, begin, size, name: name); | |||
=> gen_array_ops.slice(input, ops.convert_to_tensor(begin), ops.convert_to_tensor(size), name: name); | |||
public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name = null) | |||
=> tf.Context.ExecuteOp("Slice", name, new ExecuteOpArgs(input, begin, size) | |||
{ | |||
GetGradientAttrs = (op) => new | |||
{ | |||
T = op.get_attr<TF_DataType>("T"), | |||
Index = op.get_attr<int>("Index") | |||
} | |||
}); | |||
=> gen_array_ops.slice(input, begin, size, name: name); | |||
public static Tensor stack(object values, int axis = 0, string name = "stack") | |||
@@ -233,7 +233,7 @@ namespace Tensorflow | |||
{ | |||
try | |||
{ | |||
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("AnonymousIteratorV3", name) | |||
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(tf.Context, "AnonymousIteratorV3", name) | |||
{ | |||
attrs = attrs | |||
}); | |||
@@ -250,7 +250,7 @@ namespace Tensorflow | |||
public Tensor anonymous_iterator_v3_eager_fallback(TF_DataType[] output_types, Shape[] output_shapes, string name, Context ctx) | |||
{ | |||
object[] attrs = new object[] { output_types, output_shapes }; | |||
var result = execute.quick_execute("AnonymousIteratorV3", 1, new Tensor[] { }, attrs, ctx, name); | |||
var result = _execute.quick_execute("AnonymousIteratorV3", 1, new Tensor[] { }, attrs, ctx, name); | |||
return result[0]; | |||
} | |||
@@ -19,7 +19,7 @@ namespace Tensorflow.Operations | |||
{ | |||
try | |||
{ | |||
return tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("PartitionedCall", name, | |||
return tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(tf.Context, "PartitionedCall", name, | |||
args, tout, f, config, config_proto, executor_type)); | |||
} | |||
catch (Exception) | |||
@@ -50,7 +50,7 @@ namespace Tensorflow.Operations | |||
var output = tf.OpDefLib._apply_op_helper("PartitionedCall", | |||
name, kwargs); | |||
var result = output.outputs; | |||
if (execute.must_record_gradient()) | |||
if (_execute.must_record_gradient()) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
@@ -88,7 +88,7 @@ namespace Tensorflow.Operations | |||
try | |||
{ | |||
var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo( | |||
"SymbolicGradient", name, input, Tout, f)); | |||
tf.Context, "SymbolicGradient", name, input, Tout, f)); | |||
return _result; | |||
} | |||
catch (Exception) | |||
@@ -107,7 +107,7 @@ namespace Tensorflow.Operations | |||
} | |||
var op = tf.OpDefLib._apply_op_helper("SymbolicGradient", name, new object[] { input, Tout, f }); | |||
var result = op.outputs; | |||
if (execute.must_record_gradient()) | |||
if (_execute.must_record_gradient()) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
@@ -117,8 +117,8 @@ namespace Tensorflow.Operations | |||
public static Tensor[] symbolic_gradient_eager_fallback(Tensor[] input, TF_DataType[] Tout, NameAttrList f, string name, Context ctx) | |||
{ | |||
object[] attrs = new object[] { "Tin", input, "Tout", Tout, "f", f }; | |||
var result = execute.executes("SymbolicGradient", Tout.Length, input, attrs, ctx, name); | |||
if (execute.must_record_gradient()) | |||
var result = _execute.execute("SymbolicGradient", Tout.Length, input, attrs, ctx, name); | |||
if (_execute.must_record_gradient()) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
@@ -26,7 +26,7 @@ namespace Tensorflow | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo( | |||
"Assert", name, | |||
tf.Context, "Assert", name, | |||
new object[] { condition, data, summarize })); | |||
return results[0]; | |||
@@ -1,11 +0,0 @@ | |||
using System; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
public static partial class gen_math_ops | |||
{ | |||
public static Tensor mul(IntPtr x, IntPtr y, string name = null) | |||
=> tf.Context.ExecuteOp("Mul", name, new ExecuteOpArgs(x, y)); | |||
} | |||
} |
@@ -10055,7 +10055,7 @@ namespace Tensorflow.Operations | |||
{ | |||
try | |||
{ | |||
var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("EnsureShape", name, input, shape)); | |||
var _result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(tf.Context, "EnsureShape", name, input, shape)); | |||
return _result[0]; | |||
} | |||
catch (Exception) | |||
@@ -10076,7 +10076,7 @@ namespace Tensorflow.Operations | |||
dict["input"] = input; | |||
dict["shape"] = shape; | |||
var op = tf.OpDefLib._apply_op_helper("EnsureShape", name: name, keywords: dict); | |||
if (execute.must_record_gradient()) | |||
if (_execute.must_record_gradient()) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
@@ -10086,9 +10086,9 @@ namespace Tensorflow.Operations | |||
public static Tensor ensure_shape_eager_fallback(Tensor input, Shape shape, string name, Context ctx) | |||
{ | |||
object[] attrs = new object[4] { "shape", shape, "T", input.dtype.as_datatype_enum() }; | |||
var _result = execute.executes("EnsureShape", 1, new Tensor[] { input }, | |||
var _result = _execute.execute("EnsureShape", 1, new Tensor[] { input }, | |||
attrs, ctx, name); | |||
if (execute.must_record_gradient()) | |||
if (_execute.must_record_gradient()) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
@@ -17194,7 +17194,7 @@ namespace Tensorflow.Operations | |||
var ctx = tf.Context; | |||
if (ctx.executing_eagerly()) | |||
{ | |||
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("MergeV2Checkpoints", name, | |||
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(tf.Context, "MergeV2Checkpoints", name, | |||
checkpoint_prefixes, destination_prefix, "delete_old_dirs", delete_old_dirs, "allow_missing_files", allow_missing_files)); | |||
result = null; | |||
return null; | |||
@@ -24297,7 +24297,7 @@ namespace Tensorflow.Operations | |||
var ctx = tf.Context; | |||
if (ctx.executing_eagerly()) | |||
{ | |||
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("RegexFullMatch", name, input, pattern)); | |||
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(tf.Context, "RegexFullMatch", name, input, pattern)); | |||
return result[0]; | |||
} | |||
var dict = new Dictionary<string, object>(); | |||
@@ -27201,7 +27201,7 @@ namespace Tensorflow.Operations | |||
Dictionary<string, object> attrs = new(); | |||
attrs["dtypes"] = dtypes; | |||
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo( | |||
"RestoreV2", name, prefix, tensor_names, shape_and_slices | |||
tf.Context, "RestoreV2", name, prefix, tensor_names, shape_and_slices | |||
) | |||
{ attrs = attrs }); | |||
return result; | |||
@@ -27236,9 +27236,9 @@ namespace Tensorflow.Operations | |||
var shape_and_slices_tensor = ops.convert_to_tensor(shape_and_slices, TF_DataType.TF_STRING); | |||
object[] attrs = new object[] { "dtypes", dtypes }; | |||
Tensor[] inputs_flat = new Tensor[] { prefix, tensor_names_tensor, shape_and_slices_tensor }; | |||
var result = execute.quick_execute("RestoreV2", dtypes.Length, inputs_flat, attrs, ctx, name); | |||
var result = _execute.quick_execute("RestoreV2", dtypes.Length, inputs_flat, attrs, ctx, name); | |||
if (execute.must_record_gradient()) | |||
if (_execute.must_record_gradient()) | |||
{ | |||
// TODO(Rinne); record the gradient | |||
} | |||
@@ -29829,7 +29829,7 @@ namespace Tensorflow.Operations | |||
var ctx = tf.Context; | |||
if (ctx.executing_eagerly()) | |||
{ | |||
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("ShardedFilename", name, basename, shard, num_shards)); | |||
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(tf.Context, "ShardedFilename", name, basename, shard, num_shards)); | |||
return result[0]; | |||
} | |||
var dict = new Dictionary<string, object>(); | |||
@@ -34759,7 +34759,7 @@ namespace Tensorflow.Operations | |||
var ctx = tf.Context; | |||
if (ctx.executing_eagerly()) | |||
{ | |||
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("StringJoin", name, inputs, "separator", separator)); | |||
var result = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(tf.Context, "StringJoin", name, inputs, "separator", separator)); | |||
return result[0]; | |||
} | |||
var dict = new Dictionary<string, object>(); | |||
@@ -25,7 +25,7 @@ namespace Tensorflow | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo( | |||
"AssignSubVariableOp", name, resource, value)); | |||
tf.Context, "AssignSubVariableOp", name, resource, value)); | |||
return null; | |||
} | |||
@@ -44,7 +44,7 @@ namespace Tensorflow | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("AssignAddVariableOp", name, | |||
tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(tf.Context, "AssignAddVariableOp", name, | |||
resource, value)); | |||
return null; | |||
@@ -59,7 +59,7 @@ namespace Tensorflow | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("AssignVariableOp", name, | |||
tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(tf.Context, "AssignVariableOp", name, | |||
resource, value)); | |||
return null; | |||
@@ -74,7 +74,7 @@ namespace Tensorflow | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("VarIsInitializedOp", name, | |||
var results = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(tf.Context, "VarIsInitializedOp", name, | |||
resource)); | |||
return results[0]; | |||
@@ -99,7 +99,7 @@ namespace Tensorflow | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("VarHandleOp", name) | |||
var results = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo(tf.Context, "VarHandleOp", name) | |||
{ | |||
attrs = ConvertToDict(new | |||
{ | |||
@@ -177,11 +177,11 @@ namespace Tensorflow | |||
if (shape.ndim == 3 || shape.ndim == Unknown) | |||
{ | |||
Tensor uniform_random = random_ops.random_uniform(new int[] { }, 0f, 1.0f, seed: seed); | |||
var mirror_cond = gen_math_ops.less(uniform_random, .5); | |||
var mirror_cond = gen_math_ops.less(uniform_random, ops.convert_to_tensor(.5)); | |||
var result = control_flow_ops.cond( | |||
pred: mirror_cond, | |||
true_fn: () => gen_array_ops.reverse(image, new { flip_index }), | |||
true_fn: () => gen_array_ops.reverse(image, ops.convert_to_tensor(new int[] { flip_index })), | |||
false_fn: () => image, | |||
name: scope | |||
); | |||
@@ -197,7 +197,7 @@ namespace Tensorflow | |||
var flips = math_ops.round( | |||
array_ops.reshape(uniform_random, shape: array_ops.constant(value: new object[] { batch_size[0], 1, 1, 1 }))); | |||
flips = math_ops.cast(flips, image.dtype); | |||
var flipped_input = gen_array_ops.reverse(image, new int[] { flip_index + 1 }); | |||
var flipped_input = gen_array_ops.reverse(image, ops.convert_to_tensor(new int[] { flip_index + 1 })); | |||
return flips * flipped_input + (1 - flips) * image; | |||
} | |||
else | |||
@@ -222,11 +222,11 @@ namespace Tensorflow | |||
Shape shape = image.shape; | |||
if (shape.ndim == 3 || shape.ndim == Unknown) | |||
{ | |||
return fix_image_flip_shape(image, gen_array_ops.reverse(image, new { flip_index })); | |||
return fix_image_flip_shape(image, gen_array_ops.reverse(image, ops.convert_to_tensor(new int[] { flip_index }))); | |||
} | |||
else if (shape.ndim == 4) | |||
{ | |||
return gen_array_ops.reverse(image, new[] { flip_index + 1 }); | |||
return gen_array_ops.reverse(image, ops.convert_to_tensor(new[] { flip_index + 1 })); | |||
} | |||
else | |||
{ | |||
@@ -268,15 +268,15 @@ namespace Tensorflow | |||
{ | |||
Tensor _rot90() | |||
{ | |||
return array_ops.transpose(gen_array_ops.reverse(image, new[] { 1, 0, 2 }), new int[] { 1 }); | |||
return array_ops.transpose(gen_array_ops.reverse(image, ops.convert_to_tensor(new[] { 1, 0, 2 })), new int[] { 1 }); | |||
}; | |||
Tensor _rot180() | |||
{ | |||
return gen_array_ops.reverse(image, new[] { 0, 1 }); | |||
return gen_array_ops.reverse(image, ops.convert_to_tensor(new[] { 0, 1 })); | |||
}; | |||
Tensor _rot270() | |||
{ | |||
return gen_array_ops.reverse(array_ops.transpose(image, new[] { 1, 0, 2 }), new[] { 1 }); | |||
return gen_array_ops.reverse(array_ops.transpose(image, new[] { 1, 0, 2 }), ops.convert_to_tensor(new[] { 1 })); | |||
}; | |||
var cases = new[] {math_ops.equal(k, 1), _rot90(), | |||
@@ -1389,7 +1389,7 @@ new_height, new_width"); | |||
Operation[] checks = new Operation[] { }; | |||
checks.append( | |||
control_flow_ops.Assert( | |||
gen_math_ops.greater_equal(array_ops.size(shape1_tensor), 3), new[] { shape1, shape2 }, | |||
gen_math_ops.greater_equal(array_ops.size(shape1_tensor), ops.convert_to_tensor(3)), new[] { shape1, shape2 }, | |||
summarize: 10)); | |||
checks.append( | |||
control_flow_ops.Assert( | |||
@@ -1762,8 +1762,8 @@ new_height, new_width"); | |||
{ | |||
var batch_size = array_ops.shape(boxes)[0]; | |||
var new_slice = array_ops.slice( | |||
boxes, new object[] { 0, inner_idx * tile_size, 0 }, | |||
new object[] { batch_size, tile_size, 4 }); | |||
boxes, new Tensor[] { ops.convert_to_tensor(0), ops.convert_to_tensor(inner_idx * tile_size), ops.convert_to_tensor(0) }, | |||
new Tensor[] { ops.convert_to_tensor(batch_size), ops.convert_to_tensor(tile_size), ops.convert_to_tensor(4) }); | |||
var iou = _bbox_overlap(new_slice, box_slice); | |||
var box_slice_after_suppression = array_ops.expand_dims( | |||
math_ops.cast(math_ops.reduce_all(iou < iou_threshold, new(1)), | |||
@@ -1816,8 +1816,8 @@ new_height, new_width"); | |||
(Tensor, Tensor, Tensor, Tensor) cross_suppression_func(Tensor boxes, Tensor box_slice, Tensor iou_threshold, Tensor inner_idx, int tile_size) | |||
=> _cross_suppression(boxes, box_slice, iou_threshold, inner_idx, tile_size); | |||
var box_slice = array_ops.slice(boxes, new[] { 0, idx * tile_size, 0 }, | |||
new[] { batch_size, tile_size, 4 }); | |||
var box_slice = array_ops.slice(boxes, new Tensor[]{ ops.convert_to_tensor(0), ops.convert_to_tensor(idx * tile_size), ops.convert_to_tensor(0) }, | |||
new Tensor[] { ops.convert_to_tensor(batch_size), ops.convert_to_tensor(tile_size), ops.convert_to_tensor(4) }); | |||
var iou = _bbox_overlap(box_slice, box_slice); | |||
var mask = array_ops.expand_dims( | |||
@@ -31,7 +31,7 @@ namespace Tensorflow | |||
try | |||
{ | |||
var result = tf.Runner.TFE_FastPathExecute( | |||
new FastPathOpExecInfo("SaveV2", name, new object[] { prefix, tensor_names, shape_and_slices, tensors })); | |||
new FastPathOpExecInfo(tf.Context, "SaveV2", name, new object[] { prefix, tensor_names, shape_and_slices, tensors })); | |||
result = null; | |||
return null; | |||
} | |||
@@ -48,14 +48,14 @@ namespace Tensorflow | |||
public Operation save_v2_eager_fallback(Tensor prefix, string[] tensor_names, string[] shape_and_slices, Tensor[] tensors, string name, Context ctx) | |||
{ | |||
DataType[] attr_dtypes; | |||
(attr_dtypes, tensors) = execute.onvert_to_mixed_eager_tensors(tensors, ctx); | |||
(attr_dtypes, tensors) = _execute.onvert_to_mixed_eager_tensors(tensors, ctx); | |||
prefix = ops.convert_to_tensor(prefix, TF_DataType.TF_STRING); | |||
var tensor_names_tensor = ops.convert_to_tensor(tensor_names, TF_DataType.TF_STRING); | |||
var shape_and_slices_tensor = ops.convert_to_tensor(shape_and_slices, TF_DataType.TF_STRING); | |||
var inputs_flat = tensors.Concat(new Tensor[] { prefix, tensor_names_tensor, shape_and_slices_tensor }).ToArray(); | |||
var attrs = new object[] { "dtypes", attr_dtypes }; | |||
var result = execute.quick_execute("SaveV2", 0, inputs_flat, attrs, ctx, name); | |||
var result = _execute.quick_execute("SaveV2", 0, inputs_flat, attrs, ctx, name); | |||
result = null; | |||
return null; | |||
} | |||
@@ -21,6 +21,7 @@ using System.Linq; | |||
using Tensorflow.Framework; | |||
using static Tensorflow.Binding; | |||
using Tensorflow.Operations; | |||
using System.Runtime.CompilerServices; | |||
namespace Tensorflow | |||
{ | |||
@@ -39,18 +40,18 @@ namespace Tensorflow | |||
{ | |||
return gen_ops.complex_abs(x, Tout: x.dtype.real_dtype(), name: name); | |||
} | |||
return gen_math_ops._abs(x, name: name); | |||
return gen_math_ops.abs(x, name: name); | |||
}); | |||
} | |||
public static Tensor add<Tx, Ty>(Tx x, Ty y, string name = null) | |||
=> gen_math_ops.add(x, y, name); | |||
=> gen_math_ops.add(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name); | |||
public static Tensor add_v2(Tensor x, Tensor y, string name = null) | |||
=> tf.Context.ExecuteOp("AddV2", name, new ExecuteOpArgs(x, y)); | |||
public static Tensor add_v2<Tx, Ty>(Tx x, Ty y, string name = null) | |||
=> gen_math_ops.add_v2(x, y, name); | |||
=> gen_math_ops.add_v2(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name); | |||
/// <summary> | |||
/// Adds all input tensors element-wise. | |||
@@ -254,9 +255,9 @@ namespace Tensorflow | |||
} | |||
public static Tensor greater_equal<Tx, Ty>(Tx x, Ty y, string name = null) | |||
=> gen_math_ops.greater_equal<Tx, Ty>(x, y, name: name); | |||
=> gen_math_ops.greater_equal(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name); | |||
public static Tensor equal<Tx, Ty>(Tx x, Ty y, string name = null) | |||
=> gen_math_ops.equal(x, y, name: name); | |||
=> gen_math_ops.equal(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name); | |||
/// <summary> | |||
/// Computes the Gauss error function of `x` element-wise. | |||
@@ -274,13 +275,13 @@ namespace Tensorflow | |||
=> tf.Context.ExecuteOp("Mul", name, new ExecuteOpArgs(x, y)); | |||
public static Tensor multiply<Tx, Ty>(Tx x, Ty y, string name = null) | |||
=> gen_math_ops.mul(x, y, name: name); | |||
=> gen_math_ops.mul(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name); | |||
public static Tensor not_equal<Tx, Ty>(Tx x, Ty y, string name = null) | |||
=> gen_math_ops.not_equal(x, y, name: name); | |||
=> gen_math_ops.not_equal(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name); | |||
public static Tensor mul_no_nan<Tx, Ty>(Tx x, Ty y, string name = null) | |||
=> gen_math_ops.mul_no_nan(x, y, name: name); | |||
=> gen_math_ops.mul_no_nan(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name); | |||
public static Tensor scalar_mul<Tscale, Tx>(Tscale scale, Tx x, string name = null) | |||
=> tf.Context.ExecuteOp("Mul", name, new ExecuteOpArgs(scale, x)); | |||
@@ -396,7 +397,7 @@ namespace Tensorflow | |||
}); | |||
public static Tensor sign<T>(T x, string name = null) | |||
=> gen_math_ops.sign(x, name: name); | |||
=> gen_math_ops.sign(ops.convert_to_tensor(x), name: name); | |||
public static Tensor sin(Tensor x, string name = null) | |||
=> tf.Context.ExecuteOp("Sin", name, new ExecuteOpArgs(x)); | |||
@@ -421,7 +422,7 @@ namespace Tensorflow | |||
public static Tensor subtract<Tx, Ty>(Tx x, Ty y, string name = null) | |||
{ | |||
return gen_math_ops.sub(x, y, name); | |||
return gen_math_ops.sub(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name); | |||
} | |||
public static Tensor log(Tensor x, string name = null) | |||
@@ -455,8 +456,8 @@ namespace Tensorflow | |||
var axis_tensor = array_ops.where_v2(constant_op.constant(axis >= 0), x: axis, y: ndims + axis); | |||
// The purpose is to avoid having negative values when repeating. | |||
var num_fill = gen_math_ops.maximum(num_int_tensor - 2, 0); | |||
var n_steps = gen_math_ops.maximum(num_int_tensor - 1, 1); | |||
var num_fill = gen_math_ops.maximum(num_int_tensor - 2, ops.convert_to_tensor(0)); | |||
var n_steps = gen_math_ops.maximum(num_int_tensor - 1, ops.convert_to_tensor(1)); | |||
var delta = (expanded_stop - expanded_start) / cast(n_steps, expanded_stop.dtype); | |||
var range_end = array_ops.where_v2(num_int_tensor >= 0, n_steps, -1); | |||
@@ -503,7 +504,7 @@ namespace Tensorflow | |||
var axes_shape = array_ops.shape(axes); | |||
var rng = math_ops.range(input_rank); | |||
var a1 = new Tensor[] { rng, axes }; | |||
var fill = gen_array_ops.fill(axes_shape, 1); | |||
var fill = gen_array_ops.fill(axes_shape, ops.convert_to_tensor(1)); | |||
var a2 = new Tensor[] { input_shape, fill }; | |||
return gen_data_flow_ops.dynamic_stitch(a1, a2); | |||
@@ -528,7 +529,7 @@ namespace Tensorflow | |||
/// <returns></returns> | |||
public static Tensor reduce_all(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null) | |||
{ | |||
var all = gen_math_ops._all(input_tensor, | |||
var all = gen_math_ops.all(input_tensor, | |||
_ReductionDims(input_tensor, axis), | |||
keepdims, | |||
name: name); | |||
@@ -581,23 +582,23 @@ namespace Tensorflow | |||
public static Tensor reduce_any(Tensor input_tensor, Axis axis = null, bool keepdims = false, string name = null) | |||
{ | |||
var r = _ReductionDims(input_tensor, axis); | |||
var max = (axis != null) ? gen_math_ops._any(input_tensor, axis, keepdims, name) : | |||
gen_math_ops._any(input_tensor, r, keepdims, name); | |||
var max = (axis != null) ? gen_math_ops.any(input_tensor, axis, keepdims, name) : | |||
gen_math_ops.any(input_tensor, r, keepdims, name); | |||
return _may_reduce_to_scalar(keepdims, axis, max); | |||
} | |||
public static Tensor reduce_max(Tensor input_tensor, Axis axis = null, bool keepdims = false, string name = null) | |||
{ | |||
var r = _ReductionDims(input_tensor, axis); | |||
var max = (axis != null) ? gen_math_ops._max(input_tensor, axis, keepdims, name) : | |||
gen_math_ops._max(input_tensor, r, keepdims, name); | |||
var max = (axis != null) ? gen_math_ops.max(input_tensor, axis, keepdims, name) : | |||
gen_math_ops.max(input_tensor, r, keepdims, name); | |||
return _may_reduce_to_scalar(keepdims, axis, max); | |||
} | |||
public static Tensor reduce_min(Tensor input_tensor, Axis axis = null, bool keepdims = false, string name = null) | |||
{ | |||
var r = _ReductionDims(input_tensor, axis); | |||
var min = gen_math_ops._min(input_tensor, r, keepdims, name); | |||
var min = gen_math_ops.min(input_tensor, r, keepdims, name); | |||
return _may_reduce_to_scalar(keepdims, axis, min); | |||
} | |||
@@ -643,7 +644,7 @@ namespace Tensorflow | |||
public static Tensor reduce_sum(Tensor input_tensor, Tensor axis = null, bool keepdims = false, string name = null) | |||
{ | |||
var r = _ReductionDims(input_tensor, axis); | |||
var m = gen_math_ops._sum(input_tensor, r, keep_dims: keepdims, name: name); | |||
var m = gen_math_ops.sum(input_tensor, r, keep_dims: keepdims, name: name); | |||
return _may_reduce_to_scalar(keepdims, axis, m); | |||
} | |||
@@ -752,10 +753,10 @@ namespace Tensorflow | |||
} | |||
public static Tensor minimum<Tx, Ty>(Tx x, Ty y, string name = null) | |||
=> gen_math_ops.minimum(x, y, name: name); | |||
=> gen_math_ops.minimum(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name); | |||
public static Tensor maximum<Tx, Ty>(Tx x, Ty y, string name = null) | |||
=> gen_math_ops.maximum(x, y, name: name); | |||
=> gen_math_ops.maximum(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name); | |||
/// <summary> | |||
/// Multiplies matrix `a` by matrix `b`, producing `a` * `b`. | |||
@@ -236,7 +236,7 @@ namespace Tensorflow | |||
Tensor size = array_ops.size(value, out_type: dtypes.int64); | |||
Tensor zero_fraction_float32 = null; | |||
size = gen_math_ops.less_equal(size, dtypes.int32.max()); | |||
size = gen_math_ops.less_equal(size, ops.convert_to_tensor(dtypes.int32.max())); | |||
Tensor num_nonzero = control_flow_ops.cond( | |||
size, | |||
() => math_ops.cast(_count_nonzero(value, dtype: dtypes.int32), TF_DataType.TF_INT64), | |||
@@ -55,7 +55,7 @@ namespace Tensorflow | |||
return tf_with(ops.name_scope(name, "BiasAdd", new { value, bias }), scope => | |||
{ | |||
name = scope; | |||
return gen_nn_ops.bias_add(value, bias, data_format: data_format, name: name); | |||
return gen_nn_ops.bias_add(value, ops.convert_to_tensor(bias), data_format: data_format, name: name); | |||
}); | |||
} | |||
@@ -117,7 +117,7 @@ namespace Tensorflow | |||
{ | |||
return tf_with(ops.name_scope(name, "in_top_k"), delegate | |||
{ | |||
return gen_nn_ops.in_top_kv2(predictions, targets, k, name: name); | |||
return gen_nn_ops.in_top_kv2(predictions, targets, ops.convert_to_tensor(k), name: name); | |||
}); | |||
} | |||
@@ -222,8 +222,8 @@ namespace Tensorflow | |||
// Check if no reshapes are required. | |||
if (logits.shape.ndim == 2) | |||
{ | |||
var (cost, _) = gen_nn_ops.sparse_softmax_cross_entropy_with_logits( | |||
precise_logits, labels, name: name); | |||
var cost = gen_nn_ops.sparse_softmax_cross_entropy_with_logits( | |||
precise_logits, labels, name: name)[0]; | |||
if (logits.dtype == dtypes.float16) | |||
return math_ops.cast(cost, dtypes.float32); | |||
else | |||
@@ -261,7 +261,8 @@ namespace Tensorflow | |||
// The second output tensor contains the gradients. We use it in | |||
// _CrossEntropyGrad() in nn_grad but not here. | |||
var (cost, unused_backprop) = gen_nn_ops.softmax_cross_entropy_with_logits(precise_logits, labels, name: name); | |||
var entropy = gen_nn_ops.softmax_cross_entropy_with_logits(precise_logits, labels, name: name); | |||
var (cost, unused_backprop) = (entropy[0], entropy[1]); | |||
// The output cost shape should be the input minus axis. | |||
var output_shape = array_ops.slice(input_shape, | |||
@@ -78,7 +78,7 @@ namespace Tensorflow | |||
minlength: nrows_int32, | |||
maxlength: nrows_int32, | |||
dtype: value_rowids.dtype); | |||
var row_splits = array_ops.concat(new object[] | |||
var row_splits = array_ops.concat(new Tensor[] | |||
{ | |||
ops.convert_to_tensor(new long[] { 0 }), | |||
tf.cumsum(row_lengths) | |||
@@ -154,103 +154,103 @@ namespace Tensorflow | |||
public static Tensor operator >(Tensor lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); | |||
public static Tensor operator >(Tensor lhs, NDArray rhs) => gen_math_ops.greater(lhs, rhs); | |||
public static Tensor operator >(NDArray lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); | |||
public static Tensor operator >(Tensor lhs, sbyte rhs) => gen_math_ops.greater(lhs, rhs); | |||
public static Tensor operator >(sbyte lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); | |||
public static Tensor operator >(Tensor lhs, byte rhs) => gen_math_ops.greater(lhs, rhs); | |||
public static Tensor operator >(byte lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); | |||
public static Tensor operator >(Tensor lhs, short rhs) => gen_math_ops.greater(lhs, rhs); | |||
public static Tensor operator >(short lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); | |||
public static Tensor operator >(Tensor lhs, ushort rhs) => gen_math_ops.greater(lhs, rhs); | |||
public static Tensor operator >(ushort lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); | |||
public static Tensor operator >(Tensor lhs, int rhs) => gen_math_ops.greater(lhs, rhs); | |||
public static Tensor operator >(int lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); | |||
public static Tensor operator >(Tensor lhs, uint rhs) => gen_math_ops.greater(lhs, rhs); | |||
public static Tensor operator >(uint lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); | |||
public static Tensor operator >(Tensor lhs, ulong rhs) => gen_math_ops.greater(lhs, rhs); | |||
public static Tensor operator >(ulong lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); | |||
public static Tensor operator >(Tensor lhs, long rhs) => gen_math_ops.greater(lhs, rhs); | |||
public static Tensor operator >(long lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); | |||
public static Tensor operator >(Tensor lhs, float rhs) => gen_math_ops.greater(lhs, rhs); | |||
public static Tensor operator >(float lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); | |||
public static Tensor operator >(Tensor lhs, double rhs) => gen_math_ops.greater(lhs, rhs); | |||
public static Tensor operator >(double lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); | |||
public static Tensor operator >(Tensor lhs, Complex rhs) => gen_math_ops.greater(lhs, rhs); | |||
public static Tensor operator >(Complex lhs, Tensor rhs) => gen_math_ops.greater(lhs, rhs); | |||
public static Tensor operator >(Tensor lhs, sbyte rhs) => gen_math_ops.greater(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator >(sbyte lhs, Tensor rhs) => gen_math_ops.greater(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator >(Tensor lhs, byte rhs) => gen_math_ops.greater(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator >(byte lhs, Tensor rhs) => gen_math_ops.greater(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator >(Tensor lhs, short rhs) => gen_math_ops.greater(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator >(short lhs, Tensor rhs) => gen_math_ops.greater(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator >(Tensor lhs, ushort rhs) => gen_math_ops.greater(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator >(ushort lhs, Tensor rhs) => gen_math_ops.greater(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator >(Tensor lhs, int rhs) => gen_math_ops.greater(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator >(int lhs, Tensor rhs) => gen_math_ops.greater(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator >(Tensor lhs, uint rhs) => gen_math_ops.greater(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator >(uint lhs, Tensor rhs) => gen_math_ops.greater(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator >(Tensor lhs, ulong rhs) => gen_math_ops.greater(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator >(ulong lhs, Tensor rhs) => gen_math_ops.greater(ops.convert_to_tensor(lhs), ops.convert_to_tensor(rhs)); | |||
public static Tensor operator >(Tensor lhs, long rhs) => gen_math_ops.greater(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator >(long lhs, Tensor rhs) => gen_math_ops.greater(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator >(Tensor lhs, float rhs) => gen_math_ops.greater(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator >(float lhs, Tensor rhs) => gen_math_ops.greater(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator >(Tensor lhs, double rhs) => gen_math_ops.greater(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator >(double lhs, Tensor rhs) => gen_math_ops.greater(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator >(Tensor lhs, Complex rhs) => gen_math_ops.greater(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator >(Complex lhs, Tensor rhs) => gen_math_ops.greater(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator <(Tensor lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); | |||
public static Tensor operator <(Tensor lhs, NDArray rhs) => gen_math_ops.less(lhs, rhs); | |||
public static Tensor operator <(NDArray lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); | |||
public static Tensor operator <(Tensor lhs, sbyte rhs) => gen_math_ops.less(lhs, rhs); | |||
public static Tensor operator <(sbyte lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); | |||
public static Tensor operator <(Tensor lhs, byte rhs) => gen_math_ops.less(lhs, rhs); | |||
public static Tensor operator <(byte lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); | |||
public static Tensor operator <(Tensor lhs, short rhs) => gen_math_ops.less(lhs, rhs); | |||
public static Tensor operator <(short lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); | |||
public static Tensor operator <(Tensor lhs, ushort rhs) => gen_math_ops.less(lhs, rhs); | |||
public static Tensor operator <(ushort lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); | |||
public static Tensor operator <(Tensor lhs, int rhs) => gen_math_ops.less(lhs, rhs); | |||
public static Tensor operator <(int lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); | |||
public static Tensor operator <(Tensor lhs, uint rhs) => gen_math_ops.less(lhs, rhs); | |||
public static Tensor operator <(uint lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); | |||
public static Tensor operator <(Tensor lhs, ulong rhs) => gen_math_ops.less(lhs, rhs); | |||
public static Tensor operator <(ulong lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); | |||
public static Tensor operator <(Tensor lhs, long rhs) => gen_math_ops.less(lhs, rhs); | |||
public static Tensor operator <(long lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); | |||
public static Tensor operator <(Tensor lhs, float rhs) => gen_math_ops.less(lhs, rhs); | |||
public static Tensor operator <(float lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); | |||
public static Tensor operator <(Tensor lhs, double rhs) => gen_math_ops.less(lhs, rhs); | |||
public static Tensor operator <(double lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); | |||
public static Tensor operator <(Tensor lhs, Complex rhs) => gen_math_ops.less(lhs, rhs); | |||
public static Tensor operator <(Complex lhs, Tensor rhs) => gen_math_ops.less(lhs, rhs); | |||
public static Tensor operator <(Tensor lhs, sbyte rhs) => gen_math_ops.less(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator <(sbyte lhs, Tensor rhs) => gen_math_ops.less(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator <(Tensor lhs, byte rhs) => gen_math_ops.less(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator <(byte lhs, Tensor rhs) => gen_math_ops.less(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator <(Tensor lhs, short rhs) => gen_math_ops.less(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator <(short lhs, Tensor rhs) => gen_math_ops.less(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator <(Tensor lhs, ushort rhs) => gen_math_ops.less(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator <(ushort lhs, Tensor rhs) => gen_math_ops.less(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator <(Tensor lhs, int rhs) => gen_math_ops.less(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator <(int lhs, Tensor rhs) => gen_math_ops.less(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator <(Tensor lhs, uint rhs) => gen_math_ops.less(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator <(uint lhs, Tensor rhs) => gen_math_ops.less(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator <(Tensor lhs, ulong rhs) => gen_math_ops.less(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator <(ulong lhs, Tensor rhs) => gen_math_ops.less(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator <(Tensor lhs, long rhs) => gen_math_ops.less(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator <(long lhs, Tensor rhs) => gen_math_ops.less(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator <(Tensor lhs, float rhs) => gen_math_ops.less(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator <(float lhs, Tensor rhs) => gen_math_ops.less(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator <(Tensor lhs, double rhs) => gen_math_ops.less(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator <(double lhs, Tensor rhs) => gen_math_ops.less(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator <(Tensor lhs, Complex rhs) => gen_math_ops.less(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator <(Complex lhs, Tensor rhs) => gen_math_ops.less(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator >=(Tensor lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); | |||
public static Tensor operator >=(Tensor lhs, NDArray rhs) => gen_math_ops.greater_equal(lhs, rhs); | |||
public static Tensor operator >=(NDArray lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); | |||
public static Tensor operator >=(Tensor lhs, sbyte rhs) => gen_math_ops.greater_equal(lhs, rhs); | |||
public static Tensor operator >=(sbyte lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); | |||
public static Tensor operator >=(Tensor lhs, byte rhs) => gen_math_ops.greater_equal(lhs, rhs); | |||
public static Tensor operator >=(byte lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); | |||
public static Tensor operator >=(Tensor lhs, short rhs) => gen_math_ops.greater_equal(lhs, rhs); | |||
public static Tensor operator >=(short lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); | |||
public static Tensor operator >=(Tensor lhs, ushort rhs) => gen_math_ops.greater_equal(lhs, rhs); | |||
public static Tensor operator >=(ushort lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); | |||
public static Tensor operator >=(Tensor lhs, int rhs) => gen_math_ops.greater_equal(lhs, rhs); | |||
public static Tensor operator >=(int lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); | |||
public static Tensor operator >=(Tensor lhs, uint rhs) => gen_math_ops.greater_equal(lhs, rhs); | |||
public static Tensor operator >=(uint lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); | |||
public static Tensor operator >=(Tensor lhs, ulong rhs) => gen_math_ops.greater_equal(lhs, rhs); | |||
public static Tensor operator >=(ulong lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); | |||
public static Tensor operator >=(Tensor lhs, long rhs) => gen_math_ops.greater_equal(lhs, rhs); | |||
public static Tensor operator >=(long lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); | |||
public static Tensor operator >=(Tensor lhs, float rhs) => gen_math_ops.greater_equal(lhs, rhs); | |||
public static Tensor operator >=(float lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); | |||
public static Tensor operator >=(Tensor lhs, double rhs) => gen_math_ops.greater_equal(lhs, rhs); | |||
public static Tensor operator >=(double lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); | |||
public static Tensor operator >=(Tensor lhs, Complex rhs) => gen_math_ops.greater_equal(lhs, rhs); | |||
public static Tensor operator >=(Complex lhs, Tensor rhs) => gen_math_ops.greater_equal(lhs, rhs); | |||
public static Tensor operator >=(Tensor lhs, sbyte rhs) => gen_math_ops.greater_equal(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator >=(sbyte lhs, Tensor rhs) => gen_math_ops.greater_equal(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator >=(Tensor lhs, byte rhs) => gen_math_ops.greater_equal(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator >=(byte lhs, Tensor rhs) => gen_math_ops.greater_equal(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator >=(Tensor lhs, short rhs) => gen_math_ops.greater_equal(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator >=(short lhs, Tensor rhs) => gen_math_ops.greater_equal(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator >=(Tensor lhs, ushort rhs) => gen_math_ops.greater_equal(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator >=(ushort lhs, Tensor rhs) => gen_math_ops.greater_equal(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator >=(Tensor lhs, int rhs) => gen_math_ops.greater_equal(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator >=(int lhs, Tensor rhs) => gen_math_ops.greater_equal(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator >=(Tensor lhs, uint rhs) => gen_math_ops.greater_equal(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator >=(uint lhs, Tensor rhs) => gen_math_ops.greater_equal(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator >=(Tensor lhs, ulong rhs) => gen_math_ops.greater_equal(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator >=(ulong lhs, Tensor rhs) => gen_math_ops.greater_equal(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator >=(Tensor lhs, long rhs) => gen_math_ops.greater_equal(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator >=(long lhs, Tensor rhs) => gen_math_ops.greater_equal(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator >=(Tensor lhs, float rhs) => gen_math_ops.greater_equal(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator >=(float lhs, Tensor rhs) => gen_math_ops.greater_equal(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator >=(Tensor lhs, double rhs) => gen_math_ops.greater_equal(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator >=(double lhs, Tensor rhs) => gen_math_ops.greater_equal(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator >=(Tensor lhs, Complex rhs) => gen_math_ops.greater_equal(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator >=(Complex lhs, Tensor rhs) => gen_math_ops.greater_equal(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator <=(Tensor lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); | |||
public static Tensor operator <=(Tensor lhs, NDArray rhs) => gen_math_ops.less_equal(lhs, rhs); | |||
public static Tensor operator <=(NDArray lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); | |||
public static Tensor operator <=(Tensor lhs, sbyte rhs) => gen_math_ops.less_equal(lhs, rhs); | |||
public static Tensor operator <=(sbyte lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); | |||
public static Tensor operator <=(Tensor lhs, byte rhs) => gen_math_ops.less_equal(lhs, rhs); | |||
public static Tensor operator <=(byte lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); | |||
public static Tensor operator <=(Tensor lhs, short rhs) => gen_math_ops.less_equal(lhs, rhs); | |||
public static Tensor operator <=(short lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); | |||
public static Tensor operator <=(Tensor lhs, ushort rhs) => gen_math_ops.less_equal(lhs, rhs); | |||
public static Tensor operator <=(ushort lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); | |||
public static Tensor operator <=(Tensor lhs, int rhs) => gen_math_ops.less_equal(lhs, rhs); | |||
public static Tensor operator <=(int lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); | |||
public static Tensor operator <=(Tensor lhs, uint rhs) => gen_math_ops.less_equal(lhs, rhs); | |||
public static Tensor operator <=(uint lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); | |||
public static Tensor operator <=(Tensor lhs, ulong rhs) => gen_math_ops.less_equal(lhs, rhs); | |||
public static Tensor operator <=(ulong lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); | |||
public static Tensor operator <=(Tensor lhs, long rhs) => gen_math_ops.less_equal(lhs, rhs); | |||
public static Tensor operator <=(long lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); | |||
public static Tensor operator <=(Tensor lhs, float rhs) => gen_math_ops.less_equal(lhs, rhs); | |||
public static Tensor operator <=(float lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); | |||
public static Tensor operator <=(Tensor lhs, double rhs) => gen_math_ops.less_equal(lhs, rhs); | |||
public static Tensor operator <=(double lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); | |||
public static Tensor operator <=(Tensor lhs, Complex rhs) => gen_math_ops.less_equal(lhs, rhs); | |||
public static Tensor operator <=(Complex lhs, Tensor rhs) => gen_math_ops.less_equal(lhs, rhs); | |||
public static Tensor operator <=(Tensor lhs, sbyte rhs) => gen_math_ops.less_equal(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator <=(sbyte lhs, Tensor rhs) => gen_math_ops.less_equal(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator <=(Tensor lhs, byte rhs) => gen_math_ops.less_equal(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator <=(byte lhs, Tensor rhs) => gen_math_ops.less_equal(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator <=(Tensor lhs, short rhs) => gen_math_ops.less_equal(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator <=(short lhs, Tensor rhs) => gen_math_ops.less_equal(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator <=(Tensor lhs, ushort rhs) => gen_math_ops.less_equal(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator <=(ushort lhs, Tensor rhs) => gen_math_ops.less_equal(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator <=(Tensor lhs, int rhs) => gen_math_ops.less_equal(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator <=(int lhs, Tensor rhs) => gen_math_ops.less_equal(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator <=(Tensor lhs, uint rhs) => gen_math_ops.less_equal(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator <=(uint lhs, Tensor rhs) => gen_math_ops.less_equal(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator <=(Tensor lhs, ulong rhs) => gen_math_ops.less_equal(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator <=(ulong lhs, Tensor rhs) => gen_math_ops.less_equal(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator <=(Tensor lhs, long rhs) => gen_math_ops.less_equal(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator <=(long lhs, Tensor rhs) => gen_math_ops.less_equal(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator <=(Tensor lhs, float rhs) => gen_math_ops.less_equal(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator <=(float lhs, Tensor rhs) => gen_math_ops.less_equal(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator <=(Tensor lhs, double rhs) => gen_math_ops.less_equal(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator <=(double lhs, Tensor rhs) => gen_math_ops.less_equal(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator <=(Tensor lhs, Complex rhs) => gen_math_ops.less_equal(lhs, ops.convert_to_tensor(rhs)); | |||
public static Tensor operator <=(Complex lhs, Tensor rhs) => gen_math_ops.less_equal(ops.convert_to_tensor(lhs), rhs); | |||
public static Tensor operator -(Tensor x) => gen_math_ops.neg(x); | |||
@@ -161,6 +161,9 @@ namespace Tensorflow | |||
EnsureSingleTensor(tensor, "explicit conversion to string"); | |||
return (string)tensor[0]; | |||
} | |||
public static explicit operator object[](Tensors tensors) | |||
=> tensors.items.ToArray(); | |||
#endregion | |||
#region Implicit Conversions | |||
@@ -106,7 +106,7 @@ namespace Tensorflow | |||
name = scope; | |||
// Add a placeholder string tensor for the filename. | |||
var filename_tensor = array_ops.placeholder_with_default(string.IsNullOrEmpty(filename) ? "model" : filename, shape: new int[0], name: "filename"); | |||
var filename_tensor = array_ops.placeholder_with_default(tf.convert_to_tensor(string.IsNullOrEmpty(filename) ? "model" : filename), shape: new int[0], name: "filename"); | |||
// Keep the name "Const" for backwards compatibility. | |||
filename_tensor = gen_array_ops.placeholder_with_default(filename_tensor, shape: new int[0], name: "Const"); | |||
@@ -57,7 +57,8 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
IDatasetV2 slice_batch_indices(Tensor indices) | |||
{ | |||
var num_in_full_batch = num_full_batches * _batch_size; | |||
var first_k_indices = array_ops.slice(indices, new int[] { 0 }, new int[] { num_in_full_batch }); | |||
var first_k_indices = array_ops.slice(indices, new Tensor[] { ops.convert_to_tensor(0) }, | |||
new Tensor[] { ops.convert_to_tensor(num_in_full_batch) }); | |||
first_k_indices = array_ops.reshape(first_k_indices, new int[] { num_full_batches, _batch_size }); | |||
var flat_dataset = tf.data.Dataset.from_tensor_slices(first_k_indices); | |||
if (_partial_batch_size > 0) | |||
@@ -81,7 +82,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
{ | |||
var indices = inputs[0]; | |||
var results = inputs.Skip(1) | |||
.Select(x => gen_array_ops.gather_v2(x, indices, 0)) | |||
.Select(x => array_ops.gather(x, indices, axis: 0)) | |||
.ToArray(); | |||
return new Tensors(results); | |||
}, -1); | |||
@@ -79,7 +79,7 @@ namespace Tensorflow.Keras.Layers | |||
} | |||
else | |||
{ | |||
outputs = gen_math_ops.mat_mul(inputs, kernel.AsTensor()); | |||
outputs = math_ops.matmul(inputs, kernel.AsTensor()); | |||
} | |||
if (args.UseBias) | |||
@@ -30,7 +30,7 @@ namespace Tensorflow.Keras.Losses | |||
return gen_math_ops.mean(array_ops.where_v2(abs_error <= delta, | |||
half * math_ops.pow(error, 2), | |||
half * math_ops.pow(delta, 2) + delta * (abs_error - delta)), | |||
axis: -1); | |||
ops.convert_to_tensor(-1)); | |||
} | |||
} | |||
} |
@@ -20,7 +20,8 @@ namespace Tensorflow.Keras.Losses | |||
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); | |||
Tensor x = y_pred_dispatch - y_true_cast; | |||
return gen_math_ops.mean(x + gen_math_ops.softplus(-2.0 * x) - math_ops.cast(math_ops.log(tf.Variable(2.0)), x.dtype), axis: -1); | |||
return gen_math_ops.mean(x + gen_nn_ops.softplus(-2.0 * x) - math_ops.cast(math_ops.log(tf.Variable(2.0)), x.dtype), | |||
ops.convert_to_tensor(-1)); | |||
} | |||
} | |||
} |
@@ -17,7 +17,7 @@ namespace Tensorflow.Keras.Losses | |||
{ | |||
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); | |||
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); | |||
return gen_math_ops.mean(math_ops.abs(y_pred_dispatch - y_true_cast), axis: -1); | |||
return gen_math_ops.mean(math_ops.abs(y_pred_dispatch - y_true_cast), ops.convert_to_tensor(-1)); | |||
} | |||
} | |||
} |
@@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Losses | |||
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); | |||
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); | |||
Tensor diff = math_ops.abs(y_true_cast - y_pred_dispatch) / gen_math_ops.maximum(math_ops.abs(y_true_cast), gen_math_ops.cast(tf.constant(1e-7), y_pred_dispatch.dtype)); | |||
return gen_math_ops.cast(tf.constant(100), y_pred_dispatch.dtype) * gen_math_ops.mean(diff, axis: -1); | |||
return gen_math_ops.cast(tf.constant(100), y_pred_dispatch.dtype) * gen_math_ops.mean(diff, ops.convert_to_tensor(-1)); | |||
} | |||
} | |||
} |
@@ -17,7 +17,7 @@ namespace Tensorflow.Keras.Losses | |||
{ | |||
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); | |||
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); | |||
return gen_math_ops.mean(gen_math_ops.squared_difference(y_pred_dispatch, y_true_cast), axis: -1); | |||
return gen_math_ops.mean(gen_math_ops.squared_difference(y_pred_dispatch, y_true_cast), ops.convert_to_tensor(-1)); | |||
} | |||
} | |||
} |
@@ -20,14 +20,14 @@ namespace Tensorflow.Keras.Losses | |||
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); | |||
Tensor first_log=null, second_log=null; | |||
if (y_pred_dispatch.dtype == TF_DataType.TF_DOUBLE) { | |||
first_log = math_ops.log(gen_math_ops.maximum(y_pred_dispatch, 1e-7) + 1.0); | |||
second_log = math_ops.log(gen_math_ops.maximum(y_true_cast, 1e-7) + 1.0); | |||
first_log = math_ops.log(math_ops.maximum(y_pred_dispatch, 1e-7) + 1.0); | |||
second_log = math_ops.log(math_ops.maximum(y_true_cast, 1e-7) + 1.0); | |||
} | |||
else { | |||
first_log = math_ops.log(gen_math_ops.maximum(y_pred_dispatch, 1e-7f) + 1.0f); | |||
second_log = math_ops.log(gen_math_ops.maximum(y_true_cast, 1e-7f) + 1.0f); | |||
first_log = math_ops.log(math_ops.maximum(y_pred_dispatch, 1e-7f) + 1.0f); | |||
second_log = math_ops.log(math_ops.maximum(y_true_cast, 1e-7f) + 1.0f); | |||
} | |||
return gen_math_ops.mean(gen_math_ops.squared_difference(first_log, second_log), axis: -1); | |||
return gen_math_ops.mean(gen_math_ops.squared_difference(first_log, second_log), ops.convert_to_tensor(-1)); | |||
} | |||
} | |||
} |
@@ -25,8 +25,8 @@ namespace TensorFlowNET.UnitTest.ControlFlowTest | |||
// TODO: implement missing code dependencies | |||
var sess = this.cached_session(); | |||
var i = constant_op.constant(0, name: "i"); | |||
var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, 10, name: "c")); | |||
var b = new Func<Tensor, Tensor>(x => gen_math_ops.add(x, 1, name: "c")); | |||
var c = new Func<Tensor, Tensor>(x => gen_math_ops.less(x, ops.convert_to_tensor(10), name: "c")); | |||
var b = new Func<Tensor, Tensor>(x => math_ops.add(x, 1, name: "c")); | |||
//control_flow_ops.while_loop( | |||
// c, b, i , maximum_iterations: tf.constant(maximum_iterations)); | |||
foreach (Operation op in sess.graph.get_operations()) | |||
@@ -260,7 +260,7 @@ namespace TensorFlowNET.UnitTest.Gradient | |||
public void testStopGradientFunction() | |||
{ | |||
var ap = tf.constant(1f); | |||
var b = tf.tanh(ap) + gen_array_ops.stop_gradient(ap); | |||
var b = tf.tanh(ap) + array_ops.stop_gradient(ap); | |||
var g = tf.gradients(b, ap); | |||
var sess = tf.Session(); | |||
var result = sess.run(g); | |||
@@ -18,7 +18,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||
var input_array = tf.constant(np.array(new int[] { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 }).reshape((3,2,3))); | |||
var indices = tf.constant(np.array(new int[] { 0, 2 })); | |||
var r1 = array_ops.slice(input_array, new int[] { 1, 0, 0 }, new int[] { 1, 1, 3 }); | |||
var r1 = array_ops.slice(input_array, ops.convert_n_to_tensor(new object[] { 1, 0, 0 }), ops.convert_n_to_tensor(new object[] { 1, 1, 3 })); | |||
Assert.AreEqual(new Shape(1,1,3), r1.shape); | |||
var r1np = r1.numpy(); | |||
Assert.AreEqual(r1np[0, 0, 0], 3); | |||
@@ -26,7 +26,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||
Assert.AreEqual(r1np[0, 0, 2], 3); | |||
var r2 = array_ops.slice(input_array, new int[] { 1, 0, 0 }, new int[] { 1, 2, 3 }); | |||
var r2 = array_ops.slice(input_array, ops.convert_n_to_tensor(new object[] { 1, 0, 0 }), ops.convert_n_to_tensor(new object[] { 1, 2, 3 })); | |||
Assert.AreEqual(new Shape(1, 2, 3), r2.shape); | |||
var r2np = r2.numpy(); | |||
Assert.AreEqual(r2np[0, 0, 0], 3); | |||
@@ -36,7 +36,7 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||
Assert.AreEqual(r2np[0, 1, 1], 4); | |||
Assert.AreEqual(r2np[0, 1, 2], 4); | |||
var r3 = array_ops.slice(input_array, new int[] { 1, 0, 0 }, new int[] { 2, 1, 3 }); | |||
var r3 = array_ops.slice(input_array, ops.convert_n_to_tensor(new object[] { 1, 0, 0 }), ops.convert_n_to_tensor(new object[] { 2, 1, 3 })); | |||
Assert.AreEqual(new Shape(2, 1, 3), r3.shape); | |||
var r3np = r3.numpy(); | |||
Assert.AreEqual(r3np[0, 0, 0], 3); | |||