diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs index 664ba7f9..88dca940 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs +++ b/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs @@ -18,22 +18,22 @@ namespace Tensorflow.NumPy => new NDArray(array); public unsafe static implicit operator bool(NDArray nd) - => *(bool*)nd.data; + => nd.dtype == TF_DataType.TF_BOOL ? *(bool*)nd.data : NDArrayConverter.Scalar(nd); public unsafe static implicit operator byte(NDArray nd) - => *(byte*)nd.data; + => nd.dtype == TF_DataType.TF_INT8 ? *(byte*)nd.data : NDArrayConverter.Scalar(nd); public unsafe static implicit operator int(NDArray nd) - => *(int*)nd.data; + => nd.dtype == TF_DataType.TF_INT32 ? *(int*)nd.data : NDArrayConverter.Scalar(nd); public unsafe static implicit operator long(NDArray nd) - => *(long*)nd.data; + => nd.dtype == TF_DataType.TF_INT64 ? *(long*)nd.data : NDArrayConverter.Scalar(nd); public unsafe static implicit operator float(NDArray nd) - => *(float*)nd.data; + => nd.dtype == TF_DataType.TF_FLOAT ? *(float*)nd.data : NDArrayConverter.Scalar(nd); public unsafe static implicit operator double(NDArray nd) - => *(double*)nd.data; + => nd.dtype == TF_DataType.TF_DOUBLE ? *(double*)nd.data : NDArrayConverter.Scalar(nd); public static implicit operator NDArray(bool value) => new NDArray(value); diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs index 9797e861..28471854 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs +++ b/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs @@ -175,8 +175,8 @@ namespace Tensorflow.NumPy unsafe void SetData(NDArray src, Slice[] slices, int[] indices, int currentNDim) { if (dtype != src.dtype) - src = src.astype(dtype); - // throw new ArrayTypeMismatchException($"Required dtype {dtype} but {array.dtype} is assigned."); + // src = src.astype(dtype); + throw new ArrayTypeMismatchException($"Required dtype {dtype} but {src.dtype} is assigned."); if (!slices.Any()) return; diff --git a/src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs b/src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs new file mode 100644 index 00000000..6e1b8da1 --- /dev/null +++ b/src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs @@ -0,0 +1,34 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +namespace Tensorflow.NumPy +{ + public class NDArrayConverter + { + public unsafe static T Scalar(NDArray nd) where T : unmanaged + => nd.dtype switch + { + TF_DataType.TF_FLOAT => Scalar(*(float*)nd.data), + TF_DataType.TF_INT64 => Scalar(*(long*)nd.data), + _ => throw new NotImplementedException("") + }; + + static T Scalar(float input) + => Type.GetTypeCode(typeof(T)) switch + { + TypeCode.Int32 => (T)Convert.ChangeType(input, TypeCode.Int32), + TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single), + _ => throw new NotImplementedException("") + }; + + static T Scalar(long input) + => Type.GetTypeCode(typeof(T)) switch + { + TypeCode.Int32 => (T)Convert.ChangeType(input, TypeCode.Int32), + TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single), + _ => throw new NotImplementedException("") + }; + } +} diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs index c9c44b8f..1ddddd11 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs @@ -78,7 +78,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters _insufficient_data = false; } - int _infer_steps(int steps_per_epoch, IDatasetV2 dataset) + long _infer_steps(int steps_per_epoch, IDatasetV2 dataset) { if (steps_per_epoch > -1) return steps_per_epoch;