Browse Source

add Tensor[] pattern match for ops.name_scope.

tags/v0.9
Oceania2018 6 years ago
parent
commit
5a73e698b0
4 changed files with 10 additions and 35 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Layers/Layer.cs
  2. +6
    -2
      src/TensorFlowNET.Core/ops.name_scope.cs
  3. +3
    -1
      src/TensorFlowNET.Core/ops.py.cs
  4. +0
    -31
      test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs

+ 1
- 1
src/TensorFlowNET.Core/Layers/Layer.cs View File

@@ -37,7 +37,7 @@ namespace Tensorflow.Layers
VariableScope scope = null)
{
_set_scope(scope);
_graph = ops._get_graph_from_inputs(new List<Tensor> { inputs }, graph: _graph);
_graph = ops._get_graph_from_inputs(new Tensor[] { inputs }, graph: _graph);

variable_scope scope_context_manager = null;
if (built)


+ 6
- 2
src/TensorFlowNET.Core/ops.name_scope.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Eager;

@@ -37,8 +38,11 @@ namespace Tensorflow
_name = _name == null ? _default_name : _name;

Graph g = null;
if (_values is List<Tensor> values)
g = _get_graph_from_inputs(values);

if (_values is List<Tensor> vList)
g = _get_graph_from_inputs(vList.ToArray());
else if (_values is Tensor[] vArray)
g = _get_graph_from_inputs(vArray);

if (g == null)
g = get_default_graph();


+ 3
- 1
src/TensorFlowNET.Core/ops.py.cs View File

@@ -102,8 +102,10 @@ namespace Tensorflow
default_graph = tf.Graph();
}

public static Graph _get_graph_from_inputs(params Tensor[] op_input_list)
=> _get_graph_from_inputs(op_input_list: op_input_list);

public static Graph _get_graph_from_inputs(List<Tensor> op_input_list, Graph graph = null)
public static Graph _get_graph_from_inputs(Tensor[] op_input_list, Graph graph = null)
{
foreach(var op_input in op_input_list)
{


+ 0
- 31
test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs View File

@@ -203,37 +203,6 @@ namespace TensorFlowNET.Examples.CnnTextClassification
return (train_x, valid_x, train_y, valid_y);
}

//private (int[][], int[][], int[], int[]) train_test_split(int[][] x, int[] y, float test_size = 0.3f)
//{
// Console.WriteLine("Splitting in Training and Testing data...");
// var stopwatch = Stopwatch.StartNew();
// int len = x.Length;
// int train_size = int.Parse((len * (1 - test_size)).ToString());
// var random = new Random(17);

// // we collect indices of labels
// var labels = new Dictionary<int, HashSet<int>>();
// var shuffled_indices = random.Shuffle<int>(range(len).ToArray());
// foreach (var i in shuffled_indices)
// {
// var label = y[i];
// if (!labels.ContainsKey(i))
// labels[label] = new HashSet<int>();
// labels[label].Add(i);
// }

// var train_x = new int[train_size][];
// var valid_x = new int[len - train_size][];
// var train_y = new int[train_size];
// var valid_y = new int[len - train_size];
// FillWithShuffledLabels(x, y, train_x, train_y, random, labels);
// FillWithShuffledLabels(x, y, valid_x, valid_y, random, labels);

// Console.WriteLine("\tDONE " + stopwatch.Elapsed);
// return (train_x, valid_x, train_y, valid_y);
//}

private static void FillWithShuffledLabels(int[][] x, int[] y, int[][] shuffled_x, int[] shuffled_y, Random random, Dictionary<int, HashSet<int>> labels)
{
int i = 0;


Loading…
Cancel
Save