Browse Source

add SafeStringTensorHandle

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
3052e1fbd2
6 changed files with 54 additions and 35 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/NumPy/NDArray.Index.cs
  2. +47
    -0
      src/TensorFlowNET.Core/Tensors/SafeStringTensorHandle.cs
  3. +2
    -2
      src/TensorFlowNET.Core/Tensors/SafeTensorHandle.cs
  4. +0
    -16
      src/TensorFlowNET.Core/Tensors/TStringHandle.cs
  5. +3
    -3
      src/TensorFlowNET.Core/Tensors/Tensor.String.cs
  6. +0
    -12
      src/TensorFlowNET.Core/Tensors/Tensor.cs

+ 2
- 2
src/TensorFlowNET.Core/NumPy/NDArray.Index.cs View File

@@ -64,9 +64,9 @@ namespace Tensorflow.NumPy
if (tensor.Handle == null)
{
if (tf.executing_eagerly())
return new NDArray(tensor);
else
tensor = tf.defaultSession.eval(tensor);
else
return new NDArray(tensor);
}
return new NDArray(tensor);


+ 47
- 0
src/TensorFlowNET.Core/Tensors/SafeStringTensorHandle.cs View File

@@ -0,0 +1,47 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Util;

namespace Tensorflow
{
public sealed class SafeStringTensorHandle : SafeTensorHandle
{
Shape _shape;
SafeTensorHandle _handle;
const int TF_TSRING_SIZE = 24;

protected SafeStringTensorHandle()
{
}

public SafeStringTensorHandle(SafeTensorHandle handle, Shape shape)
: base(handle.DangerousGetHandle())
{
_handle = handle;
_shape = shape;
}

protected override bool ReleaseHandle()
{
#if TRACK_TENSOR_LIFE
print($"Delete StringTensorHandle 0x{handle.ToString("x16")}");
#endif

long size = 1;
foreach (var s in _shape.dims)
size *= s;
var tstr = c_api.TF_TensorData(_handle);

for (int i = 0; i < size; i++)
{
c_api.TF_StringDealloc(tstr);
tstr += TF_TSRING_SIZE;
}

SetHandle(IntPtr.Zero);

return true;
}
}
}

+ 2
- 2
src/TensorFlowNET.Core/Tensors/SafeTensorHandle.cs View File

@@ -20,9 +20,9 @@ using static Tensorflow.Binding;

namespace Tensorflow
{
public sealed class SafeTensorHandle : SafeTensorflowHandle
public class SafeTensorHandle : SafeTensorflowHandle
{
private SafeTensorHandle()
protected SafeTensorHandle()
{
}



+ 0
- 16
src/TensorFlowNET.Core/Tensors/TStringHandle.cs View File

@@ -1,16 +0,0 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Util;

namespace Tensorflow
{
public class TStringHandle : SafeTensorflowHandle
{
protected override bool ReleaseHandle()
{
c_api.TF_StringDealloc(handle);
return true;
}
}
}

+ 3
- 3
src/TensorFlowNET.Core/Tensors/Tensor.String.cs View File

@@ -10,7 +10,7 @@ namespace Tensorflow
{
const int TF_TSRING_SIZE = 24;

public SafeTensorHandle StringTensor(string[] strings, Shape shape)
public SafeStringTensorHandle StringTensor(string[] strings, Shape shape)
{
// convert string array to byte[][]
var buffer = new byte[strings.Length][];
@@ -20,7 +20,7 @@ namespace Tensorflow
return StringTensor(buffer, shape);
}

public SafeTensorHandle StringTensor(byte[][] buffer, Shape shape)
public SafeStringTensorHandle StringTensor(byte[][] buffer, Shape shape)
{
var handle = c_api.TF_AllocateTensor(TF_DataType.TF_STRING,
shape.ndim == 0 ? null : shape.dims,
@@ -39,7 +39,7 @@ namespace Tensorflow
tstr += TF_TSRING_SIZE;
}

return handle;
return new SafeStringTensorHandle(handle, shape);
}

public string[] StringData()


+ 0
- 12
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -255,19 +255,7 @@ namespace Tensorflow

protected override void DisposeUnmanagedResources(IntPtr handle)
{
if (dtype == TF_DataType.TF_STRING)
{
long size = 1;
foreach (var s in shape.dims)
size *= s;
var tstr = TensorDataPointer;

for (int i = 0; i < size; i++)
{
c_api.TF_StringDealloc(tstr);
tstr += TF_TSRING_SIZE;
}
}
}

public bool IsDisposed => _disposed;


Loading…
Cancel
Save