using NumSharp; using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Text; using Tensorflow; using TensorFlowNET.Examples.Utility; using static Tensorflow.Python; namespace TensorFlowNET.Examples { /// /// Implement Word2Vec algorithm to compute vector representations of words. /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/word2vec.py /// public class Word2Vec : IExample { public bool Enabled { get; set; } = true; public string Name => "Word2Vec"; public bool IsImportingGraph { get; set; } = true; // Training Parameters float learning_rate = 0.1f; int batch_size = 128; int num_steps = 30000; //3000000; int display_step = 1000; //10000; int eval_step = 5000;//200000; // Evaluation Parameters string[] eval_words = new string[] { "five", "of", "going", "hardware", "american", "britain" }; string[] text_words; List word2id; int[] data; // Word2Vec Parameters int embedding_size = 200; // Dimension of the embedding vector int max_vocabulary_size = 50000; // Total number of different words in the vocabulary int min_occurrence = 10; // Remove all words that does not appears at least n times int skip_window = 3; // How many words to consider left and right int num_skips = 2; // How many times to reuse an input to generate a label int num_sampled = 64; // Number of negative examples to sample int data_index = 0; int top_k = 8; // number of nearest neighbors float average_loss = 0; public bool Run() { PrepareData(); var graph = tf.Graph().as_default(); tf.train.import_meta_graph($"graph{Path.DirectorySeparatorChar}word2vec.meta"); // Input data Tensor X = graph.OperationByName("Placeholder"); // Input label Tensor Y = graph.OperationByName("Placeholder_1"); // Compute the average NCE loss for the batch Tensor loss_op = graph.OperationByName("Mean"); // Define the optimizer var train_op = graph.OperationByName("GradientDescent"); Tensor cosine_sim_op = graph.OperationByName("MatMul_1"); // Initialize the variables (i.e. assign their default value) var init = tf.global_variables_initializer(); with(tf.Session(graph), sess => { // Run the initializer sess.run(init); var x_test = (from word in eval_words join id in word2id on word equals id.Word into wi from wi2 in wi.DefaultIfEmpty() select wi2 == null ? 0 : wi2.Id).ToArray(); foreach (var step in range(1, num_steps + 1)) { // Get a new batch of data var (batch_x, batch_y) = next_batch(batch_size, num_skips, skip_window); var result = sess.run(new ITensorOrOperation[] { train_op, loss_op }, new FeedItem(X, batch_x), new FeedItem(Y, batch_y)); average_loss += result[1]; if (step % display_step == 0 || step == 1) { if (step > 1) average_loss /= display_step; print($"Step {step}, Average Loss= {average_loss.ToString("F4")}"); average_loss = 0; } // Evaluation if (step % eval_step == 0 || step == 1) { print("Evaluation..."); var sim = sess.run(cosine_sim_op, new FeedItem(X, x_test)); foreach(var i in range(len(eval_words))) { var nearest = (0f - sim[i]).argsort() .Data() .Skip(1) .Take(top_k) .ToArray(); string log_str = $"\"{eval_words[i]}\" nearest neighbors:"; foreach (var k in range(top_k)) log_str = $"{log_str} {word2id.First(x => x.Id == nearest[k]).Word},"; print(log_str); } } } }); return average_loss < 100; } // Generate training batch for the skip-gram model private (NDArray, NDArray) next_batch(int batch_size, int num_skips, int skip_window) { var batch = np.ndarray((batch_size), dtype: np.int32); var labels = np.ndarray((batch_size, 1), dtype: np.int32); // get window size (words left and right + current one) int span = 2 * skip_window + 1; var buffer = new Queue(span); if (data_index + span > data.Length) data_index = 0; data.Skip(data_index).Take(span).ToList().ForEach(x => buffer.Enqueue(x)); data_index += span; foreach (var i in range(batch_size / num_skips)) { var context_words = range(span).Where(x => x != skip_window).ToArray(); var words_to_use = new int[] { 1, 6 }; foreach(var (j, context_word) in enumerate(words_to_use)) { batch[i * num_skips + j] = buffer.ElementAt(skip_window); labels[i * num_skips + j, 0] = buffer.ElementAt(context_word); } if (data_index == len(data)) { //buffer.extend(data[0:span]); data_index = span; } else { buffer.Enqueue(data[data_index]); data_index += 1; } } // Backtrack a little bit to avoid skipping words in the end of a batch data_index = (data_index + len(data) - span) % len(data); return (batch, labels); } public void PrepareData() { // Download graph meta var url = "https://github.com/SciSharp/TensorFlow.NET/raw/master/graph/word2vec.meta"; Web.Download(url, "graph", "word2vec.meta"); // Download a small chunk of Wikipedia articles collection url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/text8.zip"; Web.Download(url, "word2vec", "text8.zip"); // Unzip the dataset file. Text has already been processed Compress.UnZip($"word2vec{Path.DirectorySeparatorChar}text8.zip", "word2vec"); int wordId = 0; text_words = File.ReadAllText($"word2vec{Path.DirectorySeparatorChar}text8").Trim().ToLower().Split(); // Build the dictionary and replace rare words with UNK token word2id = text_words.GroupBy(x => x) .Select(x => new WordId { Word = x.Key, Occurrence = x.Count() }) .Where(x => x.Occurrence >= min_occurrence) // Remove samples with less than 'min_occurrence' occurrences .OrderByDescending(x => x.Occurrence) // Retrieve the most common words .Select(x => new WordId { Word = x.Word, Id = ++wordId, // Assign an id to each word Occurrence = x.Occurrence }) .ToList(); // Retrieve a word id, or assign it index 0 ('UNK') if not in dictionary data = (from word in text_words join id in word2id on word equals id.Word into wi from wi2 in wi.DefaultIfEmpty() select wi2 == null ? 0 : wi2.Id).ToArray(); word2id.Insert(0, new WordId { Word = "UNK", Id = 0, Occurrence = data.Count(x => x == 0) }); print($"Words count: {text_words.Length}"); print($"Unique words: {text_words.Distinct().Count()}"); print($"Vocabulary size: {word2id.Count}"); print($"Most common words: {string.Join(", ", word2id.Take(10))}"); } public Graph ImportGraph() { throw new NotImplementedException(); } public Graph BuildGraph() { throw new NotImplementedException(); } public void Train(Session sess) { throw new NotImplementedException(); } public void Predict(Session sess) { throw new NotImplementedException(); } public void Test(Session sess) { throw new NotImplementedException(); } private class WordId { public string Word { get; set; } public int Id { get; set; } public int Occurrence { get; set; } public override string ToString() { return Word + " " + Id + " " + Occurrence; } } } }