Browse Source

fix Embedding layer.

tags/TimeSeries
Oceania2018 3 years ago
parent
commit
b048b62436
3 changed files with 6 additions and 29 deletions
  1. +0
    -1
      src/TensorFlowNET.Keras/Engine/Model.Predict.cs
  2. +3
    -1
      src/TensorFlowNET.Keras/Layers/Core/Embedding.cs
  3. +3
    -27
      test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs

+ 0
- 1
src/TensorFlowNET.Keras/Engine/Model.Predict.cs View File

@@ -60,7 +60,6 @@ namespace Tensorflow.Keras.Engine
// callbacks.on_predict_batch_end(end_step, {'outputs': batch_outputs})
}
GC.Collect();
GC.WaitForPendingFinalizers();
}
// callbacks.on_predict_end()
return outputs;


+ 3
- 1
src/TensorFlowNET.Keras/Layers/Core/Embedding.cs View File

@@ -38,7 +38,9 @@ namespace Tensorflow.Keras.Layers
: base(new LayerArgs // copy args
{
DType = args.DType,
Name = args.Name
Name = args.Name,
InputShape = args.InputShape,
BatchSize = args.BatchSize
})
{
this.args = args;


+ 3
- 27
test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs View File

@@ -82,37 +82,13 @@ namespace TensorFlowNET.Keras.UnitTest
/// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding
/// </summary>
[TestMethod]
public void Embedding_Simple()
{
var emb = keras.layers.Embedding(256, 12, input_length: 4);
var input_array = np.arange(12).reshape((3, 4)).astype(np.float32);
var output = emb.Apply(input_array);
Assert.AreEqual((3, 4, 12), output.shape);
}

/// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding
/// </summary>
[TestMethod]
[Ignore]
public void Embedding()
{
var model = keras.Sequential();
var layer = keras.layers.Embedding(7, 2, input_length: 4);
var layer = keras.layers.Embedding(1000, 64, input_length: 10);
model.add(layer);
// the model will take as input an integer matrix of size (batch,
// input_length).
// the largest integer (i.e. word index) in the input should be no larger
// than 999 (vocabulary size).
// now model.output_shape == (None, 10, 64), where None is the batch
// dimension.
var input_array = np.array(new int[,]
{
{ 1, 2, 3, 4 },
{ 2, 3, 4, 5 },
{ 3, 4, 5, 6 }
});
// model.compile("rmsprop", "mse");
var input_array = np.random.randint(1000, size: (32, 10));
model.compile("rmsprop", "mse", new[] { "accuracy" });
var output_array = model.predict(input_array);
Assert.AreEqual((32, 10, 64), output_array.shape);
}


Loading…
Cancel
Save