|
|
@@ -16,6 +16,7 @@ |
|
|
|
|
|
|
|
using System.Collections.Generic; |
|
|
|
using System.Linq; |
|
|
|
using Tensorflow.Eager; |
|
|
|
using Tensorflow.Framework; |
|
|
|
using static Tensorflow.Binding; |
|
|
|
|
|
|
@@ -82,7 +83,14 @@ namespace Tensorflow.Gradients |
|
|
|
.ToArray(); |
|
|
|
|
|
|
|
var out_grads = new List<Tensor>(); |
|
|
|
if (constant_op.is_constant(concat_dim)) |
|
|
|
if(concat_dim is EagerTensor) |
|
|
|
{ |
|
|
|
var non_neg_concat_dim = (int)concat_dim % input_values[0].rank; |
|
|
|
var sizes = input_values.Select(x => x.shape[non_neg_concat_dim]).ToArray(); |
|
|
|
var sizes_tensor = constant_op.constant(sizes); |
|
|
|
out_grads = gen_array_ops.split_v(grad, sizes_tensor, sizes[0], non_neg_concat_dim).ToList(); |
|
|
|
} |
|
|
|
else if (constant_op.is_constant(concat_dim)) |
|
|
|
{ |
|
|
|
/*If concat_dim is a constant defined in a different context, |
|
|
|
then we duplicate it in the current context to avoid passing it |
|
|
@@ -97,33 +105,33 @@ namespace Tensorflow.Gradients |
|
|
|
var value = tensor_util.constant_value(concat_dim); |
|
|
|
concat_dim = constant_op.constant(value: value, dtype: concat_dim.dtype); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// Using mod here for convenience since concat_dim is already verified |
|
|
|
// in concat implementation to be within the allowed [-rank, rank) range. |
|
|
|
var non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0]); |
|
|
|
// Using mod here for convenience since concat_dim is already verified |
|
|
|
// in concat implementation to be within the allowed [-rank, rank) range. |
|
|
|
var non_neg_concat_dim = concat_dim % array_ops.rank(input_values[0]); |
|
|
|
|
|
|
|
// Get the inputs' tensor shapes |
|
|
|
var sizes = _ExtractInputShapes(input_values); |
|
|
|
// Get the inputs' tensor shapes |
|
|
|
var sizes = _ExtractInputShapes(input_values); |
|
|
|
|
|
|
|
/* The magic number of 16 was found through benchmarking a range of sizes |
|
|
|
on CPUs and a Maxwell TitanX. A speedup was seen in a large majority of |
|
|
|
cases when switching implementations at N=16, but it is possible that |
|
|
|
there will be a small number of performance regressions.*/ |
|
|
|
if (len(sizes) > 16) |
|
|
|
{ |
|
|
|
// extract the size of each input along the concat dimension |
|
|
|
var slice = array_ops.slice(array_ops.stack(sizes, axis: 1), |
|
|
|
new Tensor[] { non_neg_concat_dim, tf.constant(0) }, |
|
|
|
new Tensor[] { tf.constant(1), tf.constant(-1) }); |
|
|
|
var squeeze_sizes = array_ops.squeeze(slice); |
|
|
|
out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_split: (int)non_neg_concat_dim).ToList(); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
var offset = gen_array_ops.concat_offset(non_neg_concat_dim, sizes); |
|
|
|
foreach (var (begin, size) in zip(offset, sizes)) |
|
|
|
out_grads.Add(gen_array_ops.slice(grad, begin, size)); |
|
|
|
/* The magic number of 16 was found through benchmarking a range of sizes |
|
|
|
on CPUs and a Maxwell TitanX. A speedup was seen in a large majority of |
|
|
|
cases when switching implementations at N=16, but it is possible that |
|
|
|
there will be a small number of performance regressions.*/ |
|
|
|
if (len(sizes) > 16) |
|
|
|
{ |
|
|
|
// extract the size of each input along the concat dimension |
|
|
|
var slice = array_ops.slice(array_ops.stack(sizes, axis: 1), |
|
|
|
new Tensor[] { non_neg_concat_dim, tf.constant(0) }, |
|
|
|
new Tensor[] { tf.constant(1), tf.constant(-1) }); |
|
|
|
var squeeze_sizes = array_ops.squeeze(slice); |
|
|
|
out_grads = array_ops.split(axis: grad, value: squeeze_sizes, num_split: (int)non_neg_concat_dim).ToList(); |
|
|
|
} |
|
|
|
else |
|
|
|
{ |
|
|
|
var offset = gen_array_ops.concat_offset(non_neg_concat_dim, sizes); |
|
|
|
foreach (var (begin, size) in zip(offset, sizes)) |
|
|
|
out_grads.Add(gen_array_ops.slice(grad, begin, size)); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return (end_value_index <= dim_index ? |
|
|
|