Browse Source

Fix string_split_v2 return RaggedTensor.

tags/v0.40-tf2.4-tstring
Oceania2018 4 years ago
parent
commit
559921585d
10 changed files with 144 additions and 26 deletions
  1. +22
    -0
      src/TensorFlowNET.Core/APIs/tf.math.cs
  2. +1
    -1
      src/TensorFlowNET.Core/APIs/tf.strings.cs
  3. +0
    -7
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  4. +31
    -9
      src/TensorFlowNET.Core/Operations/math_ops.cs
  5. +8
    -2
      src/TensorFlowNET.Core/Operations/string_ops.cs
  6. +1
    -0
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  7. +37
    -2
      src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs
  8. +37
    -3
      src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs
  9. +4
    -0
      src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
  10. +3
    -2
      test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs

+ 22
- 0
src/TensorFlowNET.Core/APIs/tf.math.cs View File

@@ -32,6 +32,28 @@ namespace Tensorflow
/// <returns></returns>
public Tensor erf(Tensor x, string name = null)
=> math_ops.erf(x, name);

/// <summary>
///
/// </summary>
/// <param name="arr"></param>
/// <param name="weights"></param>
/// <param name="minlength"></param>
/// <param name="maxlength"></param>
/// <param name="dtype"></param>
/// <param name="name"></param>
/// <param name="axis"></param>
/// <param name="binary_output"></param>
/// <returns></returns>
public Tensor bincount(Tensor arr, Tensor weights = null,
Tensor minlength = null,
Tensor maxlength = null,
TF_DataType dtype = TF_DataType.TF_INT32,
string name = null,
TensorShape axis = null,
bool binary_output = false)
=> math_ops.bincount(arr, weights: weights, minlength: minlength, maxlength: maxlength,
dtype: dtype, name: name, axis: axis, binary_output: binary_output);
}

public Tensor abs(Tensor x, string name = null)


+ 1
- 1
src/TensorFlowNET.Core/APIs/tf.strings.cs View File

@@ -67,7 +67,7 @@ namespace Tensorflow
string name = null, string @uint = "BYTE")
=> ops.substr(input, pos, len, @uint: @uint, name: name);

public SparseTensor split(Tensor input, string sep = "", int maxsplit = -1, string name = null)
public RaggedTensor split(Tensor input, string sep = "", int maxsplit = -1, string name = null)
=> ops.string_split_v2(input, sep: sep, maxsplit : maxsplit, name : name);
}
}


+ 0
- 7
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -249,13 +249,6 @@ namespace Tensorflow
return _op.outputs[0];
}

public static Tensor cumsum<T>(Tensor x, T axis, bool exclusive = false, bool reverse = false, string name = null)
{
var _op = tf.OpDefLib._apply_op_helper("Cumsum", name, args: new { x, axis, exclusive, reverse });

return _op.outputs[0];
}

/// <summary>
/// Computes the sum along segments of a tensor.
/// </summary>


+ 31
- 9
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -168,15 +168,12 @@ namespace Tensorflow
}

public static Tensor cumsum<T>(Tensor x, T axis = default, bool exclusive = false, bool reverse = false, string name = null)
{
return tf_with(ops.name_scope(name, "Cumsum", new { x }), scope =>
{
name = scope;
x = ops.convert_to_tensor(x, name: "x");

return gen_math_ops.cumsum(x, axis: axis, exclusive: exclusive, reverse: reverse, name: name);
});
}
=> tf_with(ops.name_scope(name, "Cumsum", new { x }), scope =>
{
name = scope;
return tf.Context.ExecuteOp("Cumsum", name, new ExecuteOpArgs(x, axis)
.SetAttributes(new { exclusive, reverse }));
});

/// <summary>
/// Computes Psi, the derivative of Lgamma (the log of the absolute value of
@@ -807,6 +804,31 @@ namespace Tensorflow
.SetAttributes(new { adj_x, adj_y }));
});

public static Tensor bincount(Tensor arr, Tensor weights = null,
Tensor minlength = null,
Tensor maxlength = null,
TF_DataType dtype = TF_DataType.TF_INT32,
string name = null,
TensorShape axis = null,
bool binary_output = false)
=> tf_with(ops.name_scope(name, "bincount"), scope =>
{
name = scope;
if(!binary_output && axis == null)
{
var array_is_nonempty = math_ops.reduce_prod(array_ops.shape(arr)) > 0;
var output_size = math_ops.cast(array_is_nonempty, dtypes.int32) * (math_ops.reduce_max(arr) + 1);
if (minlength != null)
output_size = math_ops.maximum(minlength, output_size);
if (maxlength != null)
output_size = math_ops.minimum(maxlength, output_size);
var weights = constant_op.constant(new long[0], dtype: dtype);
return tf.Context.ExecuteOp("Bincount", name, new ExecuteOpArgs(arr, output_size, weights));
}

throw new NotImplementedException("");
});

/// <summary>
/// Returns the complex conjugate of a complex number.
/// </summary>


+ 8
- 2
src/TensorFlowNET.Core/Operations/string_ops.cs View File

@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/

using NumSharp;
using Tensorflow.Framework;
using static Tensorflow.Binding;

@@ -43,7 +44,7 @@ namespace Tensorflow
=> tf.Context.ExecuteOp("Substr", name, new ExecuteOpArgs(input, pos, len)
.SetAttributes(new { unit = @uint }));

