|
|
@@ -166,6 +166,11 @@ namespace Tensorflow |
|
|
|
throw new ValueError("mask cannot be scalar."); |
|
|
|
|
|
|
|
var leading_size = gen_math_ops.prod(shape(tensor_tensor)[$"{axis}:{axis + ndims_mask}"], ops.convert_to_tensor(new[] { 0 })); |
|
|
|
if (leading_size.rank == 0) |
|
|
|
{ |
|
|
|
leading_size = expand_dims(leading_size, 0); |
|
|
|
} |
|
|
|
|
|
|
|
var shape1 = concat(new[] |
|
|
|
{ |
|
|
|
shape(tensor_tensor)[$":{axis}"], |
|
|
@@ -185,7 +190,7 @@ namespace Tensorflow |
|
|
|
|
|
|
|
private static Tensor _apply_mask_1d(Tensor reshaped_tensor, Tensor mask, int axis = 0) |
|
|
|
{ |
|
|
|
var indices = squeeze(where(mask), axis: new[] { 1 }); |
|
|
|
var indices = squeeze(where_v2(mask), axis: new[] { 1 }); |
|
|
|
return gather(reshaped_tensor, indices, axis: ops.convert_to_tensor(axis)); |
|
|
|
} |
|
|
|
|
|
|
@@ -940,12 +945,12 @@ namespace Tensorflow |
|
|
|
/// <returns></returns> |
|
|
|
public static Tensor concat(Tensor[] values, Tensor axis, string name = "concat") |
|
|
|
{ |
|
|
|
return tf.Context.ExecuteOp("ConcatV2", name, new ExecuteOpArgs(values, axis)); |
|
|
|
return gen_array_ops.concat_v2(values, axis, name: name); |
|
|
|
} |
|
|
|
|
|
|
|
public static Tensor concat(object[] values, int axis, string name = "concat") |
|
|
|
public static Tensor concat(Tensor[] values, Axis axis, string name = "concat") |
|
|
|
{ |
|
|
|
return tf.Context.ExecuteOp("ConcatV2", name, new ExecuteOpArgs(values, axis)); |
|
|
|
return gen_array_ops.concat_v2(values, axis, name: name); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|