Browse Source

ndarray string comparison.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
6adcfaea3f
8 changed files with 49 additions and 13 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
  2. +1
    -0
      src/TensorFlowNET.Core/NumPy/NDArray.Equal.cs
  3. +1
    -1
      src/TensorFlowNET.Core/NumPy/NDArrayRender.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  5. +7
    -4
      src/TensorFlowNET.Core/Tensors/SafeStringTensorHandle.cs
  6. +32
    -2
      src/TensorFlowNET.Core/Tensors/Tensor.String.cs
  7. +1
    -1
      src/TensorFlowNET.Keras/Utils/np_utils.cs
  8. +3
    -4
      test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs

+ 3
- 0
src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs View File

@@ -60,6 +60,9 @@ namespace Tensorflow.Eager
{ {
_id = ops.uid(); _id = ops.uid();
_eagerTensorHandle = c_api.TFE_NewTensorHandle(h, tf.Status.Handle); _eagerTensorHandle = c_api.TFE_NewTensorHandle(h, tf.Status.Handle);
#if TRACK_TENSOR_LIFE
Console.WriteLine($"New EagerTensor {_eagerTensorHandle}");
#endif
tf.Status.Check(true); tf.Status.Check(true);
} }




+ 1
- 0
src/TensorFlowNET.Core/NumPy/NDArray.Equal.cs View File

@@ -16,6 +16,7 @@ namespace Tensorflow.NumPy
long val => GetAtIndex<long>(0) == val, long val => GetAtIndex<long>(0) == val,
float val => GetAtIndex<float>(0) == val, float val => GetAtIndex<float>(0) == val,
double val => GetAtIndex<double>(0) == val, double val => GetAtIndex<double>(0) == val,
string val => StringData(0) == val,
NDArray val => Equals(this, val), NDArray val => Equals(this, val),
_ => base.Equals(obj) _ => base.Equals(obj)
}; };


+ 1
- 1
src/TensorFlowNET.Core/NumPy/NDArrayRender.cs View File

