Browse Source

waiting for scikit-learn's _split.train_test_split

tags/v0.8.0
haiping008 6 years ago
parent
commit
292233b4a9
10 changed files with 108 additions and 94 deletions
  1. +6
    -0
      TensorFlow.NET.sln
  2. +14
    -6
      src/TensorFlowNET.Utility/Web.cs
  3. +0
    -58
      test/TensorFlowNET.Examples/CnnTextClassification/CnnTextTrain.cs
  4. +0
    -16
      test/TensorFlowNET.Examples/CnnTextClassification/TextCNN.cs
  5. +4
    -5
      test/TensorFlowNET.Examples/ImageRecognition.cs
  6. +4
    -4
      test/TensorFlowNET.Examples/InceptionArchGoogLeNet.cs
  7. +1
    -0
      test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
  8. +40
    -2
      test/TensorFlowNET.Examples/TextClassification/DataHelpers.cs
  9. +37
    -0
      test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs
  10. +2
    -3
      test/TensorFlowNET.Examples/TextClassificationWithMovieReviews.cs

+ 6
- 0
TensorFlow.NET.sln View File

@@ -15,6 +15,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Visualization
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{E8340C61-12C1-4BEE-A340-403E7C1ACD82}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "scikit-learn", "..\scikit-learn.net\src\scikit-learn\scikit-learn.csproj", "{199DDAD8-4A6F-43B3-A560-C0393619E304}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -45,6 +47,10 @@ Global
{E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Debug|Any CPU.Build.0 = Debug|Any CPU
{E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Release|Any CPU.ActiveCfg = Release|Any CPU
{E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Release|Any CPU.Build.0 = Release|Any CPU
{199DDAD8-4A6F-43B3-A560-C0393619E304}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{199DDAD8-4A6F-43B3-A560-C0393619E304}.Debug|Any CPU.Build.0 = Debug|Any CPU
{199DDAD8-4A6F-43B3-A560-C0393619E304}.Release|Any CPU.ActiveCfg = Release|Any CPU
{199DDAD8-4A6F-43B3-A560-C0393619E304}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE


+ 14
- 6
src/TensorFlowNET.Utility/Web.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net;
using System.Text;
using System.Threading;
@@ -10,24 +11,31 @@ namespace TensorFlowNET.Utility
{
public class Web
{
public static bool Download(string url, string file)
public static bool Download(string url, string destDir, string destFileName)
{
if (File.Exists(file))
if (destFileName == null)
destFileName = url.Split(Path.DirectorySeparatorChar).Last();

Directory.CreateDirectory(destDir);

string relativeFilePath = Path.Combine(destDir, destFileName);

if (File.Exists(relativeFilePath))
{
Console.WriteLine($"{file} already exists.");
Console.WriteLine($"{relativeFilePath} already exists.");
return false;
}

var wc = new WebClient();
Console.WriteLine($"Downloading {file}");
var download = Task.Run(() => wc.DownloadFile(url, file));
Console.WriteLine($"Downloading {relativeFilePath}");
var download = Task.Run(() => wc.DownloadFile(url, relativeFilePath));
while (!download.IsCompleted)
{
Thread.Sleep(1000);
Console.Write(".");
}
Console.WriteLine("");
Console.WriteLine($"Downloaded {file}");
Console.WriteLine($"Downloaded {relativeFilePath}");

return true;
}


+ 0
- 58
test/TensorFlowNET.Examples/CnnTextClassification/CnnTextTrain.cs View File

@@ -1,58 +0,0 @@
using NumSharp.Core;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow;

namespace TensorFlowNET.Examples.CnnTextClassification
{
public class CnnTextTrain : Python, IExample
{
// Percentage of the training data to use for validation
private float dev_sample_percentage = 0.1f;
// Data source for the positive data.
private string positive_data_file = "https://raw.githubusercontent.com/dennybritz/cnn-text-classification-tf/master/data/rt-polaritydata/rt-polarity.pos";
// Data source for the negative data.
private string negative_data_file = "https://raw.githubusercontent.com/dennybritz/cnn-text-classification-tf/master/data/rt-polaritydata/rt-polarity.neg";
// Dimensionality of character embedding (default: 128)
private int embedding_dim = 128;
// Comma-separated filter sizes (default: '3,4,5')
private string filter_sizes = "3,4,5";
// Number of filters per filter size (default: 128)
private int num_filters = 128;
// Dropout keep probability (default: 0.5)
private float dropout_keep_prob = 0.5f;
// L2 regularization lambda (default: 0.0)
private float l2_reg_lambda = 0.0f;
// Batch Size (default: 64)
private int batch_size = 64;
// Number of training epochs (default: 200)
private int num_epochs = 200;
// Evaluate model on dev set after this many steps (default: 100)
private int evaluate_every = 100;
// Save model after this many steps (default: 100)
private int checkpoint_every = 100;
// Number of checkpoints to store (default: 5)
private int num_checkpoints = 5;
// Allow device soft device placement
private bool allow_soft_placement = true;
// Log placement of ops on devices
private bool log_device_placement = false;

public void Run()
{
var (x_train, y_train, vocab_processor, x_dev, y_dev) = preprocess();
}

public (NDArray, NDArray, NDArray, NDArray, NDArray) preprocess()
{
var (x_text, y) = DataHelpers.load_data_and_labels(positive_data_file, negative_data_file);

// Build vocabulary
int max_document_length = x_text.Select(x => x.Split(' ').Length).Max();
var vocab_processor = learn.preprocessing.VocabularyProcessor(max_document_length)
throw new NotImplementedException("");
}
}
}

+ 0
- 16
test/TensorFlowNET.Examples/CnnTextClassification/TextCNN.cs View File

@@ -1,16 +0,0 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;

namespace TensorFlowNET.Examples.CnnTextClassification
{
/// <summary>
/// Convolutional Neural Network for Text Classification
/// https://github.com/dennybritz/cnn-text-classification-tf
/// </summary>
public class TextCNN : Python
{
}
}

+ 4
- 5
test/TensorFlowNET.Examples/ImageRecognition.cs View File

@@ -85,15 +85,14 @@ namespace TensorFlowNET.Examples
// get model file
string url = "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip";

string zipFile = Path.Join(dir, "inception5h.zip");
Utility.Web.Download(url, zipFile);
Utility.Web.Download(url, dir, "inception5h.zip");

Utility.Compress.UnZip(zipFile, dir);
Utility.Compress.UnZip(Path.Join(dir, "inception5h.zip"), dir);

// download sample picture
string pic = Path.Join(dir, "img", "grace_hopper.jpg");
Directory.CreateDirectory(Path.Join(dir, "img"));
Utility.Web.Download($"https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/label_image/data/grace_hopper.jpg", pic);
url = $"https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/label_image/data/grace_hopper.jpg";
Utility.Web.Download(url, Path.Join(dir, "img"), "grace_hopper.jpg");
}
}
}

+ 4
- 4
test/TensorFlowNET.Examples/InceptionArchGoogLeNet.cs View File

@@ -90,14 +90,14 @@ namespace TensorFlowNET.Examples
// get model file
string url = "https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz";
string zipFile = Path.Join(dir, $"{pbFile}.tar.gz");
Utility.Web.Download(url, zipFile);
Utility.Web.Download(url, dir, $"{pbFile}.tar.gz");

Utility.Compress.ExtractTGZ(zipFile, dir);
Utility.Compress.ExtractTGZ(Path.Join(dir, $"{pbFile}.tar.gz"), dir);

// download sample picture
string pic = "grace_hopper.jpg";
Utility.Web.Download($"https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/label_image/data/{pic}", Path.Join(dir, pic));
url = $"https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/label_image/data/{pic}";
Utility.Web.Download(url, dir, pic);
}
}
}

