From 6adcfaea3ff6f5c9d2437e54b66d781c4c16a72a Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 25 Jul 2021 16:40:22 -0500 Subject: [PATCH] ndarray string comparison. --- .../Eager/EagerTensor.Creation.cs | 3 ++ src/TensorFlowNET.Core/NumPy/NDArray.Equal.cs | 1 + src/TensorFlowNET.Core/NumPy/NDArrayRender.cs | 2 +- .../Tensorflow.Binding.csproj | 2 +- .../Tensors/SafeStringTensorHandle.cs | 11 +++--- .../Tensors/Tensor.String.cs | 34 +++++++++++++++++-- src/TensorFlowNET.Keras/Utils/np_utils.cs | 2 +- .../ManagedAPI/StringsApiTest.cs | 7 ++-- 8 files changed, 49 insertions(+), 13 deletions(-) diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs index 8bc10758..d3433ee2 100644 --- a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs @@ -60,6 +60,9 @@ namespace Tensorflow.Eager { _id = ops.uid(); _eagerTensorHandle = c_api.TFE_NewTensorHandle(h, tf.Status.Handle); +#if TRACK_TENSOR_LIFE + Console.WriteLine($"New EagerTensor {_eagerTensorHandle}"); +#endif tf.Status.Check(true); } diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Equal.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Equal.cs index d9ad9ae6..6e3a4c76 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArray.Equal.cs +++ b/src/TensorFlowNET.Core/NumPy/NDArray.Equal.cs @@ -16,6 +16,7 @@ namespace Tensorflow.NumPy long val => GetAtIndex(0) == val, float val => GetAtIndex(0) == val, double val => GetAtIndex(0) == val, + string val => StringData(0) == val, NDArray val => Equals(this, val), _ => base.Equals(obj) }; diff --git a/src/TensorFlowNET.Core/NumPy/NDArrayRender.cs b/src/TensorFlowNET.Core/NumPy/NDArrayRender.cs index 22b4d6ab..3ab6cb34 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArrayRender.cs +++ b/src/TensorFlowNET.Core/NumPy/NDArrayRender.cs @@ -91,7 +91,7 @@ namespace Tensorflow.NumPy .Take(25) .Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString())) + "'"; else - return $"['{string.Join("', '", array.StringData().Take(25))}']"; + return $"'{string.Join("', '", array.StringData().Take(25))}'"; } else if (dtype == TF_DataType.TF_VARIANT) { diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj index 1fee17f3..33d9b699 100644 --- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj +++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj @@ -51,7 +51,7 @@ tf.net 0.6x.x aligns with TensorFlow v2.6.x native library. true - TRACE;DEBUG + TRACE;DEBUG;TRACK_TENSOR_LIFE1 x64 TensorFlow.NET.xml diff --git a/src/TensorFlowNET.Core/Tensors/SafeStringTensorHandle.cs b/src/TensorFlowNET.Core/Tensors/SafeStringTensorHandle.cs index 5730f0cd..d7ece8d2 100644 --- a/src/TensorFlowNET.Core/Tensors/SafeStringTensorHandle.cs +++ b/src/TensorFlowNET.Core/Tensors/SafeStringTensorHandle.cs @@ -8,7 +8,7 @@ namespace Tensorflow public sealed class SafeStringTensorHandle : SafeTensorHandle { Shape _shape; - IntPtr _handle; + SafeTensorHandle _tensorHandle; const int TF_TSRING_SIZE = 24; protected SafeStringTensorHandle() @@ -18,16 +18,18 @@ namespace Tensorflow public SafeStringTensorHandle(SafeTensorHandle handle, Shape shape) : base(handle.DangerousGetHandle()) { - _handle = c_api.TF_TensorData(handle); + _tensorHandle = handle; _shape = shape; + bool success = false; + _tensorHandle.DangerousAddRef(ref success); } protected override bool ReleaseHandle() { + var _handle = c_api.TF_TensorData(_tensorHandle); #if TRACK_TENSOR_LIFE - print($"Delete StringTensorHandle 0x{handle.ToString("x16")}"); + Console.WriteLine($"Delete StringTensorData 0x{_handle.ToString("x16")}"); #endif - for (int i = 0; i < _shape.size; i++) { c_api.TF_StringDealloc(_handle); @@ -35,6 +37,7 @@ namespace Tensorflow } SetHandle(IntPtr.Zero); + _tensorHandle.DangerousRelease(); return true; } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.String.cs b/src/TensorFlowNET.Core/Tensors/Tensor.String.cs index 50976550..5048d5a5 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.String.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.String.cs @@ -29,13 +29,13 @@ namespace Tensorflow var tstr = c_api.TF_TensorData(handle); #if TRACK_TENSOR_LIFE - print($"New TString 0x{handle.ToString("x16")} Data: 0x{tstr.ToString("x16")}"); + print($"New StringTensor {handle} Data: 0x{tstr.ToString("x16")}"); #endif for (int i = 0; i < buffer.Length; i++) { c_api.TF_StringInit(tstr); c_api.TF_StringCopy(tstr, buffer[i], buffer[i].Length); - var data = c_api.TF_StringGetDataPointer(tstr); + // var data = c_api.TF_StringGetDataPointer(tstr); tstr += TF_TSRING_SIZE; } @@ -53,6 +53,36 @@ namespace Tensorflow return _str; } + public string StringData(int index) + { + var bytes = StringBytes(index); + return Encoding.UTF8.GetString(bytes); + } + + public byte[] StringBytes(int index) + { + if (dtype != TF_DataType.TF_STRING) + throw new InvalidOperationException($"Unable to call StringData when dtype != TF_DataType.TF_STRING (dtype is {dtype})"); + + byte[] buffer = new byte[0]; + var tstrings = TensorDataPointer; + for (int i = 0; i < shape.size; i++) + { + if(index == i) + { + var data = c_api.TF_StringGetDataPointer(tstrings); + var len = c_api.TF_StringGetSize(tstrings); + buffer = new byte[len]; + // var capacity = c_api.TF_StringGetCapacity(tstrings); + // var type = c_api.TF_StringGetType(tstrings); + Marshal.Copy(data, buffer, 0, Convert.ToInt32(len)); + break; + } + tstrings += TF_TSRING_SIZE; + } + return buffer; + } + public byte[][] StringBytes() { if (dtype != TF_DataType.TF_STRING) diff --git a/src/TensorFlowNET.Keras/Utils/np_utils.cs b/src/TensorFlowNET.Keras/Utils/np_utils.cs index 758c287c..ef29b046 100644 --- a/src/TensorFlowNET.Keras/Utils/np_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/np_utils.cs @@ -22,7 +22,7 @@ namespace Tensorflow.Keras.Utils // categorical[np.arange(y.size), y] = 1; for (var i = 0; i < (int)y.size; i++) { - categorical[i][y1[i]] = 1.0f; + categorical[i, y1[i]] = 1.0f; } return categorical; diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs index d98c5207..353d192f 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs @@ -51,12 +51,11 @@ namespace TensorFlowNET.UnitTest.ManagedAPI { var strings = new[] { "map_and_batch_fusion", "noop_elimination", "shuffle_and_repeat_fusion" }; var tensor = tf.constant(strings, dtype: tf.@string, name: "optimizations"); - var stringData = tensor.StringData(); Assert.AreEqual(3, tensor.shape[0]); - Assert.AreEqual(strings[0], stringData[0]); - Assert.AreEqual(strings[1], stringData[1]); - Assert.AreEqual(strings[2], stringData[2]); + Assert.AreEqual(tensor[0].numpy(), strings[0]); + Assert.AreEqual(tensor[1].numpy(), strings[1]); + Assert.AreEqual(tensor[2].numpy(), strings[2]); } [TestMethod]