|
|
@@ -10,6 +10,7 @@ namespace Tensorflow |
|
|
|
unsafe |
|
|
|
{ |
|
|
|
EnsureScalar(tensor); |
|
|
|
EnsureDType(tensor, TF_DataType.TF_BOOL); |
|
|
|
return *(bool*) tensor.buffer; |
|
|
|
} |
|
|
|
} |
|
|
@@ -19,6 +20,7 @@ namespace Tensorflow |
|
|
|
unsafe |
|
|
|
{ |
|
|
|
EnsureScalar(tensor); |
|
|
|
EnsureDType(tensor, TF_DataType.TF_INT8); |
|
|
|
return *(sbyte*) tensor.buffer; |
|
|
|
} |
|
|
|
} |
|
|
@@ -28,6 +30,7 @@ namespace Tensorflow |
|
|
|
unsafe |
|
|
|
{ |
|
|
|
EnsureScalar(tensor); |
|
|
|
EnsureDType(tensor, TF_DataType.TF_UINT8); |
|
|
|
return *(byte*) tensor.buffer; |
|
|
|
} |
|
|
|
} |
|
|
@@ -37,6 +40,7 @@ namespace Tensorflow |
|
|
|
unsafe |
|
|
|
{ |
|
|
|
EnsureScalar(tensor); |
|
|
|
EnsureDType(tensor, TF_DataType.TF_UINT16); |
|
|
|
return *(ushort*) tensor.buffer; |
|
|
|
} |
|
|
|
} |
|
|
@@ -46,6 +50,7 @@ namespace Tensorflow |
|
|
|
unsafe |
|
|
|
{ |
|
|
|
EnsureScalar(tensor); |
|
|
|
EnsureDType(tensor, TF_DataType.TF_INT16); |
|
|
|
return *(short*) tensor.buffer; |
|
|
|
} |
|
|
|
} |
|
|
@@ -55,6 +60,7 @@ namespace Tensorflow |
|
|
|
unsafe |
|
|
|
{ |
|
|
|
EnsureScalar(tensor); |
|
|
|
EnsureDType(tensor, TF_DataType.TF_INT32); |
|
|
|
return *(int*) tensor.buffer; |
|
|
|
} |
|
|
|
} |
|
|
@@ -64,6 +70,7 @@ namespace Tensorflow |
|
|
|
unsafe |
|
|
|
{ |
|
|
|
EnsureScalar(tensor); |
|
|
|
EnsureDType(tensor, TF_DataType.TF_UINT32); |
|
|
|
return *(uint*) tensor.buffer; |
|
|
|
} |
|
|
|
} |
|
|
@@ -73,6 +80,7 @@ namespace Tensorflow |
|
|
|
unsafe |
|
|
|
{ |
|
|
|
EnsureScalar(tensor); |
|
|
|
EnsureDType(tensor, TF_DataType.TF_INT64); |
|
|
|
return *(long*) tensor.buffer; |
|
|
|
} |
|
|
|
} |
|
|
@@ -82,6 +90,7 @@ namespace Tensorflow |
|
|
|
unsafe |
|
|
|
{ |
|
|
|
EnsureScalar(tensor); |
|
|
|
EnsureDType(tensor, TF_DataType.TF_UINT64); |
|
|
|
return *(ulong*) tensor.buffer; |
|
|
|
} |
|
|
|
} |
|
|
@@ -91,6 +100,7 @@ namespace Tensorflow |
|
|
|
unsafe |
|
|
|
{ |
|
|
|
EnsureScalar(tensor); |
|
|
|
EnsureDType(tensor, TF_DataType.TF_FLOAT); |
|
|
|
return *(float*) tensor.buffer; |
|
|
|
} |
|
|
|
} |
|
|
@@ -100,27 +110,29 @@ namespace Tensorflow |
|
|
|
unsafe |
|
|
|
{ |
|
|
|
EnsureScalar(tensor); |
|
|
|
EnsureDType(tensor, TF_DataType.TF_DOUBLE); |
|
|
|
return *(double*) tensor.buffer; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
[MethodImpl(MethodImplOptions.AggressiveInlining)] |
|
|
|
private static void EnsureDType(Tensor tensor, TF_DataType @is) |
|
|
|
{ |
|
|
|
if (tensor._dtype != @is) |
|
|
|
throw new InvalidCastException($"Unable to cast scalar tensor {tensor._dtype} to {@is}"); |
|
|
|
} |
|
|
|
|
|
|
|
[MethodImpl(MethodImplOptions.AggressiveInlining)] |
|
|
|
private static void EnsureScalar(Tensor tensor) |
|
|
|
{ |
|
|
|
if (tensor == null) |
|
|
|
{ |
|
|
|
throw new ArgumentNullException(nameof(tensor)); |
|
|
|
} |
|
|
|
|
|
|
|
if (tensor.TensorShape.ndim != 0) |
|
|
|
{ |
|
|
|
throw new ArgumentException("Tensor must have 0 dimensions in order to convert to scalar"); |
|
|
|
} |
|
|
|
|
|
|
|
if (tensor.TensorShape.size != 1) |
|
|
|
{ |
|
|
|
throw new ArgumentException("Tensor must have size 1 in order to convert to scalar"); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|