Browse Source

Revert IMDB dataset changes.

pull/756/head
Niklas Gustafsson 4 years ago
parent
commit
24699690e7
1 changed files with 24 additions and 57 deletions
  1. +24
    -57
      src/TensorFlowNET.Keras/Datasets/Imdb.cs

+ 24
- 57
src/TensorFlowNET.Keras/Datasets/Imdb.cs View File

@@ -5,8 +5,6 @@ using System.Text;
using Tensorflow.Keras.Utils;
using NumSharp;
using System.Linq;
using NumSharp.Utilities;
using Tensorflow.Queues;

namespace Tensorflow.Keras.Datasets
{
@@ -17,10 +15,8 @@ namespace Tensorflow.Keras.Datasets
/// </summary>
public class Imdb
{
//string origin_folder = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/";
string origin_folder = "http://ai.stanford.edu/~amaas/data/sentiment/";
//string file_name = "imdb.npz";
string file_name = "aclImdb_v1.tar.gz";
string origin_folder = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/";
string file_name = "imdb.npz";
string dest_folder = "imdb";

/// <summary>
@@ -46,61 +42,33 @@ namespace Tensorflow.Keras.Datasets
{
var dst = Download();

var vocab = BuildVocabulary(Path.Combine(dst, "imdb.vocab"), start_char, oov_char, index_from);

var (x_train,y_train) = GetDataSet(Path.Combine(dst, "train"));
var (x_test, y_test) = GetDataSet(Path.Combine(dst, "test"));

return new DatasetPass
var lines = File.ReadAllLines(Path.Combine(dst, "imdb_train.txt"));
var x_train_string = new string[lines.Length];
var y_train = np.zeros(new int[] { lines.Length }, NPTypeCode.Int64);
for (int i = 0; i < lines.Length; i++)
{
Train = (x_train, y_train),
Test = (x_test, y_test)
};
}
y_train[i] = long.Parse(lines[i].Substring(0, 1));
x_train_string[i] = lines[i].Substring(2);
}

private static Dictionary<string, int> BuildVocabulary(string path,
int start_char,
int oov_char,
int index_from)
{
var words = File.ReadAllLines(path);
var result = new Dictionary<string, int>();
var idx = index_from;
var x_train = np.array(x_train_string);

foreach (var word in words)
File.ReadAllLines(Path.Combine(dst, "imdb_test.txt"));
var x_test_string = new string[lines.Length];
var y_test = np.zeros(new int[] { lines.Length }, NPTypeCode.Int64);
for (int i = 0; i < lines.Length; i++)
{
result[word] = idx;
idx += 1;
y_test[i] = long.Parse(lines[i].Substring(0, 1));
x_test_string[i] = lines[i].Substring(2);
}

return result;
}

private static (NDArray, NDArray) GetDataSet(string path)
{
var posFiles = Directory.GetFiles(Path.Combine(path, "pos")).Slice(0,10);
var negFiles = Directory.GetFiles(Path.Combine(path, "neg")).Slice(0,10);

var x_string = new string[posFiles.Length + negFiles.Length];
var y = new int[posFiles.Length + negFiles.Length];
var trg = 0;
var longest = 0;
var x_test = np.array(x_test_string);

for (int i = 0; i < posFiles.Length; i++, trg++)
{
y[trg] = 1;
x_string[trg] = File.ReadAllText(posFiles[i]);
longest = Math.Max(longest, x_string[trg].Length);
}
for (int i = 0; i < posFiles.Length; i++, trg++)
return new DatasetPass
{
y[trg] = 0;
x_string[trg] = File.ReadAllText(negFiles[i]);
longest = Math.Max(longest, x_string[trg].Length);
}
var x = np.array(x_string);

return (x, y);
Train = (x_train, y_train),
Test = (x_test, y_test)
};
}

(NDArray, NDArray) LoadX(byte[] bytes)
@@ -122,9 +90,8 @@ namespace Tensorflow.Keras.Datasets

Web.Download(origin_folder + file_name, dst, file_name);

Tensorflow.Keras.Utils.Compress.ExtractTGZ(Path.Combine(dst, file_name), dst);

return Path.Combine(dst, "aclImdb");
return dst;
// return Path.Combine(dst, file_name);
}
}
}
}

Loading…
Cancel
Save