Browse Source

Add more explicit conversion for Tensors.

pull/996/head
Yaohui Liu 2 years ago
parent
commit
6a295b68fc
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
2 changed files with 110 additions and 9 deletions
  1. +101
    -0
      src/TensorFlowNET.Core/Tensors/Tensors.cs
  2. +9
    -9
      test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs

+ 101
- 0
src/TensorFlowNET.Core/Tensors/Tensors.cs View File

@@ -65,6 +65,93 @@ namespace Tensorflow
IEnumerator IEnumerable.GetEnumerator()
=> GetEnumerator();

public NDArray numpy()
{
EnsureSingleTensor(this, "nnumpy");
return this[0].numpy();
}

public T[] ToArray<T>() where T: unmanaged
{
EnsureSingleTensor(this, $"ToArray<{typeof(T)}>");
return this[0].ToArray<T>();
}

#region Explicit Conversions
public unsafe static explicit operator bool(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to bool");
return (bool)tensor[0];
}

public unsafe static explicit operator sbyte(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to sbyte");
return (sbyte)tensor[0];
}

public unsafe static explicit operator byte(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to byte");
return (byte)tensor[0];
}

public unsafe static explicit operator ushort(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to ushort");
return (ushort)tensor[0];
}

public unsafe static explicit operator short(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to short");
return (short)tensor[0];
}

public unsafe static explicit operator int(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to int");
return (int)tensor[0];
}

public unsafe static explicit operator uint(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to uint");
return (uint)tensor[0];
}

public unsafe static explicit operator long(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to long");
return (long)tensor[0];
}

public unsafe static explicit operator ulong(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to ulong");
return (ulong)tensor[0];
}

public unsafe static explicit operator float(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to byte");
return (byte)tensor[0];
}

public unsafe static explicit operator double(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to double");
return (double)tensor[0];
}

public unsafe static explicit operator string(Tensors tensor)
{
EnsureSingleTensor(tensor, "explicit conversion to string");
return (string)tensor[0];
}
#endregion

#region Implicit Conversions
public static implicit operator Tensors(Tensor tensor)
=> new Tensors(tensor);

@@ -87,12 +174,26 @@ namespace Tensorflow
public static implicit operator Tensor[](Tensors tensors)
=> tensors.items.ToArray();

#endregion

public void Deconstruct(out Tensor a, out Tensor b)
{
a = items[0];
b = items[1];
}

private static void EnsureSingleTensor(Tensors tensors, string methodnName)
{
if(tensors.Length == 0)
{
throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains no Tensor.");
}
else if(tensors.Length > 1)
{
throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains more than one Tensor.");
}
}

public override string ToString()
=> items.Count() == 1
? items.First().ToString()


+ 9
- 9
test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs View File

@@ -20,7 +20,7 @@ namespace TensorFlowNET.UnitTest.Dataset
Assert.AreEqual(iStep, step);
iStep++;

Assert.AreEqual(value, (long)item.Item1[0]);
Assert.AreEqual(value, (long)item.Item1);
value++;
}
}
@@ -39,7 +39,7 @@ namespace TensorFlowNET.UnitTest.Dataset
Assert.AreEqual(iStep, step);
iStep++;

Assert.AreEqual(value, (long)item.Item1[0]);
Assert.AreEqual(value, (long)item.Item1);
value += 2;
}
}
@@ -54,7 +54,7 @@ namespace TensorFlowNET.UnitTest.Dataset
int n = 0;
foreach (var (item_x, item_y) in dataset)
{
print($"x:{item_x[0].numpy()},y:{item_y[0].numpy()}");
print($"x:{item_x.numpy()},y:{item_y.numpy()}");
n += 1;
}
Assert.AreEqual(5, n);
@@ -69,7 +69,7 @@ namespace TensorFlowNET.UnitTest.Dataset
int n = 0;
foreach (var x in dataset)
{
Assert.IsTrue(X.SequenceEqual(x.Item1[0].ToArray<int>()));
Assert.IsTrue(X.SequenceEqual(x.Item1.ToArray<int>()));
n += 1;
}
Assert.AreEqual(1, n);
@@ -85,7 +85,7 @@ namespace TensorFlowNET.UnitTest.Dataset

foreach (var item in dataset2)
{
Assert.AreEqual(value, (long)item.Item1[0]);
Assert.AreEqual(value, (long)item.Item1);
value += 3;
}

@@ -93,7 +93,7 @@ namespace TensorFlowNET.UnitTest.Dataset
var dataset3 = dataset1.shard(num_shards: 3, index: 1);
foreach (var item in dataset3)
{
Assert.AreEqual(value, (long)item.Item1[0]);
Assert.AreEqual(value, (long)item.Item1);
value += 3;
}
}
@@ -108,7 +108,7 @@ namespace TensorFlowNET.UnitTest.Dataset

foreach (var item in dataset)
{
Assert.AreEqual(value, (long)item.Item1[0]);
Assert.AreEqual(value, (long)item.Item1);
value++;
}
}
@@ -123,7 +123,7 @@ namespace TensorFlowNET.UnitTest.Dataset

foreach (var item in dataset)
{
Assert.AreEqual(value + 10, (long)item.Item1[0]);
Assert.AreEqual(value + 10, (long)item.Item1);
value++;
}
}
@@ -138,7 +138,7 @@ namespace TensorFlowNET.UnitTest.Dataset

foreach (var item in dataset)
{
Assert.AreEqual(value, (long)item.Item1[0]);
Assert.AreEqual(value, (long)item.Item1);
value++;
}
}


Loading…
Cancel
Save