Browse Source

RowPartition and RaggedTensor

tags/v0.40-tf2.4-tstring
Oceania2018 4 years ago
parent
commit
a1ebd70c46
13 changed files with 234 additions and 85 deletions
  1. +5
    -4
      src/TensorFlowNET.Core/APIs/tf.sparse.cs
  2. +3
    -1
      src/TensorFlowNET.Core/APIs/tf.strings.cs
  3. +0
    -63
      src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs
  4. +2
    -2
      src/TensorFlowNET.Core/Framework/tensor_shape.cs
  5. +19
    -2
      src/TensorFlowNET.Core/Operations/string_ops.cs
  6. +1
    -1
      src/TensorFlowNET.Core/Tensors/EagerTensorV2.cs
  7. +0
    -7
      src/TensorFlowNET.Core/Tensors/ITensor.cs
  8. +56
    -0
      src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs
  9. +59
    -0
      src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs
  10. +76
    -0
      src/TensorFlowNET.Core/Tensors/Ragged/SparseTensor.cs
  11. +1
    -2
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  12. +5
    -3
      src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs
  13. +7
    -0
      test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs

+ 5
- 4
src/TensorFlowNET.Core/APIs/tf.sparse.cs View File

@@ -14,17 +14,18 @@
limitations under the License.
******************************************************************************/

using System;
using Tensorflow.Framework;

namespace Tensorflow
{
public partial class tensorflow
{
public SparseTensor<T> SparseTensor<T>(long[,] indices, T[] values, long[] dense_shape)
=> new SparseTensor<T>(indices, values, dense_shape);
public SparseTensor SparseTensor(long[,] indices, Array values, long[] dense_shape)
=> new SparseTensor(indices, values, dense_shape);

public Tensor sparse_tensor_to_dense<T>(SparseTensor<T> sp_input,
T default_value = default,
public Tensor sparse_tensor_to_dense(SparseTensor sp_input,
Array default_value = default,
bool validate_indices = true,
string name = null)
=> gen_sparse_ops.sparse_to_dense(sp_input.indices,


+ 3
- 1
src/TensorFlowNET.Core/APIs/tf.strings.cs View File

@@ -14,6 +14,8 @@
limitations under the License.
******************************************************************************/

using Tensorflow.Framework;

namespace Tensorflow
{
public partial class tensorflow
@@ -65,7 +67,7 @@ namespace Tensorflow
string name = null, string @uint = "BYTE")
=> ops.substr(input, pos, len, @uint: @uint, name: name);

public Tensor split(Tensor input, string sep = "", int maxsplit = -1, string name = null)
public SparseTensor split(Tensor input, string sep = "", int maxsplit = -1, string name = null)
=> ops.string_split_v2(input, sep: sep, maxsplit : maxsplit, name : name);
}
}


+ 0
- 63
src/TensorFlowNET.Core/Framework/sparse_tensor.py.cs View File

@@ -1,63 +0,0 @@
using System;
using System.Linq;
using static Tensorflow.Binding;

namespace Tensorflow.Framework
{
/// <summary>
/// Represents a sparse tensor.
/// </summary>
public class SparseTensor<T> : CompositeTensor, _TensorLike
{
long[,] _indices;
public Tensor indices;

T[] _values;
public Tensor values;

long[] _dense_shape;
public Tensor dense_shape;

TensorShape _shape;
public TensorShape shape => _shape;

public TF_DataType dtype => dtypes.as_dtype(typeof(T));

public SparseTensor(long[,] indices_, T[] values_, long[] dense_shape_)
{
tf_with(ops.name_scope(null, "SparseTensor", new { }), delegate
{
indices = ops.convert_to_tensor(
indices_, name: "indices", dtype: dtypes.int64);
values = ops.convert_to_tensor(values_, name: "values");
dense_shape = ops.convert_to_tensor(
dense_shape_, name: "dense_shape", dtype: dtypes.int64);
});

_indices = indices_;
_values = values_;
_dense_shape = dense_shape_;

var indices_shape = indices.TensorShape.with_rank(2);
var values_shape = values.TensorShape.with_rank(1);
var dense_shape_shape = dense_shape.TensorShape.with_rank(1);

indices_shape["0"].merge_with(values_shape[0]);
indices_shape["1"].merge_with(dense_shape_shape[0]);

_shape = new TensorShape(_dense_shape.Select(x => Convert.ToInt32(x)).ToArray());
}
}

public interface _TensorLike
{
}

public static class sparse_tensor_extension
{
public static bool is_sparse(this _TensorLike x)
{
return x.GetType().Name.Contains("SparseTensor");
}
}
}

+ 2
- 2
src/TensorFlowNET.Core/Framework/tensor_shape.cs View File

@@ -44,14 +44,14 @@ namespace Tensorflow.Framework
return true;
}

if (other.is_sparse())
if (other.IsSparseTensor)
{
return self.dtype.is_compatible_with(other.dtype);
}

return self.dtype.is_compatible_with(other.dtype) &&
_shape_is_compatible_0dim(self.shape, other.shape) &&
!self.is_sparse();
!self.IsSparseTensor;
}

