You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

PreprocessingTests.cs 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Linq;
  4. using System.Collections.Generic;
  5. using System.Text;
  6. using NumSharp;
  7. using static Tensorflow.KerasApi;
  8. using Tensorflow;
  9. using Tensorflow.Keras.Datasets;
  10. namespace TensorFlowNET.Keras.UnitTest
  11. {
  12. [TestClass]
  13. public class PreprocessingTests : EagerModeTestBase
  14. {
  15. private readonly string[] texts = new string[] {
  16. "It was the best of times, it was the worst of times.",
  17. "this is a new dawn, an era to follow the previous era. It can not be said to start anew.",
  18. "It was the best of times, it was the worst of times.",
  19. "this is a new dawn, an era to follow the previous era.",
  20. };
  21. private const string OOV = "<OOV>";
  22. [TestMethod]
  23. public void TokenizeWithNoOOV()
  24. {
  25. var tokenizer = keras.preprocessing.text.Tokenizer(lower: true);
  26. tokenizer.fit_on_texts(texts);
  27. Assert.AreEqual(23, tokenizer.word_index.Count);
  28. Assert.AreEqual(7, tokenizer.word_index["worst"]);
  29. Assert.AreEqual(12, tokenizer.word_index["dawn"]);
  30. Assert.AreEqual(16, tokenizer.word_index["follow"]);
  31. }
  32. [TestMethod]
  33. public void TokenizeWithOOV()
  34. {
  35. var tokenizer = keras.preprocessing.text.Tokenizer(lower: true, oov_token: OOV);
  36. tokenizer.fit_on_texts(texts);
  37. Assert.AreEqual(24, tokenizer.word_index.Count);
  38. Assert.AreEqual(1, tokenizer.word_index[OOV]);
  39. Assert.AreEqual(8, tokenizer.word_index["worst"]);
  40. Assert.AreEqual(13, tokenizer.word_index["dawn"]);
  41. Assert.AreEqual(17, tokenizer.word_index["follow"]);
  42. }
  43. [TestMethod]
  44. public void PadSequencesWithDefaults()
  45. {
  46. var tokenizer = keras.preprocessing.text.Tokenizer(lower: true, oov_token: OOV);
  47. tokenizer.fit_on_texts(texts);
  48. var sequences = tokenizer.texts_to_sequences(texts);
  49. var padded = keras.preprocessing.sequence.pad_sequences(sequences);
  50. Assert.AreEqual(4, padded.shape[0]);
  51. Assert.AreEqual(20, padded.shape[1]);
  52. var firstRow = padded[0];
  53. var secondRow = padded[1];
  54. Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 17].GetInt32());
  55. for (var i = 0; i < 8; i++)
  56. Assert.AreEqual(0, padded[0, i].GetInt32());
  57. Assert.AreEqual(tokenizer.word_index["previous"], padded[1, 10].GetInt32());
  58. for (var i = 0; i < 20; i++)
  59. Assert.AreNotEqual(0, padded[1, i].GetInt32());
  60. }
  61. [TestMethod]
  62. public void PadSequencesPrePaddingTrunc()
  63. {
  64. var tokenizer = keras.preprocessing.text.Tokenizer(lower: true, oov_token: OOV);
  65. tokenizer.fit_on_texts(texts);
  66. var sequences = tokenizer.texts_to_sequences(texts);
  67. var padded = keras.preprocessing.sequence.pad_sequences(sequences,maxlen:15);
  68. Assert.AreEqual(4, padded.shape[0]);
  69. Assert.AreEqual(15, padded.shape[1]);
  70. var firstRow = padded[0];
  71. var secondRow = padded[1];
  72. Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 12].GetInt32());
  73. for (var i = 0; i < 3; i++)
  74. Assert.AreEqual(0, padded[0, i].GetInt32());
  75. Assert.AreEqual(tokenizer.word_index["previous"], padded[1, 5].GetInt32());
  76. for (var i = 0; i < 15; i++)
  77. Assert.AreNotEqual(0, padded[1, i].GetInt32());
  78. }
  79. [TestMethod]
  80. public void PadSequencesPostPaddingTrunc()
  81. {
  82. var tokenizer = keras.preprocessing.text.Tokenizer(lower: true, oov_token: OOV);
  83. tokenizer.fit_on_texts(texts);
  84. var sequences = tokenizer.texts_to_sequences(texts);
  85. var padded = keras.preprocessing.sequence.pad_sequences(sequences, maxlen: 15, padding: "post", truncating: "post");
  86. Assert.AreEqual(4, padded.shape[0]);
  87. Assert.AreEqual(15, padded.shape[1]);
  88. var firstRow = padded[0];
  89. var secondRow = padded[1];
  90. Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 9].GetInt32());
  91. for (var i = 12; i < 15; i++)
  92. Assert.AreEqual(0, padded[0, i].GetInt32());
  93. Assert.AreEqual(tokenizer.word_index["previous"], padded[1, 10].GetInt32());
  94. for (var i = 0; i < 15; i++)
  95. Assert.AreNotEqual(0, padded[1, i].GetInt32());
  96. }
  97. }
  98. }