Browse Source

add tf.where api.

tags/v0.9
Oceania2018 6 years ago
parent
commit
a5ae56affc
9 changed files with 166 additions and 124 deletions
  1. +7
    -0
      src/TensorFlowNET.Core/APIs/tf.array.cs
  2. +1
    -1
      src/TensorFlowNET.Core/APIs/tf.random.cs
  3. +20
    -4
      src/TensorFlowNET.Core/Layers/Layer.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Operations/array_ops.py.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  6. +1
    -1
      src/TensorFlowNET.Core/Operations/random_ops.py.cs
  7. +1
    -1
      src/TensorFlowNET.Core/tf.cs
  8. +133
    -114
      test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs
  9. +1
    -1
      test/TensorFlowNET.Examples/TextProcess/cnn_models/VdCnn.cs

+ 7
- 0
src/TensorFlowNET.Core/APIs/tf.array.cs View File

@@ -20,6 +20,13 @@ namespace Tensorflow
public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1)
=> array_ops.expand_dims(input, axis, name, dim);

/// <summary>
/// Return the elements, either from `x` or `y`, depending on the `condition`.
/// </summary>
/// <returns></returns>
public static Tensor where<Tx, Ty>(Tensor condition, Tx x, Ty y, string name = null)
=> array_ops.where(condition, x, y, name);

/// <summary>
/// Transposes `a`. Permutes the dimensions according to `perm`.
/// </summary>


+ 1
- 1
src/TensorFlowNET.Core/APIs/tf.random.cs View File

@@ -25,7 +25,7 @@ namespace Tensorflow

public static Tensor random_uniform(int[] shape,
float minval = 0,
float? maxval = null,
float maxval = 1,
TF_DataType dtype = TF_DataType.TF_FLOAT,
int? seed = null,
string name = null) => random_ops.random_uniform(shape, minval, maxval, dtype, seed, name);


+ 20
- 4
src/TensorFlowNET.Core/Layers/Layer.cs View File

@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Text;
using static Tensorflow.Python;

namespace Tensorflow.Layers
{
@@ -50,7 +51,7 @@ namespace Tensorflow.Layers
auxiliary_name_scope: false);
}

Python.with(scope_context_manager, scope2 => _current_scope = scope2);
with(scope_context_manager, scope2 => _current_scope = scope2);
// Actually call layer
var outputs = base.__call__(new Tensor[] { inputs }, training: training);

@@ -60,6 +61,13 @@ namespace Tensorflow.Layers
return outputs;
}

protected override void _init_set_name(string name, bool zero_based = true)
{
// Determine layer name (non-unique).
base._init_set_name(name, zero_based: zero_based);
_base_name = this.name;
}

protected virtual void _add_elements_to_collection(Operation[] elements, string[] collection_list)
{
foreach(var name in collection_list)
@@ -140,10 +148,18 @@ namespace Tensorflow.Layers
{
if (_scope == null)
{
Python.with(tf.variable_scope(scope, default_name: _base_name), captured_scope =>
if(_reuse.HasValue && _reuse.Value)
{
_scope = captured_scope;
});
throw new NotImplementedException("_set_scope _reuse.HasValue");
/*with(tf.variable_scope(scope == null ? _base_name : scope),
captured_scope => _scope = captured_scope);*/
}
else
{
with(tf.variable_scope(scope, default_name: _base_name),
captured_scope => _scope = captured_scope);
}

}
}
}


+ 1
- 1
src/TensorFlowNET.Core/Operations/array_ops.py.cs View File

