Browse Source

to_numpy_string

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
76abb2c6f0
2 changed files with 88 additions and 7 deletions
  1. +78
    -7
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  2. +10
    -0
      test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs

+ 78
- 7
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -470,7 +470,20 @@ would not be rank 1.", tensor.op.get_attr("axis")));
return ops.convert_to_tensor(shape, dtype: TF_DataType.TF_INT32, name: "shape");
}

public static string to_numpy_string(Tensor tensor)
public static string to_numpy_string(Tensor array)
{
Shape shape = array.shape;
if (shape.ndim == 0)
return array[0].ToString();

var s = new StringBuilder();
s.Append("array(");
PrettyPrint(s, array);
s.Append(")");
return s.ToString();
}

static string Render(Tensor tensor)
{
if (tensor.buffer == IntPtr.Zero)
return "Empty";
@@ -487,7 +500,7 @@ would not be rank 1.", tensor.op.get_attr("axis")));
else
return $"['{string.Join("', '", tensor.StringData().Take(25))}']";
}
else if(dtype == TF_DataType.TF_VARIANT)
else if (dtype == TF_DataType.TF_VARIANT)
{
return "<unprintable>";
}
@@ -515,7 +528,7 @@ would not be rank 1.", tensor.op.get_attr("axis")));
var array = tensor.ToArray<float>();
return DisplayArrayAsString(array, tensor.shape);
}
else if(dtype == TF_DataType.TF_DOUBLE)
else if (dtype == TF_DataType.TF_DOUBLE)
{
var array = tensor.ToArray<double>();
return DisplayArrayAsString(array, tensor.shape);
@@ -532,14 +545,72 @@ would not be rank 1.", tensor.op.get_attr("axis")));
if (shape.ndim == 0)
return array[0].ToString();

var display = "array([";
var display = "";
if (array.Length < 10)
display += string.Join(", ", array);
else
display += string.Join(", ", array.Take(3)) + " ... " + string.Join(", ", array.Skip(array.Length - 3));
return display + "])";
display += string.Join(", ", array.Take(3)) + ", ..., " + string.Join(", ", array.Skip(array.Length - 3));
return display;
}

static void PrettyPrint(StringBuilder s, Tensor array, bool flat = false)
{
var shape = array.shape;

if (shape.Length == 1)
{
s.Append("[");
s.Append(Render(array));
s.Append("]");
return;
}

var len = shape[0];
s.Append("[");

if (len <= 10)
{
for (int i = 0; i < len; i++)
{
PrettyPrint(s, array[i], flat);
if (i < len - 1)
{
s.Append(", ");
if (!flat)
s.AppendLine();
}
}
}
else
{
for (int i = 0; i < 5; i++)
{
PrettyPrint(s, array[i], flat);
if (i < len - 1)
{
s.Append(", ");
if (!flat)
s.AppendLine();
}
}

s.Append(" ... ");
s.AppendLine();

for (int i = (int)array.size - 5; i < len; i++)
{
PrettyPrint(s, array[i], flat);
if (i < len - 1)
{
s.Append(", ");
if (!flat)
s.AppendLine();
}
}
}

s.Append("]");
}

public static ParsedSliceArgs ParseSlices(Slice[] slices)
{


+ 10
- 0
test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs View File

@@ -3,6 +3,7 @@ using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow;
using Tensorflow.NumPy;

namespace TensorFlowNET.UnitTest.NumPy
@@ -88,5 +89,14 @@ namespace TensorFlowNET.UnitTest.NumPy
AssetSequenceEqual(a.ToArray<int>(), new int[] { 0, 1, 2, 0, 1, 2, 0, 1, 2 });
AssetSequenceEqual(b.ToArray<int>(), new int[] { 0, 0, 0, 1, 1, 1, 2, 2, 2 });
}

[TestMethod]
public void to_numpy_string()
{
var nd = np.arange(10 * 10 * 10 * 10).reshape((10, 10, 10, 10));
var str = tensor_util.to_numpy_string(nd);
Assert.AreEqual("array([[[[0, 1, 2, ..., 7, 8, 9],", str.Substring(0, 33));
Assert.AreEqual("[9990, 9991, 9992, ..., 9997, 9998, 9999]]]])", str.Substring(str.Length - 45));
}
}
}

Loading…
Cancel
Save