@@ -91,7 +91,7 @@ namespace Tensorflow.NumPy
.Take(25) .Take(25)
.Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString())) + "'"; .Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString())) + "'";
else else
return $"['{string.Join("', '", array.StringData().Take(25))}']";
return $"'{string.Join("', '", array.StringData().Take(25))}'";
} }
else if (dtype == TF_DataType.TF_VARIANT) else if (dtype == TF_DataType.TF_VARIANT)
{ {


+ 1
- 1
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -51,7 +51,7 @@ tf.net 0.6x.x aligns with TensorFlow v2.6.x native library.</PackageReleaseNotes


<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> <AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<DefineConstants>TRACE;DEBUG</DefineConstants>
<DefineConstants>TRACE;DEBUG;TRACK_TENSOR_LIFE1</DefineConstants>
<PlatformTarget>x64</PlatformTarget> <PlatformTarget>x64</PlatformTarget>
<DocumentationFile>TensorFlow.NET.xml</DocumentationFile> <DocumentationFile>TensorFlow.NET.xml</DocumentationFile>
</PropertyGroup> </PropertyGroup>


+ 7
- 4
src/TensorFlowNET.Core/Tensors/SafeStringTensorHandle.cs View File

@@ -8,7 +8,7 @@ namespace Tensorflow
public sealed class SafeStringTensorHandle : SafeTensorHandle public sealed class SafeStringTensorHandle : SafeTensorHandle
{ {
Shape _shape; Shape _shape;
IntPtr _handle;
SafeTensorHandle _tensorHandle;
const int TF_TSRING_SIZE = 24; const int TF_TSRING_SIZE = 24;


protected SafeStringTensorHandle() protected SafeStringTensorHandle()
@@ -18,16 +18,18 @@ namespace Tensorflow
public SafeStringTensorHandle(SafeTensorHandle handle, Shape shape) public SafeStringTensorHandle(SafeTensorHandle handle, Shape shape)
: base(handle.DangerousGetHandle()) : base(handle.DangerousGetHandle())
{ {
_handle = c_api.TF_TensorData(handle);
_tensorHandle = handle;
_shape = shape; _shape = shape;
bool success = false;
_tensorHandle.DangerousAddRef(ref success);
} }


protected override bool ReleaseHandle() protected override bool ReleaseHandle()
{ {
var _handle = c_api.TF_TensorData(_tensorHandle);
#if TRACK_TENSOR_LIFE #if TRACK_TENSOR_LIFE
print($"Delete StringTensorHandle 0x{handle.ToString("x16")}");
Console.WriteLine($"Delete StringTensorData 0x{_handle.ToString("x16")}");
#endif #endif

for (int i = 0; i < _shape.size; i++) for (int i = 0; i < _shape.size; i++)
{ {
c_api.TF_StringDealloc(_handle); c_api.TF_StringDealloc(_handle);
@@ -35,6 +37,7 @@ namespace Tensorflow
} }


SetHandle(IntPtr.Zero); SetHandle(IntPtr.Zero);
_tensorHandle.DangerousRelease();


return true; return true;
} }


+ 32
- 2
src/TensorFlowNET.Core/Tensors/Tensor.String.cs View File

@@ -29,13 +29,13 @@ namespace Tensorflow


var tstr = c_api.TF_TensorData(handle); var tstr = c_api.TF_TensorData(handle);
#if TRACK_TENSOR_LIFE #if TRACK_TENSOR_LIFE
print($"New TString 0x{handle.ToString("x16")} Data: 0x{tstr.ToString("x16")}");
print($"New StringTensor {handle} Data: 0x{tstr.ToString("x16")}");
#endif #endif
for (int i = 0; i < buffer.Length; i++) for (int i = 0; i < buffer.Length; i++)
{ {
c_api.TF_StringInit(tstr); c_api.TF_StringInit(tstr);
c_api.TF_StringCopy(tstr, buffer[i], buffer[i].Length); c_api.TF_StringCopy(tstr, buffer[i], buffer[i].Length);
var data = c_api.TF_StringGetDataPointer(tstr);
// var data = c_api.TF_StringGetDataPointer(tstr);
tstr += TF_TSRING_SIZE; tstr += TF_TSRING_SIZE;
} }


@@ -53,6 +53,36 @@ namespace Tensorflow
return _str; return _str;
} }


public string StringData(int index)
{
var bytes = StringBytes(index);
return Encoding.UTF8.GetString(bytes);
}

public byte[] StringBytes(int index)
{
if (dtype != TF_DataType.TF_STRING)
throw new InvalidOperationException($"Unable to call StringData when dtype != TF_DataType.TF_STRING (dtype is {dtype})");

byte[] buffer = new byte[0];
var tstrings = TensorDataPointer;
for (int i = 0; i < shape.size; i++)
{
if(index == i)
{
var data = c_api.TF_StringGetDataPointer(tstrings);
var len = c_api.TF_StringGetSize(tstrings);
buffer = new byte[len];
// var capacity = c_api.TF_StringGetCapacity(tstrings);
// var type = c_api.TF_StringGetType(tstrings);
Marshal.Copy(data, buffer, 0, Convert.ToInt32(len));
break;
}
tstrings += TF_TSRING_SIZE;
}
return buffer;
}

public byte[][] StringBytes() public byte[][] StringBytes()
{ {
if (dtype != TF_DataType.TF_STRING) if (dtype != TF_DataType.TF_STRING)


+ 1
- 1
src/TensorFlowNET.Keras/Utils/np_utils.cs View File

@@ -22,7 +22,7 @@ namespace Tensorflow.Keras.Utils
// categorical[np.arange(y.size), y] = 1; // categorical[np.arange(y.size), y] = 1;
for (var i = 0; i < (int)y.size; i++) for (var i = 0; i < (int)y.size; i++)
{ {
categorical[i][y1[i]] = 1.0f;
categorical[i, y1[i]] = 1.0f;
} }


return categorical; return categorical;


+ 3
- 4
test/TensorFlowNET.UnitTest/ManagedAPI/StringsApiTest.cs View File

@@ -51,12 +51,11 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
{ {
var strings = new[] { "map_and_batch_fusion", "noop_elimination", "shuffle_and_repeat_fusion" }; var strings = new[] { "map_and_batch_fusion", "noop_elimination", "shuffle_and_repeat_fusion" };
var tensor = tf.constant(strings, dtype: tf.@string, name: "optimizations"); var tensor = tf.constant(strings, dtype: tf.@string, name: "optimizations");
var stringData = tensor.StringData();


Assert.AreEqual(3, tensor.shape[0]); Assert.AreEqual(3, tensor.shape[0]);
Assert.AreEqual(strings[0], stringData[0]);
Assert.AreEqual(strings[1], stringData[1]);
Assert.AreEqual(strings[2], stringData[2]);
Assert.AreEqual(tensor[0].numpy(), strings[0]);
Assert.AreEqual(tensor[1].numpy(), strings[1]);
Assert.AreEqual(tensor[2].numpy(), strings[2]);
} }


[TestMethod] [TestMethod]


Loading…
Cancel
Save