You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Word2Vec.cs 9.4 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. using NumSharp;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Linq;
  6. using Tensorflow;
  7. using TensorFlowNET.Examples.Utility;
  8. using static Tensorflow.Python;
  9. namespace TensorFlowNET.Examples
  10. {
  11. /// <summary>
  12. /// Implement Word2Vec algorithm to compute vector representations of words.
  13. /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/word2vec.py
  14. /// </summary>
  15. public class Word2Vec : IExample
  16. {
  17. public bool Enabled { get; set; } = true;
  18. public string Name => "Word2Vec";
  19. public bool IsImportingGraph { get; set; } = true;
  20. // Training Parameters
  21. float learning_rate = 0.1f;
  22. int batch_size = 128;
  23. int num_steps = 30000; //3000000;
  24. int display_step = 1000; //10000;
  25. int eval_step = 5000;//200000;
  26. // Evaluation Parameters
  27. string[] eval_words = new string[] { "five", "of", "going", "hardware", "american", "britain" };
  28. string[] text_words;
  29. List<WordId> word2id;
  30. int[] data;
  31. // Word2Vec Parameters
  32. int embedding_size = 200; // Dimension of the embedding vector
  33. int max_vocabulary_size = 50000; // Total number of different words in the vocabulary
  34. int min_occurrence = 10; // Remove all words that does not appears at least n times
  35. int skip_window = 3; // How many words to consider left and right
  36. int num_skips = 2; // How many times to reuse an input to generate a label
  37. int num_sampled = 64; // Number of negative examples to sample
  38. int data_index = 0;
  39. int top_k = 8; // number of nearest neighbors
  40. float average_loss = 0;
  41. public bool Run()
  42. {
  43. PrepareData();
  44. var graph = tf.Graph().as_default();
  45. tf.train.import_meta_graph($"graph{Path.DirectorySeparatorChar}word2vec.meta");
  46. // Input data
  47. Tensor X = graph.OperationByName("Placeholder");
  48. // Input label
  49. Tensor Y = graph.OperationByName("Placeholder_1");
  50. // Compute the average NCE loss for the batch
  51. Tensor loss_op = graph.OperationByName("Mean");
  52. // Define the optimizer
  53. var train_op = graph.OperationByName("GradientDescent");
  54. Tensor cosine_sim_op = graph.OperationByName("MatMul_1");
  55. // Initialize the variables (i.e. assign their default value)
  56. var init = tf.global_variables_initializer();
  57. using (var sess = tf.Session(graph))
  58. {
  59. // Run the initializer
  60. sess.run(init);
  61. var x_test = (from word in eval_words
  62. join id in word2id on word equals id.Word into wi
  63. from wi2 in wi.DefaultIfEmpty()
  64. select wi2 == null ? 0 : wi2.Id).ToArray();
  65. foreach (var step in range(1, num_steps + 1))
  66. {
  67. // Get a new batch of data
  68. var (batch_x, batch_y) = next_batch(batch_size, num_skips, skip_window);
  69. (_, float loss) = sess.run((train_op, loss_op), (X, batch_x), (Y, batch_y));
  70. average_loss += loss;
  71. if (step % display_step == 0 || step == 1)
  72. {
  73. if (step > 1)
  74. average_loss /= display_step;
  75. print($"Step {step}, Average Loss= {average_loss.ToString("F4")}");
  76. average_loss = 0;
  77. }
  78. // Evaluation
  79. if (step % eval_step == 0 || step == 1)
  80. {
  81. print("Evaluation...");
  82. var sim = sess.run(cosine_sim_op, (X, x_test));
  83. foreach(var i in range(len(eval_words)))
  84. {
  85. var nearest = (0f - sim[i]).argsort<float>()
  86. .Data<int>()
  87. .Skip(1)
  88. .Take(top_k)
  89. .ToArray();
  90. string log_str = $"\"{eval_words[i]}\" nearest neighbors:";
  91. foreach (var k in range(top_k))
  92. log_str = $"{log_str} {word2id.First(x => x.Id == nearest[k]).Word},";
  93. print(log_str);
  94. }
  95. }
  96. }
  97. }
  98. return average_loss < 100;
  99. }
  100. // Generate training batch for the skip-gram model
  101. private (NDArray, NDArray) next_batch(int batch_size, int num_skips, int skip_window)
  102. {
  103. var batch = np.ndarray(new Shape(batch_size), dtype: np.int32);
  104. var labels = np.ndarray((batch_size, 1), dtype: np.int32);
  105. // get window size (words left and right + current one)
  106. int span = 2 * skip_window + 1;
  107. var buffer = new Queue<int>(span);
  108. if (data_index + span > data.Length)
  109. data_index = 0;
  110. data.Skip(data_index).Take(span).ToList().ForEach(x => buffer.Enqueue(x));
  111. data_index += span;
  112. foreach (var i in range(batch_size / num_skips))
  113. {
  114. var context_words = range(span).Where(x => x != skip_window).ToArray();
  115. var words_to_use = new int[] { 1, 6 };
  116. foreach(var (j, context_word) in enumerate(words_to_use))
  117. {
  118. batch[i * num_skips + j] = buffer.ElementAt(skip_window);
  119. labels[i * num_skips + j, 0] = buffer.ElementAt(context_word);
  120. }
  121. if (data_index == len(data))
  122. {
  123. //buffer.extend(data[0:span]);
  124. data_index = span;
  125. }
  126. else
  127. {
  128. buffer.Enqueue(data[data_index]);
  129. data_index += 1;
  130. }
  131. }
  132. // Backtrack a little bit to avoid skipping words in the end of a batch
  133. data_index = (data_index + len(data) - span) % len(data);
  134. return (batch, labels);
  135. }
  136. public void PrepareData()
  137. {
  138. // Download graph meta
  139. var url = "https://github.com/SciSharp/TensorFlow.NET/raw/master/graph/word2vec.meta";
  140. Web.Download(url, "graph", "word2vec.meta");
  141. // Download a small chunk of Wikipedia articles collection
  142. url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/text8.zip";
  143. Web.Download(url, "word2vec", "text8.zip");
  144. // Unzip the dataset file. Text has already been processed
  145. Compress.UnZip($"word2vec{Path.DirectorySeparatorChar}text8.zip", "word2vec");
  146. int wordId = 0;
  147. text_words = File.ReadAllText($"word2vec{Path.DirectorySeparatorChar}text8").Trim().ToLower().Split();
  148. // Build the dictionary and replace rare words with UNK token
  149. word2id = text_words.GroupBy(x => x)
  150. .Select(x => new WordId
  151. {
  152. Word = x.Key,
  153. Occurrence = x.Count()
  154. })
  155. .Where(x => x.Occurrence >= min_occurrence) // Remove samples with less than 'min_occurrence' occurrences
  156. .OrderByDescending(x => x.Occurrence) // Retrieve the most common words
  157. .Select(x => new WordId
  158. {
  159. Word = x.Word,
  160. Id = ++wordId, // Assign an id to each word
  161. Occurrence = x.Occurrence
  162. })
  163. .ToList();
  164. // Retrieve a word id, or assign it index 0 ('UNK') if not in dictionary
  165. data = (from word in text_words
  166. join id in word2id on word equals id.Word into wi
  167. from wi2 in wi.DefaultIfEmpty()
  168. select wi2 == null ? 0 : wi2.Id).ToArray();
  169. word2id.Insert(0, new WordId { Word = "UNK", Id = 0, Occurrence = data.Count(x => x == 0) });
  170. print($"Words count: {text_words.Length}");
  171. print($"Unique words: {text_words.Distinct().Count()}");
  172. print($"Vocabulary size: {word2id.Count}");
  173. print($"Most common words: {string.Join(", ", word2id.Take(10))}");
  174. }
  175. public Graph ImportGraph()
  176. {
  177. throw new NotImplementedException();
  178. }
  179. public Graph BuildGraph()
  180. {
  181. throw new NotImplementedException();
  182. }
  183. public void Train(Session sess)
  184. {
  185. throw new NotImplementedException();
  186. }
  187. public void Predict(Session sess)
  188. {
  189. throw new NotImplementedException();
  190. }
  191. public void Test(Session sess)
  192. {
  193. throw new NotImplementedException();
  194. }
  195. private class WordId
  196. {
  197. public string Word { get; set; }
  198. public int Id { get; set; }
  199. public int Occurrence { get; set; }
  200. public override string ToString()
  201. {
  202. return Word + " " + Id + " " + Occurrence;
  203. }
  204. }
  205. }
  206. }