Improve RaggedTensortags/v0.150.0-BERT-Model
@@ -1139,5 +1139,18 @@ namespace Tensorflow | |||||
var _op = tf.OpDefLib._apply_op_helper("Placeholder", name: name, args: new { dtype, shape }); | var _op = tf.OpDefLib._apply_op_helper("Placeholder", name: name, args: new { dtype, shape }); | ||||
return _op.output; | return _op.output; | ||||
} | } | ||||
public static int get_positive_axis(int axis, int ndims=-100, string axis_name="axis", string ndims_name= "ndims") | |||||
{ | |||||
if(ndims != -100) | |||||
{ | |||||
if (axis >= 0 && axis < ndims) return axis; | |||||
else if (-ndims <= axis && axis < 0) return axis + ndims; | |||||
else throw new ValueError($"{axis_name}={axis} out of bounds:expected {-ndims}<={axis_name}<{ndims}"); | |||||
} else if(axis < 0) throw new ValueError($"{axis_name}={axis} may only be negative if {ndims_name} is statically known."); | |||||
return axis; | |||||
} | |||||
} | } | ||||
} | } |
@@ -163,5 +163,38 @@ namespace Tensorflow | |||||
{ | { | ||||
return tensor.Tag as RaggedTensor; | return tensor.Tag as RaggedTensor; | ||||
} | } | ||||
public Tensor nrows(TF_DataType out_type, string name = null) | |||||
{ | |||||
tf_with(ops.name_scope(name, "RaggedNRows"), scope => | |||||
{ | |||||
return math_ops.cast(this._row_partition.nrows(), dtype: out_type); | |||||
}); | |||||
return null; | |||||
} | |||||
public RaggedTensor row_lengths(int axis=-1, string name=null) | |||||
{ | |||||
if (axis == 0) return this._row_partition.nrows(); | |||||
if (axis == 1) return this._row_partition.row_lengths(); | |||||
var values = (RaggedTensor)this._values; | |||||
axis = array_ops.get_positive_axis( | |||||
axis, this.shape.rank, ndims_name: "rank(this)"); | |||||
if (axis == 0) return this.nrows(this._row_partition.GetDataType()); | |||||
else if (axis == 1) | |||||
{ | |||||
var splits = this._row_partition.row_splits; | |||||
return splits[new Slice(start: 1)] - splits[new Slice(stop: -1)]; | |||||
} | |||||
else if (this._values is RaggedTensor) | |||||
{ | |||||
return values.row_lengths(axis - 1); | |||||
} | |||||
else | |||||
{ | |||||
var shape = array_ops.shape(values, out_type: this._row_partition.GetDataType()); | |||||
return array_ops.ones(shape[new Slice(stop:axis - 1)], this._row_partition.GetDataType()) * | |||||
shape[axis - 1]; | |||||
} | |||||
} | |||||
} | } | ||||
} | } |
@@ -14,10 +14,15 @@ | |||||
limitations under the License. | limitations under the License. | ||||
******************************************************************************/ | ******************************************************************************/ | ||||
using Serilog.Debugging; | |||||
using System; | using System; | ||||
using System.Collections.Concurrent; | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
//using System.ComponentModel.DataAnnotations; | |||||
using System.Text; | using System.Text; | ||||
using System.Xml.Linq; | |||||
using Tensorflow.Framework; | using Tensorflow.Framework; | ||||
using Tensorflow.NumPy; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
namespace Tensorflow | namespace Tensorflow | ||||
@@ -99,5 +104,55 @@ namespace Tensorflow | |||||
return new RowPartition(row_splits); | return new RowPartition(row_splits); | ||||
}); | }); | ||||
} | } | ||||
public static RowPartition from_row_lengths(Tensor row_lengths, | |||||
bool validate=true, | |||||
TF_DataType dtype = TF_DataType.TF_INT32, | |||||
TF_DataType dtype_hint= TF_DataType.TF_INT32) | |||||
{ | |||||
row_lengths = _convert_row_partition( | |||||
row_lengths, "row_lengths", dtype_hint: dtype_hint, dtype: dtype); | |||||
Tensor row_limits = math_ops.cumsum<Tensor>(row_lengths, tf.constant(-1)); | |||||
Tensor row_splits = array_ops.concat(new Tensor[] { tf.convert_to_tensor(np.array(new int[] { 0 }, TF_DataType.TF_INT64)), row_limits }, axis:0); | |||||
return new RowPartition(row_splits: row_splits, row_lengths: row_lengths); | |||||
} | |||||
public static Tensor _convert_row_partition(Tensor partition, string name, TF_DataType dtype, | |||||
TF_DataType dtype_hint= TF_DataType.TF_INT64) | |||||
{ | |||||
if (partition is NDArray && partition.GetDataType() == np.int32) partition = ops.convert_to_tensor(partition, name: name); | |||||
if (partition.GetDataType() != np.int32 && partition.GetDataType() != np.int64) throw new ValueError($"{name} must have dtype int32 or int64"); | |||||
return partition; | |||||
} | |||||
public Tensor nrows() | |||||
{ | |||||
/*Returns the number of rows created by this `RowPartition*/ | |||||
if (this._nrows != null) return this._nrows; | |||||
var nsplits = tensor_shape.dimension_at_index(this._row_splits.shape, 0); | |||||
if (nsplits == null) return array_ops.shape(this._row_splits, out_type: this.row_splits.dtype)[0] - 1; | |||||
else return constant_op.constant(nsplits.value - 1, dtype: this.row_splits.dtype); | |||||
} | |||||
public Tensor row_lengths() | |||||
{ | |||||
if (this._row_splits != null) | |||||
{ | |||||
int nrows_plus_one = tensor_shape.dimension_value(this._row_splits.shape[0]); | |||||
return tf.constant(nrows_plus_one - 1); | |||||
} | |||||
if (this._row_lengths != null) | |||||
{ | |||||
var nrows = tensor_shape.dimension_value(this._row_lengths.shape[0]); | |||||
return tf.constant(nrows); | |||||
} | |||||
if(this._nrows != null) | |||||
{ | |||||
return tensor_util.constant_value(this._nrows); | |||||
} | |||||
return tf.constant(-1); | |||||
} | |||||
} | } | ||||
} | } |
@@ -0,0 +1,26 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
using System.Threading.Tasks; | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using Tensorflow; | |||||
using Tensorflow.NumPy; | |||||
using static Tensorflow.Binding; | |||||
namespace TensorFlowNET.UnitTest.ManagedAPI | |||||
{ | |||||
public class RaggedTensorTest :EagerModeTestBase | |||||
{ | |||||
[TestMethod] | |||||
public void Test_from_row_lengths() | |||||
{ | |||||
var row_lengths = tf.convert_to_tensor(np.array(new int[] { 2, 0, 3, 1, 1 }, TF_DataType.TF_INT64)); | |||||
var rp = RowPartition.from_row_lengths(row_lengths, validate: false); | |||||
var rp_row_lengths = rp.row_lengths(); | |||||
var rp_nrows = rp.nrows(); | |||||
Assert.IsTrue(rp_nrows.ToArray<long>()[0] == rp.nrows().ToArray<long>()[0]); | |||||
} | |||||
} | |||||
} |