@@ -65,6 +65,93 @@ namespace Tensorflow | |||||
IEnumerator IEnumerable.GetEnumerator() | IEnumerator IEnumerable.GetEnumerator() | ||||
=> GetEnumerator(); | => GetEnumerator(); | ||||
public NDArray numpy() | |||||
{ | |||||
EnsureSingleTensor(this, "nnumpy"); | |||||
return this[0].numpy(); | |||||
} | |||||
public T[] ToArray<T>() where T: unmanaged | |||||
{ | |||||
EnsureSingleTensor(this, $"ToArray<{typeof(T)}>"); | |||||
return this[0].ToArray<T>(); | |||||
} | |||||
#region Explicit Conversions | |||||
public unsafe static explicit operator bool(Tensors tensor) | |||||
{ | |||||
EnsureSingleTensor(tensor, "explicit conversion to bool"); | |||||
return (bool)tensor[0]; | |||||
} | |||||
public unsafe static explicit operator sbyte(Tensors tensor) | |||||
{ | |||||
EnsureSingleTensor(tensor, "explicit conversion to sbyte"); | |||||
return (sbyte)tensor[0]; | |||||
} | |||||
public unsafe static explicit operator byte(Tensors tensor) | |||||
{ | |||||
EnsureSingleTensor(tensor, "explicit conversion to byte"); | |||||
return (byte)tensor[0]; | |||||
} | |||||
public unsafe static explicit operator ushort(Tensors tensor) | |||||
{ | |||||
EnsureSingleTensor(tensor, "explicit conversion to ushort"); | |||||
return (ushort)tensor[0]; | |||||
} | |||||
public unsafe static explicit operator short(Tensors tensor) | |||||
{ | |||||
EnsureSingleTensor(tensor, "explicit conversion to short"); | |||||
return (short)tensor[0]; | |||||
} | |||||
public unsafe static explicit operator int(Tensors tensor) | |||||
{ | |||||
EnsureSingleTensor(tensor, "explicit conversion to int"); | |||||
return (int)tensor[0]; | |||||
} | |||||
public unsafe static explicit operator uint(Tensors tensor) | |||||
{ | |||||
EnsureSingleTensor(tensor, "explicit conversion to uint"); | |||||
return (uint)tensor[0]; | |||||
} | |||||
public unsafe static explicit operator long(Tensors tensor) | |||||
{ | |||||
EnsureSingleTensor(tensor, "explicit conversion to long"); | |||||
return (long)tensor[0]; | |||||
} | |||||
public unsafe static explicit operator ulong(Tensors tensor) | |||||
{ | |||||
EnsureSingleTensor(tensor, "explicit conversion to ulong"); | |||||
return (ulong)tensor[0]; | |||||
} | |||||
public unsafe static explicit operator float(Tensors tensor) | |||||
{ | |||||
EnsureSingleTensor(tensor, "explicit conversion to byte"); | |||||
return (byte)tensor[0]; | |||||
} | |||||
public unsafe static explicit operator double(Tensors tensor) | |||||
{ | |||||
EnsureSingleTensor(tensor, "explicit conversion to double"); | |||||
return (double)tensor[0]; | |||||
} | |||||
public unsafe static explicit operator string(Tensors tensor) | |||||
{ | |||||
EnsureSingleTensor(tensor, "explicit conversion to string"); | |||||
return (string)tensor[0]; | |||||
} | |||||
#endregion | |||||
#region Implicit Conversions | |||||
public static implicit operator Tensors(Tensor tensor) | public static implicit operator Tensors(Tensor tensor) | ||||
=> new Tensors(tensor); | => new Tensors(tensor); | ||||
@@ -87,12 +174,26 @@ namespace Tensorflow | |||||
public static implicit operator Tensor[](Tensors tensors) | public static implicit operator Tensor[](Tensors tensors) | ||||
=> tensors.items.ToArray(); | => tensors.items.ToArray(); | ||||
#endregion | |||||
public void Deconstruct(out Tensor a, out Tensor b) | public void Deconstruct(out Tensor a, out Tensor b) | ||||
{ | { | ||||
a = items[0]; | a = items[0]; | ||||
b = items[1]; | b = items[1]; | ||||
} | } | ||||
private static void EnsureSingleTensor(Tensors tensors, string methodnName) | |||||
{ | |||||
if(tensors.Length == 0) | |||||
{ | |||||
throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains no Tensor."); | |||||
} | |||||
else if(tensors.Length > 1) | |||||
{ | |||||
throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains more than one Tensor."); | |||||
} | |||||
} | |||||
public override string ToString() | public override string ToString() | ||||
=> items.Count() == 1 | => items.Count() == 1 | ||||
? items.First().ToString() | ? items.First().ToString() | ||||
@@ -20,7 +20,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
Assert.AreEqual(iStep, step); | Assert.AreEqual(iStep, step); | ||||
iStep++; | iStep++; | ||||
Assert.AreEqual(value, (long)item.Item1[0]); | |||||
Assert.AreEqual(value, (long)item.Item1); | |||||
value++; | value++; | ||||
} | } | ||||
} | } | ||||
@@ -39,7 +39,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
Assert.AreEqual(iStep, step); | Assert.AreEqual(iStep, step); | ||||
iStep++; | iStep++; | ||||
Assert.AreEqual(value, (long)item.Item1[0]); | |||||
Assert.AreEqual(value, (long)item.Item1); | |||||
value += 2; | value += 2; | ||||
} | } | ||||
} | } | ||||
@@ -54,7 +54,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
int n = 0; | int n = 0; | ||||
foreach (var (item_x, item_y) in dataset) | foreach (var (item_x, item_y) in dataset) | ||||
{ | { | ||||
print($"x:{item_x[0].numpy()},y:{item_y[0].numpy()}"); | |||||
print($"x:{item_x.numpy()},y:{item_y.numpy()}"); | |||||
n += 1; | n += 1; | ||||
} | } | ||||
Assert.AreEqual(5, n); | Assert.AreEqual(5, n); | ||||
@@ -69,7 +69,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
int n = 0; | int n = 0; | ||||
foreach (var x in dataset) | foreach (var x in dataset) | ||||
{ | { | ||||
Assert.IsTrue(X.SequenceEqual(x.Item1[0].ToArray<int>())); | |||||
Assert.IsTrue(X.SequenceEqual(x.Item1.ToArray<int>())); | |||||
n += 1; | n += 1; | ||||
} | } | ||||
Assert.AreEqual(1, n); | Assert.AreEqual(1, n); | ||||
@@ -85,7 +85,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
foreach (var item in dataset2) | foreach (var item in dataset2) | ||||
{ | { | ||||
Assert.AreEqual(value, (long)item.Item1[0]); | |||||
Assert.AreEqual(value, (long)item.Item1); | |||||
value += 3; | value += 3; | ||||
} | } | ||||
@@ -93,7 +93,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
var dataset3 = dataset1.shard(num_shards: 3, index: 1); | var dataset3 = dataset1.shard(num_shards: 3, index: 1); | ||||
foreach (var item in dataset3) | foreach (var item in dataset3) | ||||
{ | { | ||||
Assert.AreEqual(value, (long)item.Item1[0]); | |||||
Assert.AreEqual(value, (long)item.Item1); | |||||
value += 3; | value += 3; | ||||
} | } | ||||
} | } | ||||
@@ -108,7 +108,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
foreach (var item in dataset) | foreach (var item in dataset) | ||||
{ | { | ||||
Assert.AreEqual(value, (long)item.Item1[0]); | |||||
Assert.AreEqual(value, (long)item.Item1); | |||||
value++; | value++; | ||||
} | } | ||||
} | } | ||||
@@ -123,7 +123,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
foreach (var item in dataset) | foreach (var item in dataset) | ||||
{ | { | ||||
Assert.AreEqual(value + 10, (long)item.Item1[0]); | |||||
Assert.AreEqual(value + 10, (long)item.Item1); | |||||
value++; | value++; | ||||
} | } | ||||
} | } | ||||
@@ -138,7 +138,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
foreach (var item in dataset) | foreach (var item in dataset) | ||||
{ | { | ||||
Assert.AreEqual(value, (long)item.Item1[0]); | |||||
Assert.AreEqual(value, (long)item.Item1); | |||||
value++; | value++; | ||||
} | } | ||||
} | } | ||||