You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Word2Vec.cs 9.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. using NumSharp;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Linq;
  6. using System.Text;
  7. using Tensorflow;
  8. using TensorFlowNET.Examples.Utility;
  9. using static Tensorflow.Python;
  10. namespace TensorFlowNET.Examples
  11. {
  12. /// <summary>
  13. /// Implement Word2Vec algorithm to compute vector representations of words.
  14. /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/word2vec.py
  15. /// </summary>
  16. public class Word2Vec : IExample
  17. {
  18. public bool Enabled { get; set; } = true;
  19. public string Name => "Word2Vec";
  20. public bool IsImportingGraph { get; set; } = true;
  21. // Training Parameters
  22. float learning_rate = 0.1f;
  23. int batch_size = 128;
  24. int num_steps = 30000; //3000000;
  25. int display_step = 1000; //10000;
  26. int eval_step = 5000;//200000;
  27. // Evaluation Parameters
  28. string[] eval_words = new string[] { "five", "of", "going", "hardware", "american", "britain" };
  29. string[] text_words;
  30. List<WordId> word2id;
  31. int[] data;
  32. // Word2Vec Parameters
  33. int embedding_size = 200; // Dimension of the embedding vector
  34. int max_vocabulary_size = 50000; // Total number of different words in the vocabulary
  35. int min_occurrence = 10; // Remove all words that does not appears at least n times
  36. int skip_window = 3; // How many words to consider left and right
  37. int num_skips = 2; // How many times to reuse an input to generate a label
  38. int num_sampled = 64; // Number of negative examples to sample
  39. int data_index = 0;
  40. int top_k = 8; // number of nearest neighbors
  41. float average_loss = 0;
  42. public bool Run()
  43. {
  44. PrepareData();
  45. var graph = tf.Graph().as_default();
  46. tf.train.import_meta_graph($"graph{Path.DirectorySeparatorChar}word2vec.meta");
  47. // Input data
  48. Tensor X = graph.OperationByName("Placeholder");
  49. // Input label
  50. Tensor Y = graph.OperationByName("Placeholder_1");
  51. // Compute the average NCE loss for the batch
  52. Tensor loss_op = graph.OperationByName("Mean");
  53. // Define the optimizer
  54. var train_op = graph.OperationByName("GradientDescent");
  55. Tensor cosine_sim_op = graph.OperationByName("MatMul_1");
  56. // Initialize the variables (i.e. assign their default value)
  57. var init = tf.global_variables_initializer();
  58. with(tf.Session(graph), sess =>
  59. {
  60. // Run the initializer
  61. sess.run(init);
  62. var x_test = (from word in eval_words
  63. join id in word2id on word equals id.Word into wi
  64. from wi2 in wi.DefaultIfEmpty()
  65. select wi2 == null ? 0 : wi2.Id).ToArray();
  66. foreach (var step in range(1, num_steps + 1))
  67. {
  68. // Get a new batch of data
  69. var (batch_x, batch_y) = next_batch(batch_size, num_skips, skip_window);
  70. var result = sess.run(new ITensorOrOperation[] { train_op, loss_op }, new FeedItem(X, batch_x), new FeedItem(Y, batch_y));
  71. average_loss += result[1];
  72. if (step % display_step == 0 || step == 1)
  73. {
  74. if (step > 1)
  75. average_loss /= display_step;
  76. print($"Step {step}, Average Loss= {average_loss.ToString("F4")}");
  77. average_loss = 0;
  78. }
  79. // Evaluation
  80. if (step % eval_step == 0 || step == 1)
  81. {
  82. print("Evaluation...");
  83. var sim = sess.run(cosine_sim_op, new FeedItem(X, x_test));
  84. foreach(var i in range(len(eval_words)))
  85. {
  86. var nearest = (0f - sim[i]).argsort<float>()
  87. .Data<int>()
  88. .Skip(1)
  89. .Take(top_k)
  90. .ToArray();
  91. string log_str = $"\"{eval_words[i]}\" nearest neighbors:";
  92. foreach (var k in range(top_k))
  93. log_str = $"{log_str} {word2id.First(x => x.Id == nearest[k]).Word},";
  94. print(log_str);
  95. }
  96. }
  97. }
  98. });
  99. return average_loss < 100;
  100. }
  101. // Generate training batch for the skip-gram model
  102. private (NDArray, NDArray) next_batch(int batch_size, int num_skips, int skip_window)
  103. {
  104. var batch = np.ndarray((batch_size), dtype: np.int32);
  105. var labels = np.ndarray((batch_size, 1), dtype: np.int32);
  106. // get window size (words left and right + current one)
  107. int span = 2 * skip_window + 1;
  108. var buffer = new Queue<int>(span);
  109. if (data_index + span > data.Length)
  110. data_index = 0;
  111. data.Skip(data_index).Take(span).ToList().ForEach(x => buffer.Enqueue(x));
  112. data_index += span;
  113. foreach (var i in range(batch_size / num_skips))
  114. {
  115. var context_words = range(span).Where(x => x != skip_window).ToArray();
  116. var words_to_use = new int[] { 1, 6 };
  117. foreach(var (j, context_word) in enumerate(words_to_use))
  118. {
  119. batch[i * num_skips + j] = buffer.ElementAt(skip_window);
  120. labels[i * num_skips + j, 0] = buffer.ElementAt(context_word);
  121. }
  122. if (data_index == len(data))
  123. {
  124. //buffer.extend(data[0:span]);
  125. data_index = span;
  126. }
  127. else
  128. {
  129. buffer.Enqueue(data[data_index]);
  130. data_index += 1;
  131. }
  132. }
  133. // Backtrack a little bit to avoid skipping words in the end of a batch
  134. data_index = (data_index + len(data) - span) % len(data);
  135. return (batch, labels);
  136. }
  137. public void PrepareData()
  138. {
  139. // Download graph meta
  140. var url = "https://github.com/SciSharp/TensorFlow.NET/raw/master/graph/word2vec.meta";
  141. Web.Download(url, "graph", "word2vec.meta");
  142. // Download a small chunk of Wikipedia articles collection
  143. url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/text8.zip";
  144. Web.Download(url, "word2vec", "text8.zip");
  145. // Unzip the dataset file. Text has already been processed
  146. Compress.UnZip($"word2vec{Path.DirectorySeparatorChar}text8.zip", "word2vec");
  147. int wordId = 0;
  148. text_words = File.ReadAllText($"word2vec{Path.DirectorySeparatorChar}text8").Trim().ToLower().Split();
  149. // Build the dictionary and replace rare words with UNK token
  150. word2id = text_words.GroupBy(x => x)
  151. .Select(x => new WordId
  152. {
  153. Word = x.Key,
  154. Occurrence = x.Count()
  155. })
  156. .Where(x => x.Occurrence >= min_occurrence) // Remove samples with less than 'min_occurrence' occurrences
  157. .OrderByDescending(x => x.Occurrence) // Retrieve the most common words
  158. .Select(x => new WordId
  159. {
  160. Word = x.Word,
  161. Id = ++wordId, // Assign an id to each word
  162. Occurrence = x.Occurrence
  163. })
  164. .ToList();
  165. // Retrieve a word id, or assign it index 0 ('UNK') if not in dictionary
  166. data = (from word in text_words
  167. join id in word2id on word equals id.Word into wi
  168. from wi2 in wi.DefaultIfEmpty()
  169. select wi2 == null ? 0 : wi2.Id).ToArray();
  170. word2id.Insert(0, new WordId { Word = "UNK", Id = 0, Occurrence = data.Count(x => x == 0) });
  171. print($"Words count: {text_words.Length}");
  172. print($"Unique words: {text_words.Distinct().Count()}");
  173. print($"Vocabulary size: {word2id.Count}");
  174. print($"Most common words: {string.Join(", ", word2id.Take(10))}");
  175. }
  176. public Graph ImportGraph()
  177. {
  178. throw new NotImplementedException();
  179. }
  180. public Graph BuildGraph()
  181. {
  182. throw new NotImplementedException();
  183. }
  184. public void Train(Session sess)
  185. {
  186. throw new NotImplementedException();
  187. }
  188. public void Predict(Session sess)
  189. {
  190. throw new NotImplementedException();
  191. }
  192. public void Test(Session sess)
  193. {
  194. throw new NotImplementedException();
  195. }
  196. private class WordId
  197. {
  198. public string Word { get; set; }
  199. public int Id { get; set; }
  200. public int Occurrence { get; set; }
  201. public override string ToString()
  202. {
  203. return Word + " " + Id + " " + Occurrence;
  204. }
  205. }
  206. }
  207. }