public SparseTensor string_split_v2(Tensor input, string sep = "", int maxsplit = -1, string name = null)
public RaggedTensor string_split_v2(Tensor input, string sep = "", int maxsplit = -1, string name = null)
{
return tf_with(ops.name_scope(name, "StringSplit"), scope =>
{
@@ -60,7 +61,12 @@ namespace Tensorflow
indices.set_shape(new TensorShape(-1, 2));
values.set_shape(new TensorShape(-1));
shape.set_shape(new TensorShape(2));
return new SparseTensor(indices, values, shape);

var sparse_result = new SparseTensor(indices, values, shape);
return RaggedTensor.from_value_rowids(sparse_result.values,
value_rowids: sparse_result.indices[Slice.All, 0],
nrows: sparse_result.dense_shape[0],
validate: false);
});
}
}


+ 1
- 0
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -50,6 +50,7 @@ tf.net 0.4x.x aligns with TensorFlow v2.4.1 native library.</PackageReleaseNotes
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<DefineConstants>TRACE;DEBUG</DefineConstants>
<PlatformTarget>x64</PlatformTarget>
<DocumentationFile>TensorFlow.NET.xml</DocumentationFile>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|AnyCPU'">


+ 37
- 2
src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs View File

@@ -17,6 +17,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using System.Linq;
using Tensorflow.Framework;
using static Tensorflow.Binding;

@@ -27,9 +28,30 @@ namespace Tensorflow
/// </summary>
public class RaggedTensor : CompositeTensor
{
public RaggedTensor(Tensor values, RowPartition row_partition, bool validate = true)
Tensor _values;
RowPartition _row_partition;
public TF_DataType dtype => _values.dtype;
public TensorShape shape
{
get
{
var nrows = _row_partition.static_nrows;
var ncols = _row_partition.static_uniform_row_length;
return new TensorShape(nrows, ncols);
}
}

public RaggedTensor(Tensor values,
bool @internal = true,
RowPartition row_partition = null)
{
_values = values;
_row_partition = row_partition;
}

public static RaggedTensor from_row_partition(Tensor values, RowPartition row_partition, bool validate = true)
{
return new RaggedTensor(values, @internal: true, row_partition: row_partition);
}

/// <summary>
@@ -49,8 +71,21 @@ namespace Tensorflow
var row_partition = RowPartition.from_value_rowids(value_rowids,
nrows: nrows,
validate: validate);
return new RaggedTensor(values, row_partition, validate: validate);
return from_row_partition(values, row_partition, validate: validate);
});
}

public override string ToString()
=> $"tf.RaggedTensor: shape={shape} [{string.Join(", ", _values.StringData().Take(10))}]";

public static implicit operator Tensor(RaggedTensor indexedSlices)
{
return indexedSlices._values;
}

public static implicit operator RaggedTensor(Tensor tensor)
{
return tensor.Tag as RaggedTensor;
}
}
}

+ 37
- 3
src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs View File

@@ -27,11 +27,35 @@ namespace Tensorflow
/// </summary>
public class RowPartition : CompositeTensor
{
Tensor _row_splits;
Tensor _row_lengths;
Tensor _value_rowids;
Tensor _nrows;

public int static_nrows
{
get
{
return _row_splits.shape[0] - 1;
}
}

public int static_uniform_row_length
{
get
{
return -1;
}
}

public RowPartition(Tensor row_splits,
Tensor row_lengths = null, Tensor value_rowids = null, Tensor nrows = null,
Tensor uniform_row_length = null)
{

_row_splits = row_splits;
_row_lengths = row_lengths;
_value_rowids = value_rowids;
_nrows = nrows;
}

/// <summary>
@@ -47,8 +71,18 @@ namespace Tensorflow
{
return tf_with(ops.name_scope(null, "RowPartitionFromValueRowIds"), scope =>
{
Tensor row_lengths = null;
Tensor row_splits = null;
var value_rowids_int32 = math_ops.cast(value_rowids, dtypes.int32);
var nrows_int32 = math_ops.cast(nrows, dtypes.int32);
var row_lengths = tf.math.bincount(value_rowids_int32,
minlength: nrows_int32,
maxlength: nrows_int32,
dtype: value_rowids.dtype);
var row_splits = array_ops.concat(new object[]
{
ops.convert_to_tensor(new long[] { 0 }),
tf.cumsum(row_lengths)
}, axis: 0);

return new RowPartition(row_splits,
row_lengths: row_lengths,
value_rowids: value_rowids,


+ 4
- 0
src/TensorFlowNET.Keras/Tensorflow.Keras.csproj View File

@@ -49,6 +49,10 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
<AllowUnsafeBlocks>false</AllowUnsafeBlocks>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<DocumentationFile>Tensorflow.Keras.xml</DocumentationFile>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.3" />


+ 3
- 2
test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs View File

@@ -62,8 +62,9 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
[TestMethod]
public void StringSplit()
{
var tensor = tf.constant(new[] { "hello world", "tensorflow .net" });
tf.strings.split(tensor);
var tensor = tf.constant(new[] { "hello world", "tensorflow .net csharp", "fsharp" });
var ragged_tensor = tf.strings.split(tensor);
Assert.AreEqual((3, -1), ragged_tensor.shape);
}
}
}

Loading…
Cancel
Save