+ 1
- 0
test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj View File

@@ -12,6 +12,7 @@

<ItemGroup>
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" />
<ProjectReference Include="..\..\..\scikit-learn.net\src\scikit-learn\scikit-learn.csproj" />
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" />
<ProjectReference Include="..\..\src\TensorFlowNET.Utility\TensorFlowNET.Utility.csproj" />
</ItemGroup>


test/TensorFlowNET.Examples/CnnTextClassification/DataHelpers.cs → test/TensorFlowNET.Examples/TextClassification/DataHelpers.cs View File

@@ -10,6 +10,44 @@ namespace TensorFlowNET.Examples.CnnTextClassification
{
public class DataHelpers
{
private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv";
private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv";

public static (int[][], int[], int) build_char_dataset(string step, string model, int document_max_len)
{
string alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:’'\"/|_#$%ˆ&*˜‘+=<>()[]{} ";
/*if (step == "train")
df = pd.read_csv(TRAIN_PATH, names =["class", "title", "content"]);*/
var char_dict = new Dictionary<string, int>();
char_dict["<pad>"] = 0;
char_dict["<unk>"] = 1;
foreach (char c in alphabet)
char_dict[c.ToString()] = char_dict.Count;

var contents = File.ReadAllLines(TRAIN_PATH);
var x = new int[contents.Length][];
var y = new int[contents.Length];
for (int i = 0; i < contents.Length; i++)
{
string[] parts = contents[i].ToLower().Split(",\"").ToArray();
string content = parts[2];
content = content.Substring(0, content.Length - 1);
x[i] = new int[document_max_len];
for (int j = 0; j < document_max_len; j++)
{
if (j >= content.Length)
x[i][j] = char_dict["<pad>"];
else
x[i][j] = char_dict.ContainsKey(content[j].ToString()) ? char_dict[content[j].ToString()] : char_dict["<unk>"];
}
y[i] = int.Parse(parts[0]);
}

return (x, y, alphabet.Length + 2);
}

/// <summary>
/// Loads MR polarity data from files, splits the data into words and generates labels.
/// Returns split sentences and labels.
@@ -20,8 +58,8 @@ namespace TensorFlowNET.Examples.CnnTextClassification
public static (string[], NDArray) load_data_and_labels(string positive_data_file, string negative_data_file)
{
Directory.CreateDirectory("CnnTextClassification");
Utility.Web.Download(positive_data_file, "CnnTextClassification/rt-polarity.pos");
Utility.Web.Download(negative_data_file, "CnnTextClassification/rt-polarity.neg");
Utility.Web.Download(positive_data_file, "CnnTextClassification", "rt -polarity.pos");
Utility.Web.Download(negative_data_file, "CnnTextClassification", "rt-polarity.neg");

// Load data from files
var positive_examples = File.ReadAllLines("CnnTextClassification/rt-polarity.pos")

+ 37
- 0
test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs View File

@@ -0,0 +1,37 @@
using NumSharp.Core;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using Tensorflow;
using TensorFlowNET.Utility;

namespace TensorFlowNET.Examples.CnnTextClassification
{
/// <summary>
/// https://github.com/dongjun-Lee/text-classification-models-tf
/// </summary>
public class TextClassificationTrain : Python, IExample
{
private string dataDir = "text_classification";
private string dataFileName = "dbpedia_csv.tar.gz";

private const int CHAR_MAX_LEN = 1014;

public void Run()
{
download_dbpedia();
Console.WriteLine("Building dataset...");
var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "vdcnn", CHAR_MAX_LEN);
var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15);
}

public void download_dbpedia()
{
string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz";
Web.Download(url, dataDir, dataFileName);
Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir);
}
}
}

+ 2
- 3
test/TensorFlowNET.Examples/TextClassificationWithMovieReviews.cs View File

@@ -46,9 +46,8 @@ namespace TensorFlowNET.Examples
// get model file
string url = $"https://github.com/SciSharp/TensorFlow.NET/raw/master/data/{dataFile}";

string zipFile = Path.Join(dir, $"imdb.zip");
Utility.Web.Download(url, zipFile);
Utility.Compress.UnZip(zipFile, dir);
Utility.Web.Download(url, dir, "imdb.zip");
Utility.Compress.UnZip(Path.Join(dir, $"imdb.zip"), dir);

// prepare training dataset
var x_train = ReadData(Path.Join(dir, "x_train.txt"));


Loading…
Cancel
Save