Browse Source

Merge pull request #1257 from novikov-alexander/alnovi/generic-cast

fix: More generic array cast
pull/1281/head
C. Oceania GitHub 1 year ago
parent
commit
7fb73cda3f
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
1 changed files with 59 additions and 29 deletions
  1. +59
    -29
      src/TensorFlowNET.Core/Tensors/tensor_util.cs

+ 59
- 29
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -67,7 +67,7 @@ namespace Tensorflow

T[] ExpandArrayToSize<T>(IList<T> src)
{
if(src.Count == 0)
if (src.Count == 0)
{
return new T[0];
}
@@ -77,7 +77,7 @@ namespace Tensorflow
var first_elem = src[0];
var last_elem = src[src.Count - 1];
T[] res = new T[num_elements];
for(long i = 0; i < num_elements; i++)
for (long i = 0; i < num_elements; i++)
{
if (i < pre) res[i] = first_elem;
else if (i >= num_elements - after) res[i] = last_elem;
@@ -121,7 +121,7 @@ namespace Tensorflow
$"https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.");
}

if(values.size == 0)
if (values.size == 0)
{
return np.zeros(shape, tensor_dtype);
}
@@ -135,23 +135,47 @@ namespace Tensorflow
TF_DataType.TF_QINT32
};

private static TOut[,] ConvertArray2D<TIn, TOut>(TIn[,] inputArray, Func<TIn, TOut> converter)
private static Array ConvertArray<TOut>(Array inputArray, Func<object, TOut> converter)
{
var rows = inputArray.GetLength(0);
var cols = inputArray.GetLength(1);
var outputArray = new TOut[rows, cols];
if (inputArray == null)
throw new ArgumentNullException(nameof(inputArray));

for (var i = 0; i < rows; i++)
var elementType = typeof(TOut);
var lengths = new int[inputArray.Rank];
for (var i = 0; i < inputArray.Rank; i++)
{
for (var j = 0; j < cols; j++)
{
outputArray[i, j] = converter(inputArray[i, j]);
}
lengths[i] = inputArray.GetLength(i);
}

var outputArray = Array.CreateInstance(elementType, lengths);

FillArray(inputArray, outputArray, converter, new int[inputArray.Rank], 0);

return outputArray;
}

private static void FillArray<TIn, TOut>(Array inputArray, Array outputArray, Func<TIn, TOut> converter, int[] indices, int dimension)
{
if (dimension == inputArray.Rank - 1)
{
for (int i = 0; i < inputArray.GetLength(dimension); i++)
{
indices[dimension] = i;
var inputValue = (TIn)inputArray.GetValue(indices);
var convertedValue = converter(inputValue);
outputArray.SetValue(convertedValue, indices);
}
}
else
{
for (int i = 0; i < inputArray.GetLength(dimension); i++)
{
indices[dimension] = i;
FillArray(inputArray, outputArray, converter, indices, dimension + 1);
}
}
}

/// <summary>
/// Create a TensorProto, invoked in graph mode
/// </summary>
@@ -171,24 +195,30 @@ namespace Tensorflow
var origin_dtype = values.GetDataType();
if (dtype == TF_DataType.DtInvalid)
dtype = origin_dtype;
else if(origin_dtype != dtype)
else if (origin_dtype != dtype)
{
var new_system_dtype = dtype.as_system_dtype();
values = values switch

if (dtype != TF_DataType.TF_STRING && dtype != TF_DataType.TF_VARIANT && dtype != TF_DataType.TF_RESOURCE)
{
if (values is Array arrayValues)
{
values = dtype switch
{
TF_DataType.TF_INT32 => ConvertArray(arrayValues, Convert.ToInt32),
TF_DataType.TF_FLOAT => ConvertArray(arrayValues, Convert.ToSingle),
TF_DataType.TF_DOUBLE => ConvertArray(arrayValues, Convert.ToDouble),
_ => values,
};
} else
{
values = Convert.ChangeType(values, new_system_dtype);
}
} else
{
long[] longValues when dtype == TF_DataType.TF_INT32 => longValues.Select(x => (int)x).ToArray(),
long[] longValues => values,
float[] floatValues when dtype == TF_DataType.TF_DOUBLE => floatValues.Select(x => (double)x).ToArray(),
float[] floatValues => values,
float[,] float2DValues when dtype == TF_DataType.TF_DOUBLE => ConvertArray2D(float2DValues, Convert.ToDouble),
float[,] float2DValues => values,
double[] doubleValues when dtype == TF_DataType.TF_FLOAT => doubleValues.Select(x => (float)x).ToArray(),
double[] doubleValues => values,
double[,] double2DValues when dtype == TF_DataType.TF_FLOAT => ConvertArray2D(double2DValues, Convert.ToSingle),
double[,] double2DValues => values,
_ => Convert.ChangeType(values, new_system_dtype),
};

}
dtype = values.GetDataType();
}

@@ -306,7 +336,7 @@ namespace Tensorflow

if (tensor is EagerTensor eagerTensor)
{
if(tensor.dtype == tf.int64)
if (tensor.dtype == tf.int64)
return new Shape(tensor.ToArray<long>());
else
return new Shape(tensor.ToArray<int>());
@@ -481,7 +511,7 @@ would not be rank 1.", tensor.op.get_attr("axis")));
var d_ = new int[value.size];
foreach (var (index, d) in enumerate(value.ToArray<int>()))
d_[index] = d >= 0 ? d : -1;
ret = ret.merge_with(new Shape(d_));
}
return ret;


Loading…
Cancel
Save