@@ -513,10 +513,13 @@ namespace Tensorflow | |||
if (data is NDArray nd) | |||
return nd.shape; | |||
if (data is Tensor tensor) | |||
else if (data is Tensor tensor) | |||
return tensor.shape; | |||
if (!data.GetType().IsArray) | |||
else if (data is Axis axis) | |||
return axis.IsScalar ? Shape.Scalar : new Shape(axis.axis); | |||
else if (!data.GetType().IsArray) | |||
return Shape.Scalar; | |||
switch (data) | |||
@@ -17,7 +17,7 @@ namespace Tensorflow | |||
NumOfExamples = (int)images.dims[0]; | |||
images = images.reshape((images.dims[0], images.dims[1] * images.dims[2])); | |||
// images = images.reshape((images.dims[0], images.dims[1] * images.dims[2])); | |||
images = images.astype(dataType); | |||
// for debug np.multiply performance | |||
var sw = new Stopwatch(); | |||
@@ -123,9 +123,7 @@ namespace Tensorflow | |||
bytestream.Read(buf, 0, buf.Length); | |||
var data = np.frombuffer(buf, new Shape(buf.Length), np.@byte); | |||
data = data.reshape((num_images, rows, cols, 1)); | |||
var data = np.frombuffer(buf, (num_images, rows * cols), np.@byte); | |||
return data; | |||
} | |||
} | |||
@@ -24,6 +24,7 @@ namespace Tensorflow | |||
public record Axis(params int[] axis) | |||
{ | |||
public int size => axis == null ? -1 : axis.Length; | |||
public bool IsScalar { get; init; } | |||
public int this[int index] => axis[index]; | |||
@@ -34,7 +35,7 @@ namespace Tensorflow | |||
=> axis.axis[0]; | |||
public static implicit operator Axis(int axis) | |||
=> new Axis(axis); | |||
=> new Axis(axis) { IsScalar = true }; | |||
public static implicit operator Axis((int, int) axis) | |||
=> new Axis(axis.Item1, axis.Item2); | |||
@@ -15,7 +15,7 @@ namespace Tensorflow.NumPy | |||
Start = x, | |||
Stop = x + 1, | |||
IsIndex = true | |||
})); | |||
}).ToArray()); | |||
set => SetData(indices.Select(x => | |||
{ | |||
@@ -55,21 +55,58 @@ namespace Tensorflow.NumPy | |||
} | |||
} | |||
NDArray GetData(IEnumerable<Slice> slices) | |||
unsafe NDArray GetData(Slice[] slices) | |||
{ | |||
if (shape.IsScalar) | |||
return GetScalar(); | |||
if (SliceHelper.AreAllIndex(slices, out var indices1)) | |||
{ | |||
var newshape = ShapeHelper.GetShape(shape, slices); | |||
if (newshape.IsScalar) | |||
{ | |||
var offset = ShapeHelper.GetOffset(shape, indices1); | |||
return GetScalar((ulong)offset); | |||
} | |||
else | |||
{ | |||
return GetArrayData(newshape, indices1); | |||
} | |||
} | |||
else if (slices.Count() == 1) | |||
{ | |||
var slice = slices[0]; | |||
if (slice.Step == 1) | |||
{ | |||
var newshape = ShapeHelper.GetShape(shape, slice); | |||
var array = new NDArray(newshape, dtype: dtype); | |||
var new_dims = new int[shape.ndim]; | |||
new_dims[0] = slice.Start ?? 0; | |||
//for (int i = 1; i < shape.ndim; i++) | |||
//new_dims[i] = (int)shape.dims[i]; | |||
var offset = ShapeHelper.GetOffset(shape, new_dims); | |||
var src = (byte*)data + (ulong)offset * dtypesize; | |||
var dst = (byte*)array.data; | |||
var len = (ulong)newshape.size * dtypesize; | |||
System.Buffer.MemoryCopy(src, dst, len, len); | |||
return array; | |||
} | |||
} | |||
// default, performance is bad | |||
var tensor = base[slices.ToArray()]; | |||
if (tensor.Handle == null) | |||
{ | |||
if (tf.executing_eagerly()) | |||
tensor = tf.defaultSession.eval(tensor); | |||
else | |||
return new NDArray(tensor); | |||
} | |||
return new NDArray(tensor); | |||
return new NDArray(tensor, tf.executing_eagerly()); | |||
} | |||
unsafe T GetAtIndex<T>(params int[] indices) where T : unmanaged | |||
@@ -78,17 +115,26 @@ namespace Tensorflow.NumPy | |||
return *((T*)data + offset); | |||
} | |||
NDArray GetScalar() | |||
unsafe NDArray GetScalar(ulong offset = 0) | |||
{ | |||
var array = new NDArray(Shape.Scalar, dtype: dtype); | |||
unsafe | |||
{ | |||
var src = (byte*)data + dtypesize; | |||
System.Buffer.MemoryCopy(src, array.buffer.ToPointer(), bytesize, bytesize); | |||
} | |||
var src = (byte*)data + offset * dtypesize; | |||
System.Buffer.MemoryCopy(src, array.buffer.ToPointer(), dtypesize, dtypesize); | |||
return array; | |||
} | |||
unsafe NDArray GetArrayData(Shape newshape, int[] indices) | |||
{ | |||
var offset = ShapeHelper.GetOffset(shape, indices); | |||
var len = (ulong)newshape.size * dtypesize; | |||
var array = new NDArray(newshape, dtype: dtype); | |||
var src = (byte*)data + (ulong)offset * dtypesize; | |||
System.Buffer.MemoryCopy(src, array.data.ToPointer(), len, len); | |||
return array; | |||
} | |||
NDArray GetData(int[] indices, int axis = 0) | |||
{ | |||
if (shape.IsScalar) | |||
@@ -5,7 +5,7 @@ using System.Text; | |||
namespace Tensorflow.NumPy | |||
{ | |||
internal class ShapeHelper | |||
public class ShapeHelper | |||
{ | |||
public static long GetSize(Shape shape) | |||
{ | |||
@@ -41,6 +41,34 @@ namespace Tensorflow.NumPy | |||
return strides; | |||
} | |||
public static Shape GetShape(Shape shape1, params Slice[] slices) | |||
{ | |||
var new_dims = shape1.dims.ToArray(); | |||
slices = SliceHelper.AlignWithShape(shape1, slices); | |||
for (int i = 0; i < shape1.dims.Length; i++) | |||
{ | |||
Slice slice = slices[i]; | |||
if (slice.Equals(Slice.All)) | |||
new_dims[i] = shape1.dims[i]; | |||
else if (slice.IsIndex) | |||
new_dims[i] = 1; | |||
else // range | |||
new_dims[i] = (slice.Stop ?? shape1.dims[i]) - (slice.Start ?? 0); | |||
} | |||
// strip first dim if is index | |||
var return_dims = new List<long>(); | |||
for (int i = 0; i< new_dims.Length; i++) | |||
{ | |||
if (slices[i].IsIndex) | |||
continue; | |||
return_dims.add(new_dims[i]); | |||
} | |||
return new Shape(return_dims.ToArray()); | |||
} | |||
public static bool Equals(Shape shape, object target) | |||
{ | |||
switch (target) | |||
@@ -0,0 +1,56 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
namespace Tensorflow.NumPy | |||
{ | |||
public class SliceHelper | |||
{ | |||
public static Slice[] AlignWithShape(Shape shape, Slice[] slices) | |||
{ | |||
// align slices | |||
var ndim = shape.ndim; | |||
var new_slices = new List<Slice>(); | |||
var slice_index = 0; | |||
for (int i = 0; i < ndim; i++) | |||
{ | |||
if (slice_index > slices.Length - 1) | |||
{ | |||
new_slices.Add(Slice.All); | |||
continue; | |||
} | |||
if (slices[slice_index] == Slice.All) | |||
{ | |||
new_slices.Add(Slice.All); | |||
for (int j = 0; j < ndim - slices.Length; j++) | |||
{ | |||
new_slices.Add(Slice.All); | |||
i++; | |||
} | |||
} | |||
else | |||
{ | |||
new_slices.Add(slices[slice_index]); | |||
} | |||
slice_index++; | |||
} | |||
return new_slices.ToArray(); | |||
} | |||
public static bool AreAllIndex(Slice[] slices, out int[] indices) | |||
{ | |||
indices = new int[slices.Length]; | |||
for (int i = 0; i< slices.Length; i++) | |||
{ | |||
indices[i] = slices[i].Start ?? 0; | |||
if (!slices[i].IsIndex) | |||
return false; | |||
} | |||
return true; | |||
} | |||
} | |||
} |
@@ -28,7 +28,7 @@ namespace Tensorflow.NumPy | |||
public NDArray(IntPtr address, Shape shape, TF_DataType dtype) | |||
: base(address, shape, dtype) { NewEagerTensorHandle(); } | |||
public NDArray(Tensor tensor) : base(tensor.Handle) | |||
public NDArray(Tensor tensor, bool eval = true) : base(tensor.Handle) | |||
{ | |||
if (_handle is null) | |||
{ | |||
@@ -53,9 +53,12 @@ namespace Tensorflow.NumPy | |||
void NewEagerTensorHandle() | |||
{ | |||
_id = ops.uid(); | |||
_eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle); | |||
tf.Status.Check(true); | |||
if(_handle is not null) | |||
{ | |||
_id = ops.uid(); | |||
_eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle); | |||
tf.Status.Check(true); | |||
} | |||
} | |||
} | |||
} |
@@ -115,11 +115,12 @@ namespace Tensorflow | |||
/// <param name="start">Start index of the slice, null means from the start of the array</param> | |||
/// <param name="stop">Stop index (first index after end of slice), null means to the end of the array</param> | |||
/// <param name="step">Optional step to select every n-th element, defaults to 1</param> | |||
public Slice(int? start = null, int? stop = null, int step = 1) | |||
public Slice(int? start = null, int? stop = null, int step = 1, bool isIndex = false) | |||
{ | |||
Start = start; | |||
Stop = stop; | |||
Step = step; | |||
IsIndex = isIndex; | |||
} | |||
public Slice(string slice_notation) | |||
@@ -8,7 +8,7 @@ namespace Tensorflow | |||
public sealed class SafeStringTensorHandle : SafeTensorHandle | |||
{ | |||
Shape _shape; | |||
SafeTensorHandle _handle; | |||
IntPtr _handle; | |||
const int TF_TSRING_SIZE = 24; | |||
protected SafeStringTensorHandle() | |||
@@ -18,7 +18,7 @@ namespace Tensorflow | |||
public SafeStringTensorHandle(SafeTensorHandle handle, Shape shape) | |||
: base(handle.DangerousGetHandle()) | |||
{ | |||
_handle = handle; | |||
_handle = c_api.TF_TensorData(handle); | |||
_shape = shape; | |||
} | |||
@@ -28,15 +28,10 @@ namespace Tensorflow | |||
print($"Delete StringTensorHandle 0x{handle.ToString("x16")}"); | |||
#endif | |||
long size = 1; | |||
foreach (var s in _shape.dims) | |||
size *= s; | |||
var tstr = c_api.TF_TensorData(_handle); | |||
for (int i = 0; i < size; i++) | |||
for (int i = 0; i < _shape.size; i++) | |||
{ | |||
c_api.TF_StringDealloc(tstr); | |||
tstr += TF_TSRING_SIZE; | |||
c_api.TF_StringDealloc(_handle); | |||
_handle += TF_TSRING_SIZE; | |||
} | |||
SetHandle(IntPtr.Zero); | |||
@@ -23,7 +23,7 @@ namespace Tensorflow | |||
public SafeStringTensorHandle StringTensor(byte[][] buffer, Shape shape) | |||
{ | |||
var handle = c_api.TF_AllocateTensor(TF_DataType.TF_STRING, | |||
shape.ndim == 0 ? null : shape.dims, | |||
shape.dims, | |||
shape.ndim, | |||
(ulong)shape.size * TF_TSRING_SIZE); | |||
@@ -472,6 +472,9 @@ would not be rank 1.", tensor.op.get_attr("axis"))); | |||
public static string to_numpy_string(Tensor tensor) | |||
{ | |||
if (tensor.buffer == IntPtr.Zero) | |||
return "Empty"; | |||
var dtype = tensor.dtype; | |||
var shape = tensor.shape; | |||
@@ -161,7 +161,7 @@ namespace Tensorflow | |||
IEnumerable<Tensor> tensors => array_ops._autopacking_helper(tensors, dtype, name == null ? "packed" : name), | |||
RefVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref), | |||
ResourceVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref), | |||
Axis ts => constant_op.constant(ts.axis, dtype: dtype, name: name), | |||
Axis ts => constant_op.constant(ts, dtype: dtype, name: name), | |||
Shape ts => constant_op.constant(ts.dims, dtype: dtype, name: name), | |||
string str => constant_op.constant(str, dtype: tf.@string, name: name), | |||
string[] str => constant_op.constant(str, dtype: tf.@string, name: name), | |||
@@ -63,5 +63,43 @@ namespace TensorFlowNET.UnitTest.NumPy | |||
input_shape_val[(int)input_shape.size - 1] = 1; | |||
input_shape.Dispose(); | |||
} | |||
[TestMethod] | |||
public void shape_helper_get_shape_3dim() | |||
{ | |||
var x = np.arange(24).reshape((4, 3, 2)); | |||
var shape1 = ShapeHelper.GetShape(x.shape, new Slice(1, isIndex: true)); | |||
Assert.AreEqual(shape1, (3, 2)); | |||
var shape2 = ShapeHelper.GetShape(x.shape, new Slice(1)); | |||
Assert.AreEqual(shape2, (3, 3, 2)); | |||
var shape3 = ShapeHelper.GetShape(x.shape, new Slice(2), Slice.All); | |||
Assert.AreEqual(shape3, (2, 3, 2)); | |||
var shape4 = ShapeHelper.GetShape(x.shape, new Slice(1, isIndex: true), new Slice(2)); | |||
Assert.AreEqual(shape4, (1, 2)); | |||
var shape5 = ShapeHelper.GetShape(x.shape, new Slice(1, isIndex: true), new Slice(1)); | |||
Assert.AreEqual(shape5, (2, 2)); | |||
var shape6 = ShapeHelper.GetShape(x.shape, new Slice(1), new Slice(1, isIndex: true), new Slice(1)); | |||
Assert.AreEqual(shape6, (3, 1)); | |||
} | |||
[TestMethod] | |||
public void shape_helper_get_shape_4dim() | |||
{ | |||
var x = np.arange(120).reshape((4, 3, 2, 5)); | |||
var slices = new[] { new Slice(1, isIndex: true), new Slice(1), new Slice(0, isIndex: true), new Slice(1) }; | |||
var shape1 = ShapeHelper.GetShape(x.shape, slices); | |||
Assert.AreEqual(shape1, (2, 4)); | |||
var shape2 = ShapeHelper.GetShape(x.shape, Slice.All); | |||
Assert.AreEqual(shape2, (4, 3, 2, 5)); | |||
var shape3 = ShapeHelper.GetShape(x.shape, Slice.All, new Slice(0, isIndex: true)); | |||
Assert.AreEqual(shape3, (4, 3, 2)); | |||
} | |||
} | |||
} |