Browse Source

ndarray string comparison.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
6adcfaea3f
8 changed files with 49 additions and 13 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
  2. +1
    -0
      src/TensorFlowNET.Core/NumPy/NDArray.Equal.cs
  3. +1
    -1
      src/TensorFlowNET.Core/NumPy/NDArrayRender.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  5. +7
    -4
      src/TensorFlowNET.Core/Tensors/SafeStringTensorHandle.cs
  6. +32
    -2
      src/TensorFlowNET.Core/Tensors/Tensor.String.cs
  7. +1
    -1
      src/TensorFlowNET.Keras/Utils/np_utils.cs
  8. +3
    -4
      test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs

+ 3
- 0
src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs View File

@@ -60,6 +60,9 @@ namespace Tensorflow.Eager
{
_id = ops.uid();
_eagerTensorHandle = c_api.TFE_NewTensorHandle(h, tf.Status.Handle);
#if TRACK_TENSOR_LIFE
Console.WriteLine($"New EagerTensor {_eagerTensorHandle}");
#endif
tf.Status.Check(true);
}



+ 1
- 0
src/TensorFlowNET.Core/NumPy/NDArray.Equal.cs View File

@@ -16,6 +16,7 @@ namespace Tensorflow.NumPy
long val => GetAtIndex<long>(0) == val,
float val => GetAtIndex<float>(0) == val,
double val => GetAtIndex<double>(0) == val,
string val => StringData(0) == val,
NDArray val => Equals(this, val),
_ => base.Equals(obj)
};


+ 1
- 1
src/TensorFlowNET.Core/NumPy/NDArrayRender.cs View File

@@ -91,7 +91,7 @@ namespace Tensorflow.NumPy
.Take(25)
.Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString())) + "'";
else
return $"['{string.Join("', '", array.StringData().Take(25))}']";
return $"'{string.Join("', '", array.StringData().Take(25))}'";
}
else if (dtype == TF_DataType.TF_VARIANT)
{


+ 1
- 1
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -51,7 +51,7 @@ tf.net 0.6x.x aligns with TensorFlow v2.6.x native library.</PackageReleaseNotes

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<DefineConstants>TRACE;DEBUG</DefineConstants>
<DefineConstants>TRACE;DEBUG;TRACK_TENSOR_LIFE1</DefineConstants>
<PlatformTarget>x64</PlatformTarget>
<DocumentationFile>TensorFlow.NET.xml</DocumentationFile>
</PropertyGroup>


+ 7
- 4
src/TensorFlowNET.Core/Tensors/SafeStringTensorHandle.cs View File

@@ -8,7 +8,7 @@ namespace Tensorflow
public sealed class SafeStringTensorHandle : SafeTensorHandle
{
Shape _shape;
IntPtr _handle;
SafeTensorHandle _tensorHandle;
const int TF_TSRING_SIZE = 24;

protected SafeStringTensorHandle()
@@ -18,16 +18,18 @@ namespace Tensorflow
public SafeStringTensorHandle(SafeTensorHandle handle, Shape shape)
: base(handle.DangerousGetHandle())
{
_handle = c_api.TF_TensorData(handle);
_tensorHandle = handle;
_shape = shape;
bool success = false;
_tensorHandle.DangerousAddRef(ref success);
}

protected override bool ReleaseHandle()
{
var _handle = c_api.TF_TensorData(_tensorHandle);
#if TRACK_TENSOR_LIFE
print($"Delete StringTensorHandle 0x{handle.ToString("x16")}");
Console.WriteLine($"Delete StringTensorData 0x{_handle.ToString("x16")}");
#endif

for (int i = 0; i < _shape.size; i++)
{
c_api.TF_StringDealloc(_handle);
@@ -35,6 +37,7 @@ namespace Tensorflow
}

SetHandle(IntPtr.Zero);
_tensorHandle.DangerousRelease();

return true;
}


+ 32
- 2
src/TensorFlowNET.Core/Tensors/Tensor.String.cs View File

@@ -29,13 +29,13 @@ namespace Tensorflow

var tstr = c_api.TF_TensorData(handle);
#if TRACK_TENSOR_LIFE
print($"New TString 0x{handle.ToString("x16")} Data: 0x{tstr.ToString("x16")}");
print($"New StringTensor {handle} Data: 0x{tstr.ToString("x16")}");
#endif
for (int i = 0; i < buffer.Length; i++)
{
c_api.TF_StringInit(tstr);
c_api.TF_StringCopy(tstr, buffer[i], buffer[i].Length);
var data = c_api.TF_StringGetDataPointer(tstr);
// var data = c_api.TF_StringGetDataPointer(tstr);
tstr += TF_TSRING_SIZE;
}

@@ -53,6 +53,36 @@ namespace Tensorflow
return _str;
}

public string StringData(int index)
{
var bytes = StringBytes(index);
return Encoding.UTF8.GetString(bytes);
}

public byte[] StringBytes(int index)
{
if (dtype != TF_DataType.TF_STRING)
throw new InvalidOperationException($"Unable to call StringData when dtype != TF_DataType.TF_STRING (dtype is {dtype})");

byte[] buffer = new byte[0];
var tstrings = TensorDataPointer;
for (int i = 0; i < shape.size; i++)
{
if(index == i)
{
var data = c_api.TF_StringGetDataPointer(tstrings);
var len = c_api.TF_StringGetSize(tstrings);
buffer = new byte[len];
// var capacity = c_api.TF_StringGetCapacity(tstrings);
// var type = c_api.TF_StringGetType(tstrings);
Marshal.Copy(data, buffer, 0, Convert.ToInt32(len));
break;
}
tstrings += TF_TSRING_SIZE;
}
return buffer;
}

public byte[][] StringBytes()
{
if (dtype != TF_DataType.TF_STRING)


+ 1
- 1
src/TensorFlowNET.Keras/Utils/np_utils.cs View File

@@ -22,7 +22,7 @@ namespace Tensorflow.Keras.Utils
// categorical[np.arange(y.size), y] = 1;
for (var i = 0; i < (int)y.size; i++)
{
categorical[i][y1[i]] = 1.0f;
categorical[i, y1[i]] = 1.0f;
}

return categorical;


+ 3
- 4
test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs View File

@@ -51,12 +51,11 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
{
var strings = new[] { "map_and_batch_fusion", "noop_elimination", "shuffle_and_repeat_fusion" };
var tensor = tf.constant(strings, dtype: tf.@string, name: "optimizations");
var stringData = tensor.StringData();

Assert.AreEqual(3, tensor.shape[0]);
Assert.AreEqual(strings[0], stringData[0]);
Assert.AreEqual(strings[1], stringData[1]);
Assert.AreEqual(strings[2], stringData[2]);
Assert.AreEqual(tensor[0].numpy(), strings[0]);
Assert.AreEqual(tensor[1].numpy(), strings[1]);
Assert.AreEqual(tensor[2].numpy(), strings[2]);
}

[TestMethod]


Loading…
Cancel
Save