@@ -60,6 +60,9 @@ namespace Tensorflow.Eager | |||||
{ | { | ||||
_id = ops.uid(); | _id = ops.uid(); | ||||
_eagerTensorHandle = c_api.TFE_NewTensorHandle(h, tf.Status.Handle); | _eagerTensorHandle = c_api.TFE_NewTensorHandle(h, tf.Status.Handle); | ||||
#if TRACK_TENSOR_LIFE | |||||
Console.WriteLine($"New EagerTensor {_eagerTensorHandle}"); | |||||
#endif | |||||
tf.Status.Check(true); | tf.Status.Check(true); | ||||
} | } | ||||
@@ -16,6 +16,7 @@ namespace Tensorflow.NumPy | |||||
long val => GetAtIndex<long>(0) == val, | long val => GetAtIndex<long>(0) == val, | ||||
float val => GetAtIndex<float>(0) == val, | float val => GetAtIndex<float>(0) == val, | ||||
double val => GetAtIndex<double>(0) == val, | double val => GetAtIndex<double>(0) == val, | ||||
string val => StringData(0) == val, | |||||
NDArray val => Equals(this, val), | NDArray val => Equals(this, val), | ||||
_ => base.Equals(obj) | _ => base.Equals(obj) | ||||
}; | }; | ||||
@@ -91,7 +91,7 @@ namespace Tensorflow.NumPy | |||||
.Take(25) | .Take(25) | ||||
.Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString())) + "'"; | .Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString())) + "'"; | ||||
else | else | ||||
return $"['{string.Join("', '", array.StringData().Take(25))}']"; | |||||
return $"'{string.Join("', '", array.StringData().Take(25))}'"; | |||||
} | } | ||||
else if (dtype == TF_DataType.TF_VARIANT) | else if (dtype == TF_DataType.TF_VARIANT) | ||||
{ | { | ||||
@@ -51,7 +51,7 @@ tf.net 0.6x.x aligns with TensorFlow v2.6.x native library.</PackageReleaseNotes | |||||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> | <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> | ||||
<AllowUnsafeBlocks>true</AllowUnsafeBlocks> | <AllowUnsafeBlocks>true</AllowUnsafeBlocks> | ||||
<DefineConstants>TRACE;DEBUG</DefineConstants> | |||||
<DefineConstants>TRACE;DEBUG;TRACK_TENSOR_LIFE1</DefineConstants> | |||||
<PlatformTarget>x64</PlatformTarget> | <PlatformTarget>x64</PlatformTarget> | ||||
<DocumentationFile>TensorFlow.NET.xml</DocumentationFile> | <DocumentationFile>TensorFlow.NET.xml</DocumentationFile> | ||||
</PropertyGroup> | </PropertyGroup> | ||||
@@ -8,7 +8,7 @@ namespace Tensorflow | |||||
public sealed class SafeStringTensorHandle : SafeTensorHandle | public sealed class SafeStringTensorHandle : SafeTensorHandle | ||||
{ | { | ||||
Shape _shape; | Shape _shape; | ||||
IntPtr _handle; | |||||
SafeTensorHandle _tensorHandle; | |||||
const int TF_TSRING_SIZE = 24; | const int TF_TSRING_SIZE = 24; | ||||
protected SafeStringTensorHandle() | protected SafeStringTensorHandle() | ||||
@@ -18,16 +18,18 @@ namespace Tensorflow | |||||
public SafeStringTensorHandle(SafeTensorHandle handle, Shape shape) | public SafeStringTensorHandle(SafeTensorHandle handle, Shape shape) | ||||
: base(handle.DangerousGetHandle()) | : base(handle.DangerousGetHandle()) | ||||
{ | { | ||||
_handle = c_api.TF_TensorData(handle); | |||||
_tensorHandle = handle; | |||||
_shape = shape; | _shape = shape; | ||||
bool success = false; | |||||
_tensorHandle.DangerousAddRef(ref success); | |||||
} | } | ||||
protected override bool ReleaseHandle() | protected override bool ReleaseHandle() | ||||
{ | { | ||||
var _handle = c_api.TF_TensorData(_tensorHandle); | |||||
#if TRACK_TENSOR_LIFE | #if TRACK_TENSOR_LIFE | ||||
print($"Delete StringTensorHandle 0x{handle.ToString("x16")}"); | |||||
Console.WriteLine($"Delete StringTensorData 0x{_handle.ToString("x16")}"); | |||||
#endif | #endif | ||||
for (int i = 0; i < _shape.size; i++) | for (int i = 0; i < _shape.size; i++) | ||||
{ | { | ||||
c_api.TF_StringDealloc(_handle); | c_api.TF_StringDealloc(_handle); | ||||
@@ -35,6 +37,7 @@ namespace Tensorflow | |||||
} | } | ||||
SetHandle(IntPtr.Zero); | SetHandle(IntPtr.Zero); | ||||
_tensorHandle.DangerousRelease(); | |||||
return true; | return true; | ||||
} | } | ||||
@@ -29,13 +29,13 @@ namespace Tensorflow | |||||
var tstr = c_api.TF_TensorData(handle); | var tstr = c_api.TF_TensorData(handle); | ||||
#if TRACK_TENSOR_LIFE | #if TRACK_TENSOR_LIFE | ||||
print($"New TString 0x{handle.ToString("x16")} Data: 0x{tstr.ToString("x16")}"); | |||||
print($"New StringTensor {handle} Data: 0x{tstr.ToString("x16")}"); | |||||
#endif | #endif | ||||
for (int i = 0; i < buffer.Length; i++) | for (int i = 0; i < buffer.Length; i++) | ||||
{ | { | ||||
c_api.TF_StringInit(tstr); | c_api.TF_StringInit(tstr); | ||||
c_api.TF_StringCopy(tstr, buffer[i], buffer[i].Length); | c_api.TF_StringCopy(tstr, buffer[i], buffer[i].Length); | ||||
var data = c_api.TF_StringGetDataPointer(tstr); | |||||
// var data = c_api.TF_StringGetDataPointer(tstr); | |||||
tstr += TF_TSRING_SIZE; | tstr += TF_TSRING_SIZE; | ||||
} | } | ||||
@@ -53,6 +53,36 @@ namespace Tensorflow | |||||
return _str; | return _str; | ||||
} | } | ||||
public string StringData(int index) | |||||
{ | |||||
var bytes = StringBytes(index); | |||||
return Encoding.UTF8.GetString(bytes); | |||||
} | |||||
public byte[] StringBytes(int index) | |||||
{ | |||||
if (dtype != TF_DataType.TF_STRING) | |||||
throw new InvalidOperationException($"Unable to call StringData when dtype != TF_DataType.TF_STRING (dtype is {dtype})"); | |||||
byte[] buffer = new byte[0]; | |||||
var tstrings = TensorDataPointer; | |||||
for (int i = 0; i < shape.size; i++) | |||||
{ | |||||
if(index == i) | |||||
{ | |||||
var data = c_api.TF_StringGetDataPointer(tstrings); | |||||
var len = c_api.TF_StringGetSize(tstrings); | |||||
buffer = new byte[len]; | |||||
// var capacity = c_api.TF_StringGetCapacity(tstrings); | |||||
// var type = c_api.TF_StringGetType(tstrings); | |||||
Marshal.Copy(data, buffer, 0, Convert.ToInt32(len)); | |||||
break; | |||||
} | |||||
tstrings += TF_TSRING_SIZE; | |||||
} | |||||
return buffer; | |||||
} | |||||
public byte[][] StringBytes() | public byte[][] StringBytes() | ||||
{ | { | ||||
if (dtype != TF_DataType.TF_STRING) | if (dtype != TF_DataType.TF_STRING) | ||||
@@ -22,7 +22,7 @@ namespace Tensorflow.Keras.Utils | |||||
// categorical[np.arange(y.size), y] = 1; | // categorical[np.arange(y.size), y] = 1; | ||||
for (var i = 0; i < (int)y.size; i++) | for (var i = 0; i < (int)y.size; i++) | ||||
{ | { | ||||
categorical[i][y1[i]] = 1.0f; | |||||
categorical[i, y1[i]] = 1.0f; | |||||
} | } | ||||
return categorical; | return categorical; | ||||
@@ -51,12 +51,11 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||||
{ | { | ||||
var strings = new[] { "map_and_batch_fusion", "noop_elimination", "shuffle_and_repeat_fusion" }; | var strings = new[] { "map_and_batch_fusion", "noop_elimination", "shuffle_and_repeat_fusion" }; | ||||
var tensor = tf.constant(strings, dtype: tf.@string, name: "optimizations"); | var tensor = tf.constant(strings, dtype: tf.@string, name: "optimizations"); | ||||
var stringData = tensor.StringData(); | |||||
Assert.AreEqual(3, tensor.shape[0]); | Assert.AreEqual(3, tensor.shape[0]); | ||||
Assert.AreEqual(strings[0], stringData[0]); | |||||
Assert.AreEqual(strings[1], stringData[1]); | |||||
Assert.AreEqual(strings[2], stringData[2]); | |||||
Assert.AreEqual(tensor[0].numpy(), strings[0]); | |||||
Assert.AreEqual(tensor[1].numpy(), strings[1]); | |||||
Assert.AreEqual(tensor[2].numpy(), strings[2]); | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||