using NumSharp;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using Tensorflow;
using TensorFlowNET.Examples.Utility;
using static Tensorflow.Python;
namespace TensorFlowNET.Examples
{
///
/// Implement Word2Vec algorithm to compute vector representations of words.
/// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/word2vec.py
///
public class Word2Vec : IExample
{
public bool Enabled { get; set; } = true;
public string Name => "Word2Vec";
public bool IsImportingGraph { get; set; } = true;
// Training Parameters
float learning_rate = 0.1f;
int batch_size = 128;
int num_steps = 30000; //3000000;
int display_step = 1000; //10000;
int eval_step = 5000;//200000;
// Evaluation Parameters
string[] eval_words = new string[] { "five", "of", "going", "hardware", "american", "britain" };
string[] text_words;
List word2id;
int[] data;
// Word2Vec Parameters
int embedding_size = 200; // Dimension of the embedding vector
int max_vocabulary_size = 50000; // Total number of different words in the vocabulary
int min_occurrence = 10; // Remove all words that does not appears at least n times
int skip_window = 3; // How many words to consider left and right
int num_skips = 2; // How many times to reuse an input to generate a label
int num_sampled = 64; // Number of negative examples to sample
int data_index = 0;
int top_k = 8; // number of nearest neighbors
float average_loss = 0;
public bool Run()
{
PrepareData();
var graph = tf.Graph().as_default();
tf.train.import_meta_graph($"graph{Path.DirectorySeparatorChar}word2vec.meta");
// Input data
Tensor X = graph.OperationByName("Placeholder");
// Input label
Tensor Y = graph.OperationByName("Placeholder_1");
// Compute the average NCE loss for the batch
Tensor loss_op = graph.OperationByName("Mean");
// Define the optimizer
var train_op = graph.OperationByName("GradientDescent");
Tensor cosine_sim_op = graph.OperationByName("MatMul_1");
// Initialize the variables (i.e. assign their default value)
var init = tf.global_variables_initializer();
with(tf.Session(graph), sess =>
{
// Run the initializer
sess.run(init);
var x_test = (from word in eval_words
join id in word2id on word equals id.Word into wi
from wi2 in wi.DefaultIfEmpty()
select wi2 == null ? 0 : wi2.Id).ToArray();
foreach (var step in range(1, num_steps + 1))
{
// Get a new batch of data
var (batch_x, batch_y) = next_batch(batch_size, num_skips, skip_window);
var result = sess.run(new ITensorOrOperation[] { train_op, loss_op }, new FeedItem(X, batch_x), new FeedItem(Y, batch_y));
average_loss += result[1];
if (step % display_step == 0 || step == 1)
{
if (step > 1)
average_loss /= display_step;
print($"Step {step}, Average Loss= {average_loss.ToString("F4")}");
average_loss = 0;
}
// Evaluation
if (step % eval_step == 0 || step == 1)
{
print("Evaluation...");
var sim = sess.run(cosine_sim_op, new FeedItem(X, x_test));
foreach(var i in range(len(eval_words)))
{
var nearest = (0f - sim[i]).argsort()
.Data()
.Skip(1)
.Take(top_k)
.ToArray();
string log_str = $"\"{eval_words[i]}\" nearest neighbors:";
foreach (var k in range(top_k))
log_str = $"{log_str} {word2id.First(x => x.Id == nearest[k]).Word},";
print(log_str);
}
}
}
});
return average_loss < 100;
}
// Generate training batch for the skip-gram model
private (NDArray, NDArray) next_batch(int batch_size, int num_skips, int skip_window)
{
var batch = np.ndarray((batch_size), dtype: np.int32);
var labels = np.ndarray((batch_size, 1), dtype: np.int32);
// get window size (words left and right + current one)
int span = 2 * skip_window + 1;
var buffer = new Queue(span);
if (data_index + span > data.Length)
data_index = 0;
data.Skip(data_index).Take(span).ToList().ForEach(x => buffer.Enqueue(x));
data_index += span;
foreach (var i in range(batch_size / num_skips))
{
var context_words = range(span).Where(x => x != skip_window).ToArray();
var words_to_use = new int[] { 1, 6 };
foreach(var (j, context_word) in enumerate(words_to_use))
{
batch[i * num_skips + j] = buffer.ElementAt(skip_window);
labels[i * num_skips + j, 0] = buffer.ElementAt(context_word);
}
if (data_index == len(data))
{
//buffer.extend(data[0:span]);
data_index = span;
}
else
{
buffer.Enqueue(data[data_index]);
data_index += 1;
}
}
// Backtrack a little bit to avoid skipping words in the end of a batch
data_index = (data_index + len(data) - span) % len(data);
return (batch, labels);
}
public void PrepareData()
{
// Download graph meta
var url = "https://github.com/SciSharp/TensorFlow.NET/raw/master/graph/word2vec.meta";
Web.Download(url, "graph", "word2vec.meta");
// Download a small chunk of Wikipedia articles collection
url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/text8.zip";
Web.Download(url, "word2vec", "text8.zip");
// Unzip the dataset file. Text has already been processed
Compress.UnZip($"word2vec{Path.DirectorySeparatorChar}text8.zip", "word2vec");
int wordId = 0;
text_words = File.ReadAllText($"word2vec{Path.DirectorySeparatorChar}text8").Trim().ToLower().Split();
// Build the dictionary and replace rare words with UNK token
word2id = text_words.GroupBy(x => x)
.Select(x => new WordId
{
Word = x.Key,
Occurrence = x.Count()
})
.Where(x => x.Occurrence >= min_occurrence) // Remove samples with less than 'min_occurrence' occurrences
.OrderByDescending(x => x.Occurrence) // Retrieve the most common words
.Select(x => new WordId
{
Word = x.Word,
Id = ++wordId, // Assign an id to each word
Occurrence = x.Occurrence
})
.ToList();
// Retrieve a word id, or assign it index 0 ('UNK') if not in dictionary
data = (from word in text_words
join id in word2id on word equals id.Word into wi
from wi2 in wi.DefaultIfEmpty()
select wi2 == null ? 0 : wi2.Id).ToArray();
word2id.Insert(0, new WordId { Word = "UNK", Id = 0, Occurrence = data.Count(x => x == 0) });
print($"Words count: {text_words.Length}");
print($"Unique words: {text_words.Distinct().Count()}");
print($"Vocabulary size: {word2id.Count}");
print($"Most common words: {string.Join(", ", word2id.Take(10))}");
}
public Graph ImportGraph()
{
throw new NotImplementedException();
}
public Graph BuildGraph()
{
throw new NotImplementedException();
}
public void Train(Session sess)
{
throw new NotImplementedException();
}
public void Predict(Session sess)
{
throw new NotImplementedException();
}
public void Test(Session sess)
{
throw new NotImplementedException();
}
private class WordId
{
public string Word { get; set; }
public int Id { get; set; }
public int Occurrence { get; set; }
public override string ToString()
{
return Word + " " + Id + " " + Occurrence;
}
}
}
}