@@ -233,7 +233,7 @@ namespace Tensorflow
});
}
public static Tensor where(Tensor condition, Tensor x = null, Tensor y = null, string name = null)
public static Tensor where(Tensor condition, object x = null, object y = null, string name = null)
{
if( x == null && y == null)
{


+ 1
- 1
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -234,7 +234,7 @@ namespace Tensorflow
return _op.outputs[0];
}

public static Tensor select(Tensor condition, Tensor t, Tensor e, string name = null)
public static Tensor select<Tx, Ty>(Tensor condition, Tx t, Ty e, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Select", name, new { condition, t, e });
return _op.outputs[0];


+ 1
- 1
src/TensorFlowNET.Core/Operations/random_ops.py.cs View File

@@ -49,7 +49,7 @@ namespace Tensorflow
/// <returns>A tensor of the specified shape filled with random uniform values.</returns>
public static Tensor random_uniform(int[] shape,
float minval = 0,
float? maxval = null,
float maxval = 1,
TF_DataType dtype = TF_DataType.TF_FLOAT,
int? seed = null,
string name = null)


+ 1
- 1
src/TensorFlowNET.Core/tf.cs View File

@@ -15,7 +15,7 @@ namespace Tensorflow
public static TF_DataType float16 = TF_DataType.TF_HALF;
public static TF_DataType float32 = TF_DataType.TF_FLOAT;
public static TF_DataType float64 = TF_DataType.TF_DOUBLE;
public static TF_DataType boolean = TF_DataType.TF_BOOL;
public static TF_DataType @bool = TF_DataType.TF_BOOL;
public static TF_DataType chars = TF_DataType.TF_STRING;

public static Context context = new Context(new ContextOptions(), new Status());


+ 133
- 114
test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs View File

@@ -40,43 +40,157 @@ namespace TensorFlowNET.Examples
protected float loss_value = 0;
int vocabulary_size = 50000;
NDArray train_x, valid_x, train_y, valid_y;

public bool Run()
{
PrepareData();

var graph = tf.Graph().as_default();
return with(tf.Session(graph), sess =>
Train();

return true;
}

// TODO: this originally is an SKLearn utility function. it randomizes train and test which we don't do here
private (NDArray, NDArray, NDArray, NDArray) train_test_split(NDArray x, NDArray y, float test_size = 0.3f)
{
Console.WriteLine("Splitting in Training and Testing data...");
int len = x.shape[0];
//int classes = y.Data<int>().Distinct().Count();
//int samples = len / classes;
int train_size = (int)Math.Round(len * (1 - test_size));
var train_x = x[new Slice(stop: train_size), new Slice()];
var valid_x = x[new Slice(start: train_size), new Slice()];
var train_y = y[new Slice(stop: train_size)];
var valid_y = y[new Slice(start: train_size)];
Console.WriteLine("\tDONE");
return (train_x, valid_x, train_y, valid_y);
}

private static void FillWithShuffledLabels(int[][] x, int[] y, int[][] shuffled_x, int[] shuffled_y, Random random, Dictionary<int, HashSet<int>> labels)
{
int i = 0;
var label_keys = labels.Keys.ToArray();
while (i < shuffled_x.Length)
{
if (IsImportingGraph)
return RunWithImportedGraph(sess, graph);
else
return RunWithBuiltGraph(sess, graph);
});
var key = label_keys[random.Next(label_keys.Length)];
var set = labels[key];
var index = set.First();
if (set.Count == 0)
{
labels.Remove(key); // remove the set as it is empty
label_keys = labels.Keys.ToArray();
}
shuffled_x[i] = x[index];
shuffled_y[i] = y[index];
i++;
}
}

protected virtual bool RunWithImportedGraph(Session sess, Graph graph)
private IEnumerable<(NDArray, NDArray, int)> batch_iter(NDArray inputs, NDArray outputs, int batch_size, int num_epochs)
{
var stopwatch = Stopwatch.StartNew();
var num_batches_per_epoch = (len(inputs) - 1) / batch_size + 1;
var total_batches = num_batches_per_epoch * num_epochs;
foreach (var epoch in range(num_epochs))
{
foreach (var batch_num in range(num_batches_per_epoch))
{
var start_index = batch_num * batch_size;
var end_index = Math.Min((batch_num + 1) * batch_size, len(inputs));
if (end_index <= start_index)
break;
yield return (inputs[new Slice(start_index, end_index)], outputs[new Slice(start_index, end_index)], total_batches);
}
}
}

public void PrepareData()
{
// full dataset https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz
var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/dbpedia_subset.zip";
Web.Download(url, dataDir, "dbpedia_subset.zip");
Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv"));

Console.WriteLine("Building dataset...");
int[][] x = null;
int[] y = null;

int alphabet_size = 0;

var word_dict = DataHelpers.build_word_dict(TRAIN_PATH);
// vocabulary_size = len(word_dict);
(x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN);
vocabulary_size = len(word_dict);
var (x, y) = DataHelpers.build_word_dataset(TRAIN_PATH, word_dict, WORD_MAX_LEN);

Console.WriteLine("\tDONE ");

var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);
Console.WriteLine("Training set size: " + train_x.len);
Console.WriteLine("Test set size: " + valid_x.len);
}

Console.WriteLine("Import graph...");
public Graph ImportGraph()
{
var graph = tf.Graph().as_default();

// download graph meta data
var meta_file = "word_cnn.meta";
var meta_path = Path.Combine("graph", meta_file);
if (File.GetLastWriteTime(meta_path) < new DateTime(2019, 05, 11))
{
// delete old cached file which contains errors
Console.WriteLine("Discarding cached file: " + meta_path);
File.Delete(meta_path);
}
var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file;
Web.Download(url, "graph", meta_file);

Console.WriteLine("Import graph...");
tf.train.import_meta_graph(Path.Join("graph", meta_file));
Console.WriteLine("\tDONE " + stopwatch.Elapsed);
Console.WriteLine("\tDONE ");

return graph;
}

