using System;
using System.Collections.Generic;
using System.IO;
using Tensorflow;
using Newtonsoft.Json;
using System.Linq;
using System.Text.RegularExpressions;
using NumSharp;
using static Tensorflow.Python;
namespace TensorFlowNET.Examples
{
///
/// This example classifies movie reviews as positive or negative using the text of the review.
/// This is a binary—or two-class—classification, an important and widely applicable kind of machine learning problem.
/// https://github.com/tensorflow/docs/blob/master/site/en/tutorials/keras/basic_text_classification.ipynb
///
public class BinaryTextClassification : IExample
{
public bool Enabled { get; set; } = true;
public string Name => "Binary Text Classification";
public bool IsImportingGraph { get; set; } = true;
string dir = "binary_text_classification";
string dataFile = "imdb.zip";
NDArray train_data, train_labels, test_data, test_labels;
public bool Run()
{
PrepareData();
Console.WriteLine($"Training entries: {train_data.len}, labels: {train_labels.len}");
// A dictionary mapping words to an integer index
var word_index = GetWordIndex();
/*train_data = keras.preprocessing.sequence.pad_sequences(train_data,
value: word_index[""],
padding: "post",
maxlen: 256);
test_data = keras.preprocessing.sequence.pad_sequences(test_data,
value: word_index[""],
padding: "post",
maxlen: 256);*/
// input shape is the vocabulary count used for the movie reviews (10,000 words)
int vocab_size = 10000;
var model = keras.Sequential();
var layer = keras.layers.Embedding(vocab_size, 16);
model.add(layer);
return false;
}
public void PrepareData()
{
Directory.CreateDirectory(dir);
// get model file
string url = $"https://github.com/SciSharp/TensorFlow.NET/raw/master/data/{dataFile}";
Utility.Web.Download(url, dir, "imdb.zip");
Utility.Compress.UnZip(Path.Join(dir, $"imdb.zip"), dir);
// prepare training dataset
var x_train = ReadData(Path.Join(dir, "x_train.txt"));
var labels_train = ReadData(Path.Join(dir, "y_train.txt"));
var indices_train = ReadData(Path.Join(dir, "indices_train.txt"));
x_train = x_train[indices_train];
labels_train = labels_train[indices_train];
var x_test = ReadData(Path.Join(dir, "x_test.txt"));
var labels_test = ReadData(Path.Join(dir, "y_test.txt"));
var indices_test = ReadData(Path.Join(dir, "indices_test.txt"));
x_test = x_test[indices_test];
labels_test = labels_test[indices_test];
// not completed
var xs = x_train.hstack(x_test);
var labels = labels_train.hstack(labels_test);
var idx = x_train.size;
var y_train = labels_train;
var y_test = labels_test;
// convert x_train
train_data = new NDArray(np.int32, (x_train.size, 256));
for (int i = 0; i < x_train.size; i++)
train_data[i] = x_train[i].Data(0).Split(',').Select(x => int.Parse(x)).ToArray();
test_data = new NDArray(np.int32, (x_test.size, 256));
for (int i = 0; i < x_test.size; i++)
test_data[i] = x_test[i].Data(0).Split(',').Select(x => int.Parse(x)).ToArray();
train_labels = y_train;
test_labels = y_test;
}
private NDArray ReadData(string file)
{
var lines = File.ReadAllLines(file);
var nd = new NDArray(lines[0].StartsWith("[") ? typeof(string) : np.int32, new Shape(lines.Length));
if (lines[0].StartsWith("["))
{
for (int i = 0; i < lines.Length; i++)
{
/*var matches = Regex.Matches(lines[i], @"\d+\s*");
var data = new int[matches.Count];
for (int j = 0; j < data.Length; j++)
data[j] = Convert.ToInt32(matches[j].Value);
nd[i] = data.ToArray();*/
nd[i] = lines[i].Substring(1, lines[i].Length - 2).Replace(" ", string.Empty);
}
}
else
{
for (int i = 0; i < lines.Length; i++)
nd[i] = Convert.ToInt32(lines[i]);
}
return nd;
}
private Dictionary GetWordIndex()
{
var result = new Dictionary();
var json = File.ReadAllText(Path.Join(dir, "imdb_word_index.json"));
var dict = JsonConvert.DeserializeObject>(json);
dict.Keys.Select(k => result[k] = dict[k] + 3).ToList();
result[""] = 0;
result[""] = 1;
result[""] = 2; // unknown
result[""] = 3;
return result;
}
public Graph ImportGraph()
{
throw new NotImplementedException();
}
public Graph BuildGraph()
{
throw new NotImplementedException();
}
public void Train(Session sess)
{
throw new NotImplementedException();
}
public void Predict(Session sess)
{
throw new NotImplementedException();
}
public void Test(Session sess)
{
throw new NotImplementedException();
}
}
}