public static Dimension dimension_at_index(TensorShape shape, int index)


+ 19
- 2
src/TensorFlowNET.Core/Operations/string_ops.cs View File

@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/

using Tensorflow.Framework;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -42,9 +43,25 @@ namespace Tensorflow
=> tf.Context.ExecuteOp("Substr", name, new ExecuteOpArgs(input, pos, len)
.SetAttributes(new { unit = @uint }));

public Tensor string_split_v2(Tensor input, string sep = "", int maxsplit = -1, string name = null)
public SparseTensor string_split_v2(Tensor input, string sep = "", int maxsplit = -1, string name = null)
{
return null;
return tf_with(ops.name_scope(name, "StringSplit"), scope =>
{
var sep_tensor = ops.convert_to_tensor(sep, dtype: TF_DataType.TF_STRING);
var result = tf.Context.ExecuteOp("StringSplitV2", name,
new ExecuteOpArgs(input, sep)
{
GetGradientAttrs = op => new
{
maxsplit = op.get_attr<int>("maxsplit")
}
}.SetAttributes(new { maxsplit }));
var (indices, values, shape) = (result[0], result[1], result[2]);
indices.set_shape(new TensorShape(-1, 2));
values.set_shape(new TensorShape(-1));
shape.set_shape(new TensorShape(2));
return new SparseTensor(indices, values, shape);
});
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Tensors/EagerTensorV2.cs View File

@@ -7,7 +7,7 @@ using static Tensorflow.Binding;

namespace Tensorflow
{
public class EagerTensorV2 : DisposableObject, ITensor
public class EagerTensorV2 : DisposableObject
{
SafeTensorHandleHandle EagerTensorHandle;
public string Device


+ 0
- 7
src/TensorFlowNET.Core/Tensors/ITensor.cs View File

@@ -1,7 +0,0 @@
namespace Tensorflow
{
public interface ITensor
{

}
}

+ 56
- 0
src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs View File

@@ -0,0 +1,56 @@
/*****************************************************************************
Copyright 2021 Haiping Chen. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Framework;
using static Tensorflow.Binding;

namespace Tensorflow
{
/// <summary>
/// Represents a ragged tensor.
/// </summary>
public class RaggedTensor : CompositeTensor
{
public RaggedTensor(Tensor values, RowPartition row_partition, bool validate = true)
{

}

/// <summary>
/// Creates a `RaggedTensor` with rows partitioned by `value_rowids`.
/// </summary>
/// <param name="values"></param>
/// <param name="value_rowids"></param>
/// <param name="nrows"></param>
/// <param name="name"></param>
/// <param name="validate"></param>
/// <returns></returns>
public static RaggedTensor from_value_rowids(Tensor values, Tensor value_rowids,
Tensor nrows = null, string name = null, bool validate = true)
{
return tf_with(ops.name_scope(name, "RaggedFromValueRowIds"), scope =>
{
var row_partition = RowPartition.from_value_rowids(value_rowids,
nrows: nrows,
validate: validate);
return new RaggedTensor(values, row_partition, validate: validate);
});
}
}
}

+ 59
- 0
src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs View File

@@ -0,0 +1,59 @@
/*****************************************************************************
Copyright 2021 Haiping Chen. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Framework;
using static Tensorflow.Binding;

namespace Tensorflow
{
/// <summary>
/// Partitioning of a sequence of values into contiguous subsequences ("rows").
/// </summary>
public class RowPartition : CompositeTensor
{
public RowPartition(Tensor row_splits,
Tensor row_lengths = null, Tensor value_rowids = null, Tensor nrows = null,
Tensor uniform_row_length = null)
{

}

/// <summary>
/// Creates a `RowPartition` with rows partitioned by `value_rowids`.
/// </summary>
/// <param name="value_rowids"></param>
/// <param name="nrows"></param>
/// <param name="validate"></param>
/// <param name="preferred_dtype"></param>
/// <returns></returns>
public static RowPartition from_value_rowids(Tensor value_rowids,
Tensor nrows = null, bool validate = true, TF_DataType preferred_dtype = TF_DataType.DtInvalid)
{
return tf_with(ops.name_scope(null, "RowPartitionFromValueRowIds"), scope =>
{
Tensor row_lengths = null;
Tensor row_splits = null;
return new RowPartition(row_splits,
row_lengths: row_lengths,
value_rowids: value_rowids,
nrows: nrows);
});
}
}
}

+ 76
- 0
src/TensorFlowNET.Core/Tensors/Ragged/SparseTensor.cs View File

@@ -0,0 +1,76 @@
/*****************************************************************************
Copyright 2021 Haiping Chen. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System;
using System.Linq;
using Tensorflow.Framework;
using static Tensorflow.Binding;

namespace Tensorflow
{
/// <summary>
/// Represents a sparse tensor.
/// </summary>
public class SparseTensor : CompositeTensor
{
public Tensor indices;

public Tensor values;

public Tensor dense_shape;

public SparseTensor(Tensor indices, Tensor values, Tensor dense_shape)
{
this.indices = indices;
this.values = values;
this.dense_shape = dense_shape;
_init();
}

public SparseTensor(long[,] indices_, Array values_, long[] dense_shape_)
{
tf_with(ops.name_scope(null, "SparseTensor", new { }), delegate
{
indices = ops.convert_to_tensor(
indices_, name: "indices", dtype: dtypes.int64);
values = ops.convert_to_tensor(values_, name: "values");
dense_shape = ops.convert_to_tensor(
dense_shape_, name: "dense_shape", dtype: dtypes.int64);
});
_init();
}

void _init()
{
var indices_shape = indices.TensorShape.with_rank(2);
var values_shape = values.TensorShape.with_rank(1);
var dense_shape_shape = dense_shape.TensorShape.with_rank(1);

indices_shape["0"].merge_with(values_shape[0]);
indices_shape["1"].merge_with(dense_shape_shape[0]);
}

public static implicit operator Tensor(SparseTensor indexedSlices)
{
return indexedSlices.values;
}

public static implicit operator SparseTensor(Tensor tensor)
{
return tensor.Tag as SparseTensor;
}
}
}

+ 1
- 2
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -33,9 +33,7 @@ namespace Tensorflow
/// </summary>
[SuppressMessage("ReSharper", "ConvertToAutoProperty")]
public partial class Tensor : DisposableObject,
ITensor,
ITensorOrOperation,
_TensorLike,
ITensorOrTensorArray,
IPackable<Tensor>,
ICanBeFlattened
@@ -97,6 +95,7 @@ namespace Tensorflow
public SafeTensorHandleHandle EagerTensorHandle { get; set; }

public bool IsEagerTensor => this is EagerTensor;
public bool IsSparseTensor => this is SparseTensor;

/// <summary>
/// Returns the shape of a tensor.


+ 5
- 3
src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs View File

@@ -47,14 +47,16 @@ namespace Tensorflow.Keras.Layers

Tensors _preprocess(Tensors inputs)
{
Tensor input_tensor = null;
if (args.Standardize != null)
inputs = args.Standardize(inputs);
input_tensor = args.Standardize(inputs);
if (!string.IsNullOrEmpty(args.Split))
{
if (inputs.shape.ndim > 1)
inputs = array_ops.squeeze(inputs, axis: new[] { -1 });
input_tensor = array_ops.squeeze(inputs, axis: new[] { -1 });
if (args.Split == "whitespace")
inputs = tf.strings.split(inputs);
input_tensor = tf.strings.split(inputs);

}
return inputs;
}


+ 7
- 0
test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs View File

@@ -58,5 +58,12 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
Assert.AreEqual(strings[1], stringData[1]);
Assert.AreEqual(strings[2], stringData[2]);
}

[TestMethod]
public void StringSplit()
{
var tensor = tf.constant(new[] { "hello world", "tensorflow .net" });
tf.strings.split(tensor);
}
}
}

Loading…
Cancel
Save