@@ -513,10 +513,13 @@ namespace Tensorflow | |||||
if (data is NDArray nd) | if (data is NDArray nd) | ||||
return nd.shape; | return nd.shape; | ||||
if (data is Tensor tensor) | |||||
else if (data is Tensor tensor) | |||||
return tensor.shape; | return tensor.shape; | ||||
if (!data.GetType().IsArray) | |||||
else if (data is Axis axis) | |||||
return axis.IsScalar ? Shape.Scalar : new Shape(axis.axis); | |||||
else if (!data.GetType().IsArray) | |||||
return Shape.Scalar; | return Shape.Scalar; | ||||
switch (data) | switch (data) | ||||
@@ -17,7 +17,7 @@ namespace Tensorflow | |||||
NumOfExamples = (int)images.dims[0]; | NumOfExamples = (int)images.dims[0]; | ||||
images = images.reshape((images.dims[0], images.dims[1] * images.dims[2])); | |||||
// images = images.reshape((images.dims[0], images.dims[1] * images.dims[2])); | |||||
images = images.astype(dataType); | images = images.astype(dataType); | ||||
// for debug np.multiply performance | // for debug np.multiply performance | ||||
var sw = new Stopwatch(); | var sw = new Stopwatch(); | ||||
@@ -123,9 +123,7 @@ namespace Tensorflow | |||||
bytestream.Read(buf, 0, buf.Length); | bytestream.Read(buf, 0, buf.Length); | ||||
var data = np.frombuffer(buf, new Shape(buf.Length), np.@byte); | |||||
data = data.reshape((num_images, rows, cols, 1)); | |||||
var data = np.frombuffer(buf, (num_images, rows * cols), np.@byte); | |||||
return data; | return data; | ||||
} | } | ||||
} | } | ||||
@@ -24,6 +24,7 @@ namespace Tensorflow | |||||
public record Axis(params int[] axis) | public record Axis(params int[] axis) | ||||
{ | { | ||||
public int size => axis == null ? -1 : axis.Length; | public int size => axis == null ? -1 : axis.Length; | ||||
public bool IsScalar { get; init; } | |||||
public int this[int index] => axis[index]; | public int this[int index] => axis[index]; | ||||
@@ -34,7 +35,7 @@ namespace Tensorflow | |||||
=> axis.axis[0]; | => axis.axis[0]; | ||||
public static implicit operator Axis(int axis) | public static implicit operator Axis(int axis) | ||||
=> new Axis(axis); | |||||
=> new Axis(axis) { IsScalar = true }; | |||||
public static implicit operator Axis((int, int) axis) | public static implicit operator Axis((int, int) axis) | ||||
=> new Axis(axis.Item1, axis.Item2); | => new Axis(axis.Item1, axis.Item2); | ||||
@@ -15,7 +15,7 @@ namespace Tensorflow.NumPy | |||||
Start = x, | Start = x, | ||||
Stop = x + 1, | Stop = x + 1, | ||||
IsIndex = true | IsIndex = true | ||||
})); | |||||
}).ToArray()); | |||||
set => SetData(indices.Select(x => | set => SetData(indices.Select(x => | ||||
{ | { | ||||
@@ -55,21 +55,58 @@ namespace Tensorflow.NumPy | |||||
} | } | ||||
} | } | ||||
NDArray GetData(IEnumerable<Slice> slices) | |||||
unsafe NDArray GetData(Slice[] slices) | |||||
{ | { | ||||
if (shape.IsScalar) | if (shape.IsScalar) | ||||
return GetScalar(); | return GetScalar(); | ||||
if (SliceHelper.AreAllIndex(slices, out var indices1)) | |||||
{ | |||||
var newshape = ShapeHelper.GetShape(shape, slices); | |||||
if (newshape.IsScalar) | |||||
{ | |||||
var offset = ShapeHelper.GetOffset(shape, indices1); | |||||
return GetScalar((ulong)offset); | |||||
} | |||||
else | |||||
{ | |||||
return GetArrayData(newshape, indices1); | |||||
} | |||||
} | |||||
else if (slices.Count() == 1) | |||||
{ | |||||
var slice = slices[0]; | |||||
if (slice.Step == 1) | |||||
{ | |||||
var newshape = ShapeHelper.GetShape(shape, slice); | |||||
var array = new NDArray(newshape, dtype: dtype); | |||||
var new_dims = new int[shape.ndim]; | |||||
new_dims[0] = slice.Start ?? 0; | |||||
//for (int i = 1; i < shape.ndim; i++) | |||||
//new_dims[i] = (int)shape.dims[i]; | |||||
var offset = ShapeHelper.GetOffset(shape, new_dims); | |||||
var src = (byte*)data + (ulong)offset * dtypesize; | |||||
var dst = (byte*)array.data; | |||||
var len = (ulong)newshape.size * dtypesize; | |||||
System.Buffer.MemoryCopy(src, dst, len, len); | |||||
return array; | |||||
} | |||||
} | |||||
// default, performance is bad | |||||
var tensor = base[slices.ToArray()]; | var tensor = base[slices.ToArray()]; | ||||
if (tensor.Handle == null) | if (tensor.Handle == null) | ||||
{ | { | ||||
if (tf.executing_eagerly()) | if (tf.executing_eagerly()) | ||||
tensor = tf.defaultSession.eval(tensor); | tensor = tf.defaultSession.eval(tensor); | ||||
else | |||||
return new NDArray(tensor); | |||||
} | } | ||||
return new NDArray(tensor); | |||||
return new NDArray(tensor, tf.executing_eagerly()); | |||||
} | } | ||||
unsafe T GetAtIndex<T>(params int[] indices) where T : unmanaged | unsafe T GetAtIndex<T>(params int[] indices) where T : unmanaged | ||||
@@ -78,17 +115,26 @@ namespace Tensorflow.NumPy | |||||
return *((T*)data + offset); | return *((T*)data + offset); | ||||
} | } | ||||
NDArray GetScalar() | |||||
unsafe NDArray GetScalar(ulong offset = 0) | |||||
{ | { | ||||
var array = new NDArray(Shape.Scalar, dtype: dtype); | var array = new NDArray(Shape.Scalar, dtype: dtype); | ||||
unsafe | |||||
{ | |||||
var src = (byte*)data + dtypesize; | |||||
System.Buffer.MemoryCopy(src, array.buffer.ToPointer(), bytesize, bytesize); | |||||
} | |||||
var src = (byte*)data + offset * dtypesize; | |||||
System.Buffer.MemoryCopy(src, array.buffer.ToPointer(), dtypesize, dtypesize); | |||||
return array; | return array; | ||||
} | } | ||||
unsafe NDArray GetArrayData(Shape newshape, int[] indices) | |||||
{ | |||||
var offset = ShapeHelper.GetOffset(shape, indices); | |||||
var len = (ulong)newshape.size * dtypesize; | |||||
var array = new NDArray(newshape, dtype: dtype); | |||||
var src = (byte*)data + (ulong)offset * dtypesize; | |||||
System.Buffer.MemoryCopy(src, array.data.ToPointer(), len, len); | |||||
return array; | |||||
} | |||||
NDArray GetData(int[] indices, int axis = 0) | NDArray GetData(int[] indices, int axis = 0) | ||||
{ | { | ||||
if (shape.IsScalar) | if (shape.IsScalar) | ||||
@@ -5,7 +5,7 @@ using System.Text; | |||||
namespace Tensorflow.NumPy | namespace Tensorflow.NumPy | ||||
{ | { | ||||
internal class ShapeHelper | |||||
public class ShapeHelper | |||||
{ | { | ||||
public static long GetSize(Shape shape) | public static long GetSize(Shape shape) | ||||
{ | { | ||||
@@ -41,6 +41,34 @@ namespace Tensorflow.NumPy | |||||
return strides; | return strides; | ||||
} | } | ||||
public static Shape GetShape(Shape shape1, params Slice[] slices) | |||||
{ | |||||
var new_dims = shape1.dims.ToArray(); | |||||
slices = SliceHelper.AlignWithShape(shape1, slices); | |||||
for (int i = 0; i < shape1.dims.Length; i++) | |||||
{ | |||||
Slice slice = slices[i]; | |||||
if (slice.Equals(Slice.All)) | |||||
new_dims[i] = shape1.dims[i]; | |||||
else if (slice.IsIndex) | |||||
new_dims[i] = 1; | |||||
else // range | |||||
new_dims[i] = (slice.Stop ?? shape1.dims[i]) - (slice.Start ?? 0); | |||||
} | |||||
// strip first dim if is index | |||||
var return_dims = new List<long>(); | |||||
for (int i = 0; i< new_dims.Length; i++) | |||||
{ | |||||
if (slices[i].IsIndex) | |||||
continue; | |||||
return_dims.add(new_dims[i]); | |||||
} | |||||
return new Shape(return_dims.ToArray()); | |||||
} | |||||
public static bool Equals(Shape shape, object target) | public static bool Equals(Shape shape, object target) | ||||
{ | { | ||||
switch (target) | switch (target) | ||||
@@ -0,0 +1,56 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
namespace Tensorflow.NumPy | |||||
{ | |||||
public class SliceHelper | |||||
{ | |||||
public static Slice[] AlignWithShape(Shape shape, Slice[] slices) | |||||
{ | |||||
// align slices | |||||
var ndim = shape.ndim; | |||||
var new_slices = new List<Slice>(); | |||||
var slice_index = 0; | |||||
for (int i = 0; i < ndim; i++) | |||||
{ | |||||
if (slice_index > slices.Length - 1) | |||||
{ | |||||
new_slices.Add(Slice.All); | |||||
continue; | |||||
} | |||||
if (slices[slice_index] == Slice.All) | |||||
{ | |||||
new_slices.Add(Slice.All); | |||||
for (int j = 0; j < ndim - slices.Length; j++) | |||||
{ | |||||
new_slices.Add(Slice.All); | |||||
i++; | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
new_slices.Add(slices[slice_index]); | |||||
} | |||||
slice_index++; | |||||
} | |||||
return new_slices.ToArray(); | |||||
} | |||||
public static bool AreAllIndex(Slice[] slices, out int[] indices) | |||||
{ | |||||
indices = new int[slices.Length]; | |||||
for (int i = 0; i< slices.Length; i++) | |||||
{ | |||||
indices[i] = slices[i].Start ?? 0; | |||||
if (!slices[i].IsIndex) | |||||
return false; | |||||
} | |||||
return true; | |||||
} | |||||
} | |||||
} |
@@ -28,7 +28,7 @@ namespace Tensorflow.NumPy | |||||
public NDArray(IntPtr address, Shape shape, TF_DataType dtype) | public NDArray(IntPtr address, Shape shape, TF_DataType dtype) | ||||
: base(address, shape, dtype) { NewEagerTensorHandle(); } | : base(address, shape, dtype) { NewEagerTensorHandle(); } | ||||
public NDArray(Tensor tensor) : base(tensor.Handle) | |||||
public NDArray(Tensor tensor, bool eval = true) : base(tensor.Handle) | |||||
{ | { | ||||
if (_handle is null) | if (_handle is null) | ||||
{ | { | ||||
@@ -53,9 +53,12 @@ namespace Tensorflow.NumPy | |||||
void NewEagerTensorHandle() | void NewEagerTensorHandle() | ||||
{ | { | ||||
_id = ops.uid(); | |||||
_eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle); | |||||
tf.Status.Check(true); | |||||
if(_handle is not null) | |||||
{ | |||||
_id = ops.uid(); | |||||
_eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle); | |||||
tf.Status.Check(true); | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -115,11 +115,12 @@ namespace Tensorflow | |||||
/// <param name="start">Start index of the slice, null means from the start of the array</param> | /// <param name="start">Start index of the slice, null means from the start of the array</param> | ||||
/// <param name="stop">Stop index (first index after end of slice), null means to the end of the array</param> | /// <param name="stop">Stop index (first index after end of slice), null means to the end of the array</param> | ||||
/// <param name="step">Optional step to select every n-th element, defaults to 1</param> | /// <param name="step">Optional step to select every n-th element, defaults to 1</param> | ||||
public Slice(int? start = null, int? stop = null, int step = 1) | |||||
public Slice(int? start = null, int? stop = null, int step = 1, bool isIndex = false) | |||||
{ | { | ||||
Start = start; | Start = start; | ||||
Stop = stop; | Stop = stop; | ||||
Step = step; | Step = step; | ||||
IsIndex = isIndex; | |||||
} | } | ||||
public Slice(string slice_notation) | public Slice(string slice_notation) | ||||
@@ -8,7 +8,7 @@ namespace Tensorflow | |||||
public sealed class SafeStringTensorHandle : SafeTensorHandle | public sealed class SafeStringTensorHandle : SafeTensorHandle | ||||
{ | { | ||||
Shape _shape; | Shape _shape; | ||||
SafeTensorHandle _handle; | |||||
IntPtr _handle; | |||||
const int TF_TSRING_SIZE = 24; | const int TF_TSRING_SIZE = 24; | ||||
protected SafeStringTensorHandle() | protected SafeStringTensorHandle() | ||||
@@ -18,7 +18,7 @@ namespace Tensorflow | |||||
public SafeStringTensorHandle(SafeTensorHandle handle, Shape shape) | public SafeStringTensorHandle(SafeTensorHandle handle, Shape shape) | ||||
: base(handle.DangerousGetHandle()) | : base(handle.DangerousGetHandle()) | ||||
{ | { | ||||
_handle = handle; | |||||
_handle = c_api.TF_TensorData(handle); | |||||
_shape = shape; | _shape = shape; | ||||
} | } | ||||
@@ -28,15 +28,10 @@ namespace Tensorflow | |||||
print($"Delete StringTensorHandle 0x{handle.ToString("x16")}"); | print($"Delete StringTensorHandle 0x{handle.ToString("x16")}"); | ||||
#endif | #endif | ||||
long size = 1; | |||||
foreach (var s in _shape.dims) | |||||
size *= s; | |||||
var tstr = c_api.TF_TensorData(_handle); | |||||
for (int i = 0; i < size; i++) | |||||
for (int i = 0; i < _shape.size; i++) | |||||
{ | { | ||||
c_api.TF_StringDealloc(tstr); | |||||
tstr += TF_TSRING_SIZE; | |||||
c_api.TF_StringDealloc(_handle); | |||||
_handle += TF_TSRING_SIZE; | |||||
} | } | ||||
SetHandle(IntPtr.Zero); | SetHandle(IntPtr.Zero); | ||||
@@ -23,7 +23,7 @@ namespace Tensorflow | |||||
public SafeStringTensorHandle StringTensor(byte[][] buffer, Shape shape) | public SafeStringTensorHandle StringTensor(byte[][] buffer, Shape shape) | ||||
{ | { | ||||
var handle = c_api.TF_AllocateTensor(TF_DataType.TF_STRING, | var handle = c_api.TF_AllocateTensor(TF_DataType.TF_STRING, | ||||
shape.ndim == 0 ? null : shape.dims, | |||||
shape.dims, | |||||
shape.ndim, | shape.ndim, | ||||
(ulong)shape.size * TF_TSRING_SIZE); | (ulong)shape.size * TF_TSRING_SIZE); | ||||
@@ -472,6 +472,9 @@ would not be rank 1.", tensor.op.get_attr("axis"))); | |||||
public static string to_numpy_string(Tensor tensor) | public static string to_numpy_string(Tensor tensor) | ||||
{ | { | ||||
if (tensor.buffer == IntPtr.Zero) | |||||
return "Empty"; | |||||
var dtype = tensor.dtype; | var dtype = tensor.dtype; | ||||
var shape = tensor.shape; | var shape = tensor.shape; | ||||
@@ -161,7 +161,7 @@ namespace Tensorflow | |||||
IEnumerable<Tensor> tensors => array_ops._autopacking_helper(tensors, dtype, name == null ? "packed" : name), | IEnumerable<Tensor> tensors => array_ops._autopacking_helper(tensors, dtype, name == null ? "packed" : name), | ||||
RefVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref), | RefVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref), | ||||
ResourceVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref), | ResourceVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref), | ||||
Axis ts => constant_op.constant(ts.axis, dtype: dtype, name: name), | |||||
Axis ts => constant_op.constant(ts, dtype: dtype, name: name), | |||||
Shape ts => constant_op.constant(ts.dims, dtype: dtype, name: name), | Shape ts => constant_op.constant(ts.dims, dtype: dtype, name: name), | ||||
string str => constant_op.constant(str, dtype: tf.@string, name: name), | string str => constant_op.constant(str, dtype: tf.@string, name: name), | ||||
string[] str => constant_op.constant(str, dtype: tf.@string, name: name), | string[] str => constant_op.constant(str, dtype: tf.@string, name: name), | ||||
@@ -63,5 +63,43 @@ namespace TensorFlowNET.UnitTest.NumPy | |||||
input_shape_val[(int)input_shape.size - 1] = 1; | input_shape_val[(int)input_shape.size - 1] = 1; | ||||
input_shape.Dispose(); | input_shape.Dispose(); | ||||
} | } | ||||
[TestMethod] | |||||
public void shape_helper_get_shape_3dim() | |||||
{ | |||||
var x = np.arange(24).reshape((4, 3, 2)); | |||||
var shape1 = ShapeHelper.GetShape(x.shape, new Slice(1, isIndex: true)); | |||||
Assert.AreEqual(shape1, (3, 2)); | |||||
var shape2 = ShapeHelper.GetShape(x.shape, new Slice(1)); | |||||
Assert.AreEqual(shape2, (3, 3, 2)); | |||||
var shape3 = ShapeHelper.GetShape(x.shape, new Slice(2), Slice.All); | |||||
Assert.AreEqual(shape3, (2, 3, 2)); | |||||
var shape4 = ShapeHelper.GetShape(x.shape, new Slice(1, isIndex: true), new Slice(2)); | |||||
Assert.AreEqual(shape4, (1, 2)); | |||||
var shape5 = ShapeHelper.GetShape(x.shape, new Slice(1, isIndex: true), new Slice(1)); | |||||
Assert.AreEqual(shape5, (2, 2)); | |||||
var shape6 = ShapeHelper.GetShape(x.shape, new Slice(1), new Slice(1, isIndex: true), new Slice(1)); | |||||
Assert.AreEqual(shape6, (3, 1)); | |||||
} | |||||
[TestMethod] | |||||
public void shape_helper_get_shape_4dim() | |||||
{ | |||||
var x = np.arange(120).reshape((4, 3, 2, 5)); | |||||
var slices = new[] { new Slice(1, isIndex: true), new Slice(1), new Slice(0, isIndex: true), new Slice(1) }; | |||||
var shape1 = ShapeHelper.GetShape(x.shape, slices); | |||||
Assert.AreEqual(shape1, (2, 4)); | |||||
var shape2 = ShapeHelper.GetShape(x.shape, Slice.All); | |||||
Assert.AreEqual(shape2, (4, 3, 2, 5)); | |||||
var shape3 = ShapeHelper.GetShape(x.shape, Slice.All, new Slice(0, isIndex: true)); | |||||
Assert.AreEqual(shape3, (4, 3, 2)); | |||||
} | |||||
} | } | ||||
} | } |