Browse Source

throw exception from SetData

tags/TensorFlowOpLayer
Oceania2018 4 years ago
parent
commit
94601f5b89
4 changed files with 43 additions and 9 deletions
  1. +6
    -6
      src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs
  2. +2
    -2
      src/TensorFlowNET.Core/NumPy/NDArray.Index.cs
  3. +34
    -0
      src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs
  4. +1
    -1
      src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs

+ 6
- 6
src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs View File

@@ -18,22 +18,22 @@ namespace Tensorflow.NumPy
=> new NDArray(array);

public unsafe static implicit operator bool(NDArray nd)
=> *(bool*)nd.data;
=> nd.dtype == TF_DataType.TF_BOOL ? *(bool*)nd.data : NDArrayConverter.Scalar<bool>(nd);

public unsafe static implicit operator byte(NDArray nd)
=> *(byte*)nd.data;
=> nd.dtype == TF_DataType.TF_INT8 ? *(byte*)nd.data : NDArrayConverter.Scalar<byte>(nd);

public unsafe static implicit operator int(NDArray nd)
=> *(int*)nd.data;
=> nd.dtype == TF_DataType.TF_INT32 ? *(int*)nd.data : NDArrayConverter.Scalar<int>(nd);

public unsafe static implicit operator long(NDArray nd)
=> *(long*)nd.data;
=> nd.dtype == TF_DataType.TF_INT64 ? *(long*)nd.data : NDArrayConverter.Scalar<long>(nd);

public unsafe static implicit operator float(NDArray nd)
=> *(float*)nd.data;
=> nd.dtype == TF_DataType.TF_FLOAT ? *(float*)nd.data : NDArrayConverter.Scalar<float>(nd);

public unsafe static implicit operator double(NDArray nd)
=> *(double*)nd.data;
=> nd.dtype == TF_DataType.TF_DOUBLE ? *(double*)nd.data : NDArrayConverter.Scalar<double>(nd);

public static implicit operator NDArray(bool value)
=> new NDArray(value);


+ 2
- 2
src/TensorFlowNET.Core/NumPy/NDArray.Index.cs View File

@@ -175,8 +175,8 @@ namespace Tensorflow.NumPy
unsafe void SetData(NDArray src, Slice[] slices, int[] indices, int currentNDim)
{
if (dtype != src.dtype)
src = src.astype(dtype);
// throw new ArrayTypeMismatchException($"Required dtype {dtype} but {array.dtype} is assigned.");
// src = src.astype(dtype);
throw new ArrayTypeMismatchException($"Required dtype {dtype} but {src.dtype} is assigned.");

if (!slices.Any())
return;


+ 34
- 0
src/TensorFlowNET.Core/NumPy/NDArrayConverter.cs View File

@@ -0,0 +1,34 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Tensorflow.NumPy
{
public class NDArrayConverter
{
public unsafe static T Scalar<T>(NDArray nd) where T : unmanaged
=> nd.dtype switch
{
TF_DataType.TF_FLOAT => Scalar<T>(*(float*)nd.data),
TF_DataType.TF_INT64 => Scalar<T>(*(long*)nd.data),
_ => throw new NotImplementedException("")
};

static T Scalar<T>(float input)
=> Type.GetTypeCode(typeof(T)) switch
{
TypeCode.Int32 => (T)Convert.ChangeType(input, TypeCode.Int32),
TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single),
_ => throw new NotImplementedException("")
};

static T Scalar<T>(long input)
=> Type.GetTypeCode(typeof(T)) switch
{
TypeCode.Int32 => (T)Convert.ChangeType(input, TypeCode.Int32),
TypeCode.Single => (T)Convert.ChangeType(input, TypeCode.Single),
_ => throw new NotImplementedException("")
};
}
}

+ 1
- 1
src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs View File

@@ -78,7 +78,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
_insufficient_data = false;
}

int _infer_steps(int steps_per_epoch, IDatasetV2 dataset)
long _infer_steps(int steps_per_epoch, IDatasetV2 dataset)
{
if (steps_per_epoch > -1)
return steps_per_epoch;


Loading…
Cancel
Save