Browse Source

add more implicit operator for NDArray and UnitTest for `keras.datasets.imdb`

tags/v0.110.0-LSTM-Model
lingbai-kong 2 years ago
parent
commit
e749aaeaae
2 changed files with 21 additions and 0 deletions
  1. +6
    -0
      src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs
  2. +15
    -0
      test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs

+ 6
- 0
src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs View File

@@ -107,9 +107,15 @@ namespace Tensorflow.NumPy
public static implicit operator NDArray(bool value)
=> new NDArray(value);

public static implicit operator NDArray(byte value)
=> new NDArray(value);

public static implicit operator NDArray(int value)
=> new NDArray(value);

public static implicit operator NDArray(long value)
=> new NDArray(value);

public static implicit operator NDArray(float value)
=> new NDArray(value);



+ 15
- 0
test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs View File

@@ -2,6 +2,7 @@
using System;
using System.Linq;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace TensorFlowNET.UnitTest.Dataset
{
@@ -195,5 +196,19 @@ namespace TensorFlowNET.UnitTest.Dataset

Assert.IsFalse(allEqual);
}
[TestMethod]
public void GetData()
{
var vocab_size = 20000;
var dataset = keras.datasets.imdb.load_data(num_words: vocab_size);
var x_train = dataset.Train.Item1;
Assert.AreEqual(x_train.dims[0], 25000);
var y_train = dataset.Train.Item2;
Assert.AreEqual(y_train.dims[0], 25000);
var x_val = dataset.Test.Item1;
Assert.AreEqual(x_val.dims[0], 25000);
var y_val = dataset.Test.Item2;
Assert.AreEqual(y_val.dims[0], 25000);
}
}
}

Loading…
Cancel
Save