Browse Source

Tensor: Added guards for explicit casts.

tags/v0.12
Eli Belash 6 years ago
parent
commit
766fa6fc5d
1 changed files with 18 additions and 6 deletions
  1. +18
    -6
      src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs

+ 18
- 6
src/TensorFlowNET.Core/Tensors/Tensor.Explicit.cs View File

@@ -10,6 +10,7 @@ namespace Tensorflow
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_BOOL);
return *(bool*) tensor.buffer;
}
}
@@ -19,6 +20,7 @@ namespace Tensorflow
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_INT8);
return *(sbyte*) tensor.buffer;
}
}
@@ -28,6 +30,7 @@ namespace Tensorflow
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_UINT8);
return *(byte*) tensor.buffer;
}
}
@@ -37,6 +40,7 @@ namespace Tensorflow
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_UINT16);
return *(ushort*) tensor.buffer;
}
}
@@ -46,6 +50,7 @@ namespace Tensorflow
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_INT16);
return *(short*) tensor.buffer;
}
}
@@ -55,6 +60,7 @@ namespace Tensorflow
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_INT32);
return *(int*) tensor.buffer;
}
}
@@ -64,6 +70,7 @@ namespace Tensorflow
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_UINT32);
return *(uint*) tensor.buffer;
}
}
@@ -73,6 +80,7 @@ namespace Tensorflow
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_INT64);
return *(long*) tensor.buffer;
}
}
@@ -82,6 +90,7 @@ namespace Tensorflow
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_UINT64);
return *(ulong*) tensor.buffer;
}
}
@@ -91,6 +100,7 @@ namespace Tensorflow
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_FLOAT);
return *(float*) tensor.buffer;
}
}
@@ -100,27 +110,29 @@ namespace Tensorflow
unsafe
{
EnsureScalar(tensor);
EnsureDType(tensor, TF_DataType.TF_DOUBLE);
return *(double*) tensor.buffer;
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void EnsureDType(Tensor tensor, TF_DataType @is)
{
if (tensor._dtype != @is)
throw new InvalidCastException($"Unable to cast scalar tensor {tensor._dtype} to {@is}");
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void EnsureScalar(Tensor tensor)
{
if (tensor == null)
{
throw new ArgumentNullException(nameof(tensor));
}

if (tensor.TensorShape.ndim != 0)
{
throw new ArgumentException("Tensor must have 0 dimensions in order to convert to scalar");
}

if (tensor.TensorShape.size != 1)
{
throw new ArgumentException("Tensor must have size 1 in order to convert to scalar");
}
}

}


Loading…
Cancel
Save