You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Word2Vec.cs 4.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Linq;
  5. using System.Text;
  6. using Tensorflow;
  7. using TensorFlowNET.Examples.Utility;
  8. namespace TensorFlowNET.Examples
  9. {
  10. /// <summary>
  11. /// Implement Word2Vec algorithm to compute vector representations of words.
  12. /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/word2vec.py
  13. /// </summary>
  14. public class Word2Vec : Python, IExample
  15. {
  16. public int Priority => 12;
  17. public bool Enabled { get; set; } = true;
  18. public string Name => "Word2Vec";
  19. public bool ImportGraph { get; set; } = true;
  20. // Training Parameters
  21. float learning_rate = 0.1f;
  22. int batch_size = 128;
  23. int num_steps = 3000000;
  24. int display_step = 10000;
  25. int eval_step = 200000;
  26. // Evaluation Parameters
  27. string[] eval_words = new string[] { "five", "of", "going", "hardware", "american", "britain" };
  28. string[] text_words;
  29. // Word2Vec Parameters
  30. int embedding_size = 200; // Dimension of the embedding vector
  31. int max_vocabulary_size = 50000; // Total number of different words in the vocabulary
  32. int min_occurrence = 10; // Remove all words that does not appears at least n times
  33. int skip_window = 3; // How many words to consider left and right
  34. int num_skips = 2; // How many times to reuse an input to generate a label
  35. int num_sampled = 64; // Number of negative examples to sample
  36. int data_index;
  37. public bool Run()
  38. {
  39. PrepareData();
  40. var graph = tf.Graph().as_default();
  41. tf.train.import_meta_graph("graph/word2vec.meta");
  42. // Initialize the variables (i.e. assign their default value)
  43. var init = tf.global_variables_initializer();
  44. with(tf.Session(graph), sess =>
  45. {
  46. sess.run(init);
  47. });
  48. return false;
  49. }
  50. // Generate training batch for the skip-gram model
  51. private void next_batch()
  52. {
  53. }
  54. public void PrepareData()
  55. {
  56. // Download graph meta
  57. var url = "https://github.com/SciSharp/TensorFlow.NET/raw/master/graph/word2vec.meta";
  58. Web.Download(url, "graph", "word2vec.meta");
  59. // Download a small chunk of Wikipedia articles collection
  60. url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/text8.zip";
  61. Web.Download(url, "word2vec", "text8.zip");
  62. // Unzip the dataset file. Text has already been processed
  63. Compress.UnZip(@"word2vec\text8.zip", "word2vec");
  64. int wordId = 0;
  65. text_words = File.ReadAllText(@"word2vec\text8").Trim().ToLower().Split();
  66. // Build the dictionary and replace rare words with UNK token
  67. var word2id = text_words.GroupBy(x => x)
  68. .Select(x => new WordId
  69. {
  70. Word = x.Key,
  71. Occurrence = x.Count()
  72. })
  73. .Where(x => x.Occurrence >= min_occurrence) // Remove samples with less than 'min_occurrence' occurrences
  74. .OrderByDescending(x => x.Occurrence) // Retrieve the most common words
  75. .Select(x => new WordId
  76. {
  77. Word = x.Word,
  78. Id = ++wordId, // Assign an id to each word
  79. Occurrence = x.Occurrence
  80. })
  81. .ToList();
  82. // Retrieve a word id, or assign it index 0 ('UNK') if not in dictionary
  83. var data = (from word in text_words
  84. join id in word2id on word equals id.Word into wi
  85. from wi2 in wi.DefaultIfEmpty()
  86. select wi2 == null ? 0 : wi2.Id).ToList();
  87. word2id.Insert(0, new WordId { Word = "UNK", Id = 0, Occurrence = data.Count(x => x == 0) });
  88. print($"Words count: {text_words.Length}");
  89. print($"Unique words: {text_words.Distinct().Count()}");
  90. print($"Vocabulary size: {word2id.Count}");
  91. print($"Most common words: {string.Join(", ", word2id.Take(10))}");
  92. }
  93. private class WordId
  94. {
  95. public string Word { get; set; }
  96. public int Id { get; set; }
  97. public int Occurrence { get; set; }
  98. public override string ToString()
  99. {
  100. return Word + " " + Id + " " + Occurrence;
  101. }
  102. }
  103. }
  104. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。