|
|
@@ -1,4 +1,4 @@ |
|
|
|
/***************************************************************************** |
|
|
|
/***************************************************************************** |
|
|
|
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. |
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
@@ -135,6 +135,23 @@ namespace Tensorflow |
|
|
|
TF_DataType.TF_QINT32 |
|
|
|
}; |
|
|
|
|
|
|
|
private static TOut[,] ConvertArray2D<TIn, TOut>(TIn[,] inputArray, Func<TIn, TOut> converter) |
|
|
|
{ |
|
|
|
var rows = inputArray.GetLength(0); |
|
|
|
var cols = inputArray.GetLength(1); |
|
|
|
var outputArray = new TOut[rows, cols]; |
|
|
|
|
|
|
|
for (var i = 0; i < rows; i++) |
|
|
|
{ |
|
|
|
for (var j = 0; j < cols; j++) |
|
|
|
{ |
|
|
|
outputArray[i, j] = converter(inputArray[i, j]); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
return outputArray; |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Create a TensorProto, invoked in graph mode |
|
|
|
/// </summary> |
|
|
@@ -157,19 +174,16 @@ namespace Tensorflow |
|
|
|
else if(origin_dtype != dtype) |
|
|
|
{ |
|
|
|
var new_system_dtype = dtype.as_system_dtype(); |
|
|
|
if (values is long[] long_values) |
|
|
|
{ |
|
|
|
if (dtype == TF_DataType.TF_INT32) |
|
|
|
values = long_values.Select(x => (int)Convert.ChangeType(x, new_system_dtype)).ToArray(); |
|
|
|
} |
|
|
|
else if (values is double[] double_values) |
|
|
|
|
|
|
|
values = values switch |
|
|
|
{ |
|
|
|
if (dtype == TF_DataType.TF_FLOAT) |
|
|
|
values = double_values.Select(x => (float)Convert.ChangeType(x, new_system_dtype)).ToArray(); |
|
|
|
} |
|
|
|
else |
|
|
|
values = Convert.ChangeType(values, new_system_dtype); |
|
|
|
|
|
|
|
long[] longValues when dtype == TF_DataType.TF_INT32 => longValues.Select(x => (int)x).ToArray(), |
|
|
|
float[] floatValues when dtype == TF_DataType.TF_DOUBLE => floatValues.Select(x => (double)x).ToArray(), |
|
|
|
float[,] float2DValues when dtype == TF_DataType.TF_DOUBLE => ConvertArray2D(float2DValues, Convert.ToDouble), |
|
|
|
double[] doubleValues when dtype == TF_DataType.TF_FLOAT => doubleValues.Select(x => (float)x).ToArray(), |
|
|
|
double[,] double2DValues when dtype == TF_DataType.TF_DOUBLE => ConvertArray2D(double2DValues, Convert.ToSingle), |
|
|
|
_ => Convert.ChangeType(values, new_system_dtype), |
|
|
|
}; |
|
|
|
dtype = values.GetDataType(); |
|
|
|
} |
|
|
|
|
|
|
|