@@ -18,22 +18,22 @@ namespace Tensorflow.NumPy | |||||
=> new NDArray(array); | => new NDArray(array); | ||||
public unsafe static implicit operator bool(NDArray nd) | public unsafe static implicit operator bool(NDArray nd) | ||||
=> *(bool*)nd.data; | |||||
=> nd.dtype == TF_DataType.TF_BOOL ? *(bool*)nd.data : NDArrayConverter.Scalar<bool>(nd); | |||||
public unsafe static implicit operator byte(NDArray nd) | public unsafe static implicit operator byte(NDArray nd) | ||||
=> *(byte*)nd.data; | |||||
=> nd.dtype == TF_DataType.TF_INT8 ? *(byte*)nd.data : NDArrayConverter.Scalar<byte>(nd); | |||||
public unsafe static implicit operator int(NDArray nd) | public unsafe static implicit operator int(NDArray nd) | ||||
=> *(int*)nd.data; | |||||
=> nd.dtype == TF_DataType.TF_INT32 ? *(int*)nd.data : NDArrayConverter.Scalar<int>(nd); | |||||
public unsafe static implicit operator long(NDArray nd) | public unsafe static implicit operator long(NDArray nd) | ||||
=> *(long*)nd.data; | |||||
=> nd.dtype == TF_DataType.TF_INT64 ? *(long*)nd.data : NDArrayConverter.Scalar<long>(nd); | |||||
public unsafe static implicit operator float(NDArray nd) | public unsafe static implicit operator float(NDArray nd) | ||||
=> *(float*)nd.data; | |||||
=> nd.dtype == TF_DataType.TF_FLOAT ? *(float*)nd.data : NDArrayConverter.Scalar<float>(nd); | |||||
public unsafe static implicit operator double(NDArray nd) | public unsafe static implicit operator double(NDArray nd) | ||||
=> *(double*)nd.data; | |||||
=> nd.dtype == TF_DataType.TF_DOUBLE ? *(double*)nd.data : NDArrayConverter.Scalar<double>(nd); | |||||
public static implicit operator NDArray(bool value) | public static implicit operator NDArray(bool value) | ||||
=> new NDArray(value); | => new NDArray(value); | ||||
@@ -175,8 +175,8 @@ namespace Tensorflow.NumPy | |||||
unsafe void SetData(NDArray src, Slice[] slices, int[] indices, int currentNDim) | unsafe void SetData(NDArray src, Slice[] slices, int[] indices, int currentNDim) | ||||
{ | { | ||||
if (dtype != src.dtype) | if (dtype != src.dtype) | ||||
src = src.astype(dtype); | |||||
// throw new ArrayTypeMismatchException($"Required dtype {dtype} but {array.dtype} is assigned."); | |||||
// src = src.astype(dtype); | |||||
throw new ArrayTypeMismatchException($"Required dtype {dtype} but {src.dtype} is assigned."); | |||||
if (!slices.Any()) | if (!slices.Any()) | ||||
return; | return; | ||||
@@ -0,0 +1,34 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Linq; | |||||
using System.Text; | |||||
namespace Tensorflow.NumPy | |||||
{ | |||||
public class NDArrayConverter | |||||
{ | |||||
public unsafe static T Scalar<T>(NDArray nd) where T : unmanaged | |||||
=> nd.dtype switch | |||||
{ | |||||
TF_DataType.TF_FLOAT => Scalar<T>(*(float*)nd.data), | |||||
TF_DataType.TF_INT64 => Scalar<T>(*(long*)nd.data), | |||||
_ => throw new NotImplementedException("") | |||||
}; | |||||
static T Scalar<T>(float input) | |||||
=> Type.GetTypeCode(typeof(T)) switch | |||||
{ | |||||
TypeCode.Int32 => (T)Convert.ChangeType(input, TypeCode.Int32), | |||||
TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single), | |||||
_ => throw new NotImplementedException("") | |||||
}; | |||||
static T Scalar<T>(long input) | |||||
=> Type.GetTypeCode(typeof(T)) switch | |||||
{ | |||||
TypeCode.Int32 => (T)Convert.ChangeType(input, TypeCode.Int32), | |||||
TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single), | |||||
_ => throw new NotImplementedException("") | |||||
}; | |||||
} | |||||
} |
@@ -78,7 +78,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
_insufficient_data = false; | _insufficient_data = false; | ||||
} | } | ||||
int _infer_steps(int steps_per_epoch, IDatasetV2 dataset) | |||||
long _infer_steps(int steps_per_epoch, IDatasetV2 dataset) | |||||
{ | { | ||||
if (steps_per_epoch > -1) | if (steps_per_epoch > -1) | ||||
return steps_per_epoch; | return steps_per_epoch; | ||||