public Graph BuildGraph()
{
var graph = tf.Graph().as_default();

var embedding_size = 128;
var learning_rate = 0.001f;
var filter_sizes = new int[3, 4, 5];
var num_filters = 100;
var document_max_len = 100;

var x = tf.placeholder(tf.int32, new TensorShape(-1, document_max_len), name: "x");
var y = tf.placeholder(tf.int32, new TensorShape(-1), name: "y");
var is_training = tf.placeholder(tf.@bool, new TensorShape(), name: "is_training");
var global_step = tf.Variable(0, trainable: false);
var keep_prob = tf.where(is_training, 0.5, 1.0);
Tensor x_emb = null;

with(tf.name_scope("embedding"), scope =>
{
var init_embeddings = tf.random_uniform(new int[] { vocabulary_size, embedding_size });
var embeddings = tf.get_variable("embeddings", initializer: init_embeddings);
x_emb = tf.nn.embedding_lookup(embeddings, x);
x_emb = tf.expand_dims(x_emb, -1);
});

foreach(var filter_size in filter_sizes)
{
var conv = tf.layers.conv2d(
x_emb,
filters: num_filters,
kernel_size: new int[] { filter_size, embedding_size },
strides: new int[] { 1, 1 },
padding: "VALID",
activation: tf.nn.relu());
}

return graph;
}

private bool RunWithImportedGraph(Session sess, Graph graph)
{
var stopwatch = Stopwatch.StartNew();

sess.run(tf.global_variables_initializer());
var saver = tf.train.Saver(tf.global_variables());
@@ -149,107 +263,12 @@ namespace TensorFlowNET.Examples
return false;
}

protected virtual bool RunWithBuiltGraph(Session session, Graph graph)
{
Console.WriteLine("Building dataset...");
var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "word_cnn", CHAR_MAX_LEN, DataLimit);

var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);

ITextClassificationModel model = null;
// todo train the model
return false;
}

// TODO: this originally is an SKLearn utility function. it randomizes train and test which we don't do here
private (NDArray, NDArray, NDArray, NDArray) train_test_split(NDArray x, NDArray y, float test_size = 0.3f)
{
Console.WriteLine("Splitting in Training and Testing data...");
int len = x.shape[0];
//int classes = y.Data<int>().Distinct().Count();
//int samples = len / classes;
int train_size = (int)Math.Round(len * (1 - test_size));
var train_x = x[new Slice(stop: train_size), new Slice()];
var valid_x = x[new Slice(start: train_size), new Slice()];
var train_y = y[new Slice(stop: train_size)];
var valid_y = y[new Slice(start: train_size)];
Console.WriteLine("\tDONE");
return (train_x, valid_x, train_y, valid_y);
}

private static void FillWithShuffledLabels(int[][] x, int[] y, int[][] shuffled_x, int[] shuffled_y, Random random, Dictionary<int, HashSet<int>> labels)
{
int i = 0;
var label_keys = labels.Keys.ToArray();
while (i < shuffled_x.Length)
{
var key = label_keys[random.Next(label_keys.Length)];
var set = labels[key];
var index = set.First();
if (set.Count == 0)
{
labels.Remove(key); // remove the set as it is empty
label_keys = labels.Keys.ToArray();
}
shuffled_x[i] = x[index];
shuffled_y[i] = y[index];
i++;
}
}

private IEnumerable<(NDArray, NDArray, int)> batch_iter(NDArray inputs, NDArray outputs, int batch_size, int num_epochs)
{
var num_batches_per_epoch = (len(inputs) - 1) / batch_size + 1;
var total_batches = num_batches_per_epoch * num_epochs;
foreach (var epoch in range(num_epochs))
{
foreach (var batch_num in range(num_batches_per_epoch))
{
var start_index = batch_num * batch_size;
var end_index = Math.Min((batch_num + 1) * batch_size, len(inputs));
if (end_index <= start_index)
break;
yield return (inputs[new Slice(start_index, end_index)], outputs[new Slice(start_index, end_index)], total_batches);
}
}
}

public void PrepareData()
{
// full dataset https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz
var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/dbpedia_subset.zip";
Web.Download(url, dataDir, "dbpedia_subset.zip");
Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv"));

if (IsImportingGraph)
{
// download graph meta data
var meta_file = "word_cnn.meta";
var meta_path = Path.Combine("graph", meta_file);
if (File.GetLastWriteTime(meta_path) < new DateTime(2019, 05, 11))
{
// delete old cached file which contains errors
Console.WriteLine("Discarding cached file: " + meta_path);
File.Delete(meta_path);
}
url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file;
Web.Download(url, "graph", meta_file);
}
}

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
var graph = IsImportingGraph ? ImportGraph() : BuildGraph();

return with(tf.Session(graph), sess
=> RunWithImportedGraph(sess, graph));
}

public bool Predict()


+ 1
- 1
test/TensorFlowNET.Examples/TextProcess/cnn_models/VdCnn.cs View File

@@ -40,7 +40,7 @@ namespace TensorFlowNET.Examples.TextClassification

x = tf.placeholder(tf.int32, new TensorShape(-1, document_max_len), name: "x");
y = tf.placeholder(tf.int32, new TensorShape(-1), name: "y");
is_training = tf.placeholder(tf.boolean, new TensorShape(), name: "is_training");
is_training = tf.placeholder(tf.@bool, new TensorShape(), name: "is_training");
global_step = tf.Variable(0, trainable: false);

// Embedding Layer


Loading…
Cancel
Save