@@ -14,17 +14,18 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using Tensorflow.Framework; | |||
namespace Tensorflow | |||
{ | |||
public partial class tensorflow | |||
{ | |||
public SparseTensor<T> SparseTensor<T>(long[,] indices, T[] values, long[] dense_shape) | |||
=> new SparseTensor<T>(indices, values, dense_shape); | |||
public SparseTensor SparseTensor(long[,] indices, Array values, long[] dense_shape) | |||
=> new SparseTensor(indices, values, dense_shape); | |||
public Tensor sparse_tensor_to_dense<T>(SparseTensor<T> sp_input, | |||
T default_value = default, | |||
public Tensor sparse_tensor_to_dense(SparseTensor sp_input, | |||
Array default_value = default, | |||
bool validate_indices = true, | |||
string name = null) | |||
=> gen_sparse_ops.sparse_to_dense(sp_input.indices, | |||
@@ -14,6 +14,8 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using Tensorflow.Framework; | |||
namespace Tensorflow | |||
{ | |||
public partial class tensorflow | |||
@@ -65,7 +67,7 @@ namespace Tensorflow | |||
string name = null, string @uint = "BYTE") | |||
=> ops.substr(input, pos, len, @uint: @uint, name: name); | |||
public Tensor split(Tensor input, string sep = "", int maxsplit = -1, string name = null) | |||
public SparseTensor split(Tensor input, string sep = "", int maxsplit = -1, string name = null) | |||
=> ops.string_split_v2(input, sep: sep, maxsplit : maxsplit, name : name); | |||
} | |||
} | |||
@@ -1,63 +0,0 @@ | |||
using System; | |||
using System.Linq; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Framework | |||
{ | |||
/// <summary> | |||
/// Represents a sparse tensor. | |||
/// </summary> | |||
public class SparseTensor<T> : CompositeTensor, _TensorLike | |||
{ | |||
long[,] _indices; | |||
public Tensor indices; | |||
T[] _values; | |||
public Tensor values; | |||
long[] _dense_shape; | |||
public Tensor dense_shape; | |||
TensorShape _shape; | |||
public TensorShape shape => _shape; | |||
public TF_DataType dtype => dtypes.as_dtype(typeof(T)); | |||
public SparseTensor(long[,] indices_, T[] values_, long[] dense_shape_) | |||
{ | |||
tf_with(ops.name_scope(null, "SparseTensor", new { }), delegate | |||
{ | |||
indices = ops.convert_to_tensor( | |||
indices_, name: "indices", dtype: dtypes.int64); | |||
values = ops.convert_to_tensor(values_, name: "values"); | |||
dense_shape = ops.convert_to_tensor( | |||
dense_shape_, name: "dense_shape", dtype: dtypes.int64); | |||
}); | |||
_indices = indices_; | |||
_values = values_; | |||
_dense_shape = dense_shape_; | |||
var indices_shape = indices.TensorShape.with_rank(2); | |||
var values_shape = values.TensorShape.with_rank(1); | |||
var dense_shape_shape = dense_shape.TensorShape.with_rank(1); | |||
indices_shape["0"].merge_with(values_shape[0]); | |||
indices_shape["1"].merge_with(dense_shape_shape[0]); | |||
_shape = new TensorShape(_dense_shape.Select(x => Convert.ToInt32(x)).ToArray()); | |||
} | |||
} | |||
public interface _TensorLike | |||
{ | |||
} | |||
public static class sparse_tensor_extension | |||
{ | |||
public static bool is_sparse(this _TensorLike x) | |||
{ | |||
return x.GetType().Name.Contains("SparseTensor"); | |||
} | |||
} | |||
} |
@@ -44,14 +44,14 @@ namespace Tensorflow.Framework | |||
return true; | |||
} | |||
if (other.is_sparse()) | |||
if (other.IsSparseTensor) | |||
{ | |||
return self.dtype.is_compatible_with(other.dtype); | |||
} | |||
return self.dtype.is_compatible_with(other.dtype) && | |||
_shape_is_compatible_0dim(self.shape, other.shape) && | |||
!self.is_sparse(); | |||
!self.IsSparseTensor; | |||
} | |||
public static Dimension dimension_at_index(TensorShape shape, int index) | |||
@@ -14,6 +14,7 @@ | |||
limitations under the License. | |||
******************************************************************************/ | |||
using Tensorflow.Framework; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
@@ -42,9 +43,25 @@ namespace Tensorflow | |||
=> tf.Context.ExecuteOp("Substr", name, new ExecuteOpArgs(input, pos, len) | |||
.SetAttributes(new { unit = @uint })); | |||
public Tensor string_split_v2(Tensor input, string sep = "", int maxsplit = -1, string name = null) | |||
public SparseTensor string_split_v2(Tensor input, string sep = "", int maxsplit = -1, string name = null) | |||
{ | |||
return null; | |||
return tf_with(ops.name_scope(name, "StringSplit"), scope => | |||
{ | |||
var sep_tensor = ops.convert_to_tensor(sep, dtype: TF_DataType.TF_STRING); | |||
var result = tf.Context.ExecuteOp("StringSplitV2", name, | |||
new ExecuteOpArgs(input, sep) | |||
{ | |||
GetGradientAttrs = op => new | |||
{ | |||
maxsplit = op.get_attr<int>("maxsplit") | |||
} | |||
}.SetAttributes(new { maxsplit })); | |||
var (indices, values, shape) = (result[0], result[1], result[2]); | |||
indices.set_shape(new TensorShape(-1, 2)); | |||
values.set_shape(new TensorShape(-1)); | |||
shape.set_shape(new TensorShape(2)); | |||
return new SparseTensor(indices, values, shape); | |||
}); | |||
} | |||
} | |||
} |
@@ -7,7 +7,7 @@ using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
public class EagerTensorV2 : DisposableObject, ITensor | |||
public class EagerTensorV2 : DisposableObject | |||
{ | |||
SafeTensorHandleHandle EagerTensorHandle; | |||
public string Device | |||
@@ -1,7 +0,0 @@ | |||
namespace Tensorflow | |||
{ | |||
public interface ITensor | |||
{ | |||
} | |||
} |
@@ -0,0 +1,56 @@ | |||
/***************************************************************************** | |||
Copyright 2021 Haiping Chen. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Framework; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
/// <summary> | |||
/// Represents a ragged tensor. | |||
/// </summary> | |||
public class RaggedTensor : CompositeTensor | |||
{ | |||
public RaggedTensor(Tensor values, RowPartition row_partition, bool validate = true) | |||
{ | |||
} | |||
/// <summary> | |||
/// Creates a `RaggedTensor` with rows partitioned by `value_rowids`. | |||
/// </summary> | |||
/// <param name="values"></param> | |||
/// <param name="value_rowids"></param> | |||
/// <param name="nrows"></param> | |||
/// <param name="name"></param> | |||
/// <param name="validate"></param> | |||
/// <returns></returns> | |||
public static RaggedTensor from_value_rowids(Tensor values, Tensor value_rowids, | |||
Tensor nrows = null, string name = null, bool validate = true) | |||
{ | |||
return tf_with(ops.name_scope(name, "RaggedFromValueRowIds"), scope => | |||
{ | |||
var row_partition = RowPartition.from_value_rowids(value_rowids, | |||
nrows: nrows, | |||
validate: validate); | |||
return new RaggedTensor(values, row_partition, validate: validate); | |||
}); | |||
} | |||
} | |||
} |
@@ -0,0 +1,59 @@ | |||
/***************************************************************************** | |||
Copyright 2021 Haiping Chen. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Framework; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
/// <summary> | |||
/// Partitioning of a sequence of values into contiguous subsequences ("rows"). | |||
/// </summary> | |||
public class RowPartition : CompositeTensor | |||
{ | |||
public RowPartition(Tensor row_splits, | |||
Tensor row_lengths = null, Tensor value_rowids = null, Tensor nrows = null, | |||
Tensor uniform_row_length = null) | |||
{ | |||
} | |||
/// <summary> | |||
/// Creates a `RowPartition` with rows partitioned by `value_rowids`. | |||
/// </summary> | |||
/// <param name="value_rowids"></param> | |||
/// <param name="nrows"></param> | |||
/// <param name="validate"></param> | |||
/// <param name="preferred_dtype"></param> | |||
/// <returns></returns> | |||
public static RowPartition from_value_rowids(Tensor value_rowids, | |||
Tensor nrows = null, bool validate = true, TF_DataType preferred_dtype = TF_DataType.DtInvalid) | |||
{ | |||
return tf_with(ops.name_scope(null, "RowPartitionFromValueRowIds"), scope => | |||
{ | |||
Tensor row_lengths = null; | |||
Tensor row_splits = null; | |||
return new RowPartition(row_splits, | |||
row_lengths: row_lengths, | |||
value_rowids: value_rowids, | |||
nrows: nrows); | |||
}); | |||
} | |||
} | |||
} |
@@ -0,0 +1,76 @@ | |||
/***************************************************************************** | |||
Copyright 2021 Haiping Chen. All Rights Reserved. | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
******************************************************************************/ | |||
using System; | |||
using System.Linq; | |||
using Tensorflow.Framework; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
/// <summary> | |||
/// Represents a sparse tensor. | |||
/// </summary> | |||
public class SparseTensor : CompositeTensor | |||
{ | |||
public Tensor indices; | |||
public Tensor values; | |||
public Tensor dense_shape; | |||
public SparseTensor(Tensor indices, Tensor values, Tensor dense_shape) | |||
{ | |||
this.indices = indices; | |||
this.values = values; | |||
this.dense_shape = dense_shape; | |||
_init(); | |||
} | |||
public SparseTensor(long[,] indices_, Array values_, long[] dense_shape_) | |||
{ | |||
tf_with(ops.name_scope(null, "SparseTensor", new { }), delegate | |||
{ | |||
indices = ops.convert_to_tensor( | |||
indices_, name: "indices", dtype: dtypes.int64); | |||
values = ops.convert_to_tensor(values_, name: "values"); | |||
dense_shape = ops.convert_to_tensor( | |||
dense_shape_, name: "dense_shape", dtype: dtypes.int64); | |||
}); | |||
_init(); | |||
} | |||
void _init() | |||
{ | |||
var indices_shape = indices.TensorShape.with_rank(2); | |||
var values_shape = values.TensorShape.with_rank(1); | |||
var dense_shape_shape = dense_shape.TensorShape.with_rank(1); | |||
indices_shape["0"].merge_with(values_shape[0]); | |||
indices_shape["1"].merge_with(dense_shape_shape[0]); | |||
} | |||
public static implicit operator Tensor(SparseTensor indexedSlices) | |||
{ | |||
return indexedSlices.values; | |||
} | |||
public static implicit operator SparseTensor(Tensor tensor) | |||
{ | |||
return tensor.Tag as SparseTensor; | |||
} | |||
} | |||
} |
@@ -33,9 +33,7 @@ namespace Tensorflow | |||
/// </summary> | |||
[SuppressMessage("ReSharper", "ConvertToAutoProperty")] | |||
public partial class Tensor : DisposableObject, | |||
ITensor, | |||
ITensorOrOperation, | |||
_TensorLike, | |||
ITensorOrTensorArray, | |||
IPackable<Tensor>, | |||
ICanBeFlattened | |||
@@ -97,6 +95,7 @@ namespace Tensorflow | |||
public SafeTensorHandleHandle EagerTensorHandle { get; set; } | |||
public bool IsEagerTensor => this is EagerTensor; | |||
public bool IsSparseTensor => this is SparseTensor; | |||
/// <summary> | |||
/// Returns the shape of a tensor. | |||
@@ -47,14 +47,16 @@ namespace Tensorflow.Keras.Layers | |||
Tensors _preprocess(Tensors inputs) | |||
{ | |||
Tensor input_tensor = null; | |||
if (args.Standardize != null) | |||
inputs = args.Standardize(inputs); | |||
input_tensor = args.Standardize(inputs); | |||
if (!string.IsNullOrEmpty(args.Split)) | |||
{ | |||
if (inputs.shape.ndim > 1) | |||
inputs = array_ops.squeeze(inputs, axis: new[] { -1 }); | |||
input_tensor = array_ops.squeeze(inputs, axis: new[] { -1 }); | |||
if (args.Split == "whitespace") | |||
inputs = tf.strings.split(inputs); | |||
input_tensor = tf.strings.split(inputs); | |||
} | |||
return inputs; | |||
} | |||
@@ -58,5 +58,12 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||
Assert.AreEqual(strings[1], stringData[1]); | |||
Assert.AreEqual(strings[2], stringData[2]); | |||
} | |||
[TestMethod] | |||
public void StringSplit() | |||
{ | |||
var tensor = tf.constant(new[] { "hello world", "tensorflow .net" }); | |||
tf.strings.split(tensor); | |||
} | |||
} | |||
} |