From b3ce158ec3304469bf776bc582b847e685a9df73 Mon Sep 17 00:00:00 2001 From: novikov-alexander <79649566+novikov-alexander@users.noreply.github.com> Date: Fri, 14 Jun 2024 14:40:06 +0300 Subject: [PATCH] Update tensor_util.cs --- src/TensorFlowNET.Core/Tensors/tensor_util.cs | 40 +++++++++++++------ 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index f688d4d5..f2003c9d 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -1,4 +1,4 @@ -/***************************************************************************** +/***************************************************************************** Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -135,6 +135,23 @@ namespace Tensorflow TF_DataType.TF_QINT32 }; + private static TOut[,] ConvertArray2D(TIn[,] inputArray, Func converter) + { + var rows = inputArray.GetLength(0); + var cols = inputArray.GetLength(1); + var outputArray = new TOut[rows, cols]; + + for (var i = 0; i < rows; i++) + { + for (var j = 0; j < cols; j++) + { + outputArray[i, j] = converter(inputArray[i, j]); + } + } + + return outputArray; + } + /// /// Create a TensorProto, invoked in graph mode /// @@ -157,19 +174,16 @@ namespace Tensorflow else if(origin_dtype != dtype) { var new_system_dtype = dtype.as_system_dtype(); - if (values is long[] long_values) - { - if (dtype == TF_DataType.TF_INT32) - values = long_values.Select(x => (int)Convert.ChangeType(x, new_system_dtype)).ToArray(); - } - else if (values is double[] double_values) + + values = values switch { - if (dtype == TF_DataType.TF_FLOAT) - values = double_values.Select(x => (float)Convert.ChangeType(x, new_system_dtype)).ToArray(); - } - else - values = Convert.ChangeType(values, new_system_dtype); - + long[] longValues when dtype == TF_DataType.TF_INT32 => longValues.Select(x => (int)x).ToArray(), + float[] floatValues when dtype == TF_DataType.TF_DOUBLE => floatValues.Select(x => (double)x).ToArray(), + float[,] float2DValues when dtype == TF_DataType.TF_DOUBLE => ConvertArray2D(float2DValues, Convert.ToDouble), + double[] doubleValues when dtype == TF_DataType.TF_FLOAT => doubleValues.Select(x => (float)x).ToArray(), + double[,] double2DValues when dtype == TF_DataType.TF_DOUBLE => ConvertArray2D(double2DValues, Convert.ToSingle), + _ => Convert.ChangeType(values, new_system_dtype), + }; dtype = values.GetDataType(); }