Improve RaggedTensortags/v0.150.0-BERT-Model
@@ -1139,5 +1139,18 @@ namespace Tensorflow | |||
var _op = tf.OpDefLib._apply_op_helper("Placeholder", name: name, args: new { dtype, shape }); | |||
return _op.output; | |||
} | |||
public static int get_positive_axis(int axis, int ndims=-100, string axis_name="axis", string ndims_name= "ndims") | |||
{ | |||
if(ndims != -100) | |||
{ | |||
if (axis >= 0 && axis < ndims) return axis; | |||
else if (-ndims <= axis && axis < 0) return axis + ndims; | |||
else throw new ValueError($"{axis_name}={axis} out of bounds:expected {-ndims}<={axis_name}<{ndims}"); | |||
} else if(axis < 0) throw new ValueError($"{axis_name}={axis} may only be negative if {ndims_name} is statically known."); | |||
return axis; | |||
} | |||
} | |||
} |
@@ -163,5 +163,38 @@ namespace Tensorflow | |||
{ | |||
return tensor.Tag as RaggedTensor; | |||
} | |||
public Tensor nrows(TF_DataType out_type, string name = null) | |||
{ | |||
tf_with(ops.name_scope(name, "RaggedNRows"), scope => | |||
{ | |||
return math_ops.cast(this._row_partition.nrows(), dtype: out_type); | |||
}); | |||
return null; | |||
} | |||
public RaggedTensor row_lengths(int axis=-1, string name=null) | |||
{ | |||
if (axis == 0) return this._row_partition.nrows(); | |||
if (axis == 1) return this._row_partition.row_lengths(); | |||
var values = (RaggedTensor)this._values; | |||
axis = array_ops.get_positive_axis( | |||
axis, this.shape.rank, ndims_name: "rank(this)"); | |||
if (axis == 0) return this.nrows(this._row_partition.GetDataType()); | |||
else if (axis == 1) | |||
{ | |||
var splits = this._row_partition.row_splits; | |||
return splits[new Slice(start: 1)] - splits[new Slice(stop: -1)]; | |||
} | |||
else if (this._values is RaggedTensor) | |||
{ | |||
return values.row_lengths(axis - 1); | |||
} | |||
else | |||
{ | |||
var shape = array_ops.shape(values, out_type: this._row_partition.GetDataType()); | |||
return array_ops.ones(shape[new Slice(stop:axis - 1)], this._row_partition.GetDataType()) * | |||
shape[axis - 1]; | |||
} | |||
} | |||
} | |||
} |
@@ -14,10 +14,15 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using Serilog.Debugging; | |||
using System; | |||
using System.Collections.Concurrent; | |||
using System.Collections.Generic; | |||
//using System.ComponentModel.DataAnnotations; | |||
using System.Text; | |||
using System.Xml.Linq; | |||
using Tensorflow.Framework; | |||
using Tensorflow.NumPy; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
@@ -99,5 +104,55 @@ namespace Tensorflow | |||
return new RowPartition(row_splits); | |||
}); | |||
} | |||
public static RowPartition from_row_lengths(Tensor row_lengths, | |||
bool validate=true, | |||
TF_DataType dtype = TF_DataType.TF_INT32, | |||
TF_DataType dtype_hint= TF_DataType.TF_INT32) | |||
{ | |||
row_lengths = _convert_row_partition( | |||
row_lengths, "row_lengths", dtype_hint: dtype_hint, dtype: dtype); | |||
Tensor row_limits = math_ops.cumsum<Tensor>(row_lengths, tf.constant(-1)); | |||
Tensor row_splits = array_ops.concat(new Tensor[] { tf.convert_to_tensor(np.array(new int[] { 0 }, TF_DataType.TF_INT64)), row_limits }, axis:0); | |||
return new RowPartition(row_splits: row_splits, row_lengths: row_lengths); | |||
} | |||
public static Tensor _convert_row_partition(Tensor partition, string name, TF_DataType dtype, | |||
TF_DataType dtype_hint= TF_DataType.TF_INT64) | |||
{ | |||
if (partition is NDArray && partition.GetDataType() == np.int32) partition = ops.convert_to_tensor(partition, name: name); | |||
if (partition.GetDataType() != np.int32 && partition.GetDataType() != np.int64) throw new ValueError($"{name} must have dtype int32 or int64"); | |||
return partition; | |||
} | |||
public Tensor nrows() | |||
{ | |||
/*Returns the number of rows created by this `RowPartition*/ | |||
if (this._nrows != null) return this._nrows; | |||
var nsplits = tensor_shape.dimension_at_index(this._row_splits.shape, 0); | |||
if (nsplits == null) return array_ops.shape(this._row_splits, out_type: this.row_splits.dtype)[0] - 1; | |||
else return constant_op.constant(nsplits.value - 1, dtype: this.row_splits.dtype); | |||
} | |||
public Tensor row_lengths() | |||
{ | |||
if (this._row_splits != null) | |||
{ | |||
int nrows_plus_one = tensor_shape.dimension_value(this._row_splits.shape[0]); | |||
return tf.constant(nrows_plus_one - 1); | |||
} | |||
if (this._row_lengths != null) | |||
{ | |||
var nrows = tensor_shape.dimension_value(this._row_lengths.shape[0]); | |||
return tf.constant(nrows); | |||
} | |||
if(this._nrows != null) | |||
{ | |||
return tensor_util.constant_value(this._nrows); | |||
} | |||
return tf.constant(-1); | |||
} | |||
} | |||
} |
@@ -0,0 +1,26 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using System.Threading.Tasks; | |||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
using Tensorflow; | |||
using Tensorflow.NumPy; | |||
using static Tensorflow.Binding; | |||
namespace TensorFlowNET.UnitTest.ManagedAPI | |||
{ | |||
public class RaggedTensorTest :EagerModeTestBase | |||
{ | |||
[TestMethod] | |||
public void Test_from_row_lengths() | |||
{ | |||
var row_lengths = tf.convert_to_tensor(np.array(new int[] { 2, 0, 3, 1, 1 }, TF_DataType.TF_INT64)); | |||
var rp = RowPartition.from_row_lengths(row_lengths, validate: false); | |||
var rp_row_lengths = rp.row_lengths(); | |||
var rp_nrows = rp.nrows(); | |||
Assert.IsTrue(rp_nrows.ToArray<long>()[0] == rp.nrows().ToArray<long>()[0]); | |||
} | |||
} | |||
} |