Browse Source

fix inconsistent shape error while training Embedding layer.

tags/v0.110.4-Transformer-Model
lingbai-kong 2 years ago
parent
commit
f61ab520c9
2 changed files with 25 additions and 1 deletions
  1. +14
    -1
      src/TensorFlowNET.Core/Framework/IndexedSlices.cs
  2. +11
    -0
      test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs

+ 14
- 1
src/TensorFlowNET.Core/Framework/IndexedSlices.cs View File

@@ -49,12 +49,25 @@ namespace Tensorflow.Framework

public static implicit operator Tensor(IndexedSlices indexedSlices)
{
return indexedSlices.values;
return _indexed_slices_to_tensor(indexedSlices);
}

public static implicit operator IndexedSlices(Tensor tensor)
{
return tensor.Tag as IndexedSlices;
}

/// <summary>
/// Converts an IndexedSlices object `value` to a Tensor.
/// </summary>
/// <param name="indexedSlices"></param>
/// <param name="dtype"></param>
/// <param name="name"></param>
/// <param name="as_ref"></param>
/// <returns></returns>
public static Tensor _indexed_slices_to_tensor(IndexedSlices indexedSlices, TF_DataType dtype = TF_DataType.DtInvalid, String name = "", bool as_ref = false)
{
return gen_math_ops.unsorted_segment_sum(indexedSlices.values, indexedSlices.indices, indexedSlices.dense_shape.slice(0));
}
}
}

+ 11
- 0
test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs View File

@@ -110,6 +110,17 @@ namespace Tensorflow.Keras.UnitTest.Layers
var output_array = model.predict(input_array);
Assert.AreEqual((32, 10, 64), output_array.shape);
}
[TestMethod]
public void EmbeddingGrad()
{
var inputs = keras.layers.Input(shape: new[] { 32, 10 });
var outputs = keras.layers.Embedding(1000, 64, input_length: 10).Apply(inputs);
var model = keras.Model(inputs: inputs, outputs: outputs);
var input_array = np.random.randint(1000, size: (1, 32, 10));
var output_array = np.random.random(size: (1, 32, 10, 64));
model.compile("rmsprop", "mse", new[] { "accuracy" });
model.fit(input_array, output_array);
}

/// <summary>
/// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense


Loading…
Cancel
Save