Browse Source

Merge branch 'master' of https://github.com/SciSharp/TensorFlow.NET

tags/v0.8.0
Bo Peng 6 years ago
parent
commit
5e8908ea17
22 changed files with 414 additions and 153 deletions
  1. +5
    -11
      TensorFlow.NET.sln
  2. +1
    -123
      src/TensorFlowNET.Core/APIs/tf.init.cs
  3. +7
    -0
      src/TensorFlowNET.Core/APIs/tf.random.cs
  4. +20
    -0
      src/TensorFlowNET.Core/Keras/Initializers.cs
  5. +15
    -0
      src/TensorFlowNET.Core/Keras/tf.keras.cs
  6. +30
    -0
      src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs
  8. +41
    -0
      src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs
  9. +82
    -0
      src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
  10. +29
    -0
      src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs
  11. +14
    -0
      src/TensorFlowNET.Core/Operations/embedding_ops.cs
  12. +1
    -1
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  13. +1
    -1
      src/TensorFlowNET.Core/Variables/VariableScope.cs
  14. +63
    -5
      src/TensorFlowNET.Core/Variables/_VariableStore.cs
  15. +1
    -1
      src/TensorFlowNET.Core/Variables/tf.variable.cs
  16. +11
    -4
      src/TensorFlowNET.Core/tf.cs
  17. +0
    -1
      test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
  18. +1
    -1
      test/TensorFlowNET.Examples/TextClassification/DataHelpers.cs
  19. +43
    -1
      test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs
  20. +44
    -0
      test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs
  21. +3
    -2
      test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj
  22. +1
    -1
      test/TensorFlowNET.UnitTest/VariableTest.cs

+ 5
- 11
TensorFlow.NET.sln View File

@@ -11,9 +11,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\T
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Visualization", "TensorFlowNET.Visualization\TensorFlowNET.Visualization.csproj", "{4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{E8340C61-12C1-4BEE-A340-403E7C1ACD82}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "scikit-learn", "..\scikit-learn.net\src\scikit-learn\scikit-learn.csproj", "{199DDAD8-4A6F-43B3-A560-C0393619E304}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{1F4C5683-48B4-4328-8171-E9F7ABAA2E72}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
@@ -37,14 +35,10 @@ Global
{4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Debug|Any CPU.Build.0 = Debug|Any CPU
{4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Release|Any CPU.ActiveCfg = Release|Any CPU
{4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Release|Any CPU.Build.0 = Release|Any CPU
{E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Debug|Any CPU.Build.0 = Debug|Any CPU
{E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Release|Any CPU.ActiveCfg = Release|Any CPU
{E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Release|Any CPU.Build.0 = Release|Any CPU
{199DDAD8-4A6F-43B3-A560-C0393619E304}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{199DDAD8-4A6F-43B3-A560-C0393619E304}.Debug|Any CPU.Build.0 = Debug|Any CPU
{199DDAD8-4A6F-43B3-A560-C0393619E304}.Release|Any CPU.ActiveCfg = Release|Any CPU
{199DDAD8-4A6F-43B3-A560-C0393619E304}.Release|Any CPU.Build.0 = Release|Any CPU
{1F4C5683-48B4-4328-8171-E9F7ABAA2E72}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{1F4C5683-48B4-4328-8171-E9F7ABAA2E72}.Debug|Any CPU.Build.0 = Debug|Any CPU
{1F4C5683-48B4-4328-8171-E9F7ABAA2E72}.Release|Any CPU.ActiveCfg = Release|Any CPU
{1F4C5683-48B4-4328-8171-E9F7ABAA2E72}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE


+ 1
- 123
src/TensorFlowNET.Core/APIs/tf.init.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Operations.Initializers;

namespace Tensorflow
{
@@ -24,128 +25,5 @@ namespace Tensorflow
default_name,
values,
auxiliary_name_scope);

public class Zeros : IInitializer
{
private TF_DataType dtype;

public Zeros(TF_DataType dtype = TF_DataType.TF_FLOAT)
{
this.dtype = dtype;
}

public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid)
{
if (dtype == TF_DataType.DtInvalid)
dtype = this.dtype;

return array_ops.zeros(shape, dtype);
}

public object get_config()
{
return new { dtype = dtype.name() };
}
}

/// <summary>
/// Initializer capable of adapting its scale to the shape of weights tensors.
/// </summary>
public class VarianceScaling : IInitializer
{
protected float _scale;
protected string _mode;
protected string _distribution;
protected int? _seed;
protected TF_DataType _dtype;

public VarianceScaling(float scale = 1.0f,
string mode = "fan_in",
string distribution= "truncated_normal",
int? seed = null,
TF_DataType dtype = TF_DataType.TF_FLOAT)
{
if (scale < 0)
throw new ValueError("`scale` must be positive float.");
_scale = scale;
_mode = mode;
_distribution = distribution;
_seed = seed;
_dtype = dtype;
}

public Tensor call(TensorShape shape, TF_DataType dtype)
{
var (fan_in, fan_out) = _compute_fans(shape);
if (_mode == "fan_in")
_scale /= Math.Max(1, fan_in);
else if (_mode == "fan_out")
_scale /= Math.Max(1, fan_out);
else
_scale /= Math.Max(1, (fan_in + fan_out) / 2);

if (_distribution == "normal" || _distribution == "truncated_normal")
{
throw new NotImplementedException("truncated_normal");
}
else if(_distribution == "untruncated_normal")
{
throw new NotImplementedException("truncated_normal");
}
else
{
var limit = Math.Sqrt(3.0f * _scale);
return random_ops.random_uniform(shape, (float)-limit, (float)limit, dtype, seed: _seed);
}
}

private (int, int) _compute_fans(int[] shape)
{
if (shape.Length < 1)
return (1, 1);
if (shape.Length == 1)
return (shape[0], shape[0]);
if (shape.Length == 2)
return (shape[0], shape[1]);
else
throw new NotImplementedException("VarianceScaling._compute_fans");
}

public virtual object get_config()
{
return new
{
scale = _scale,
mode = _mode,
distribution = _distribution,
seed = _seed,
dtype = _dtype
};
}
}

public class GlorotUniform : VarianceScaling
{
public GlorotUniform(float scale = 1.0f,
string mode = "fan_avg",
string distribution = "uniform",
int? seed = null,
TF_DataType dtype = TF_DataType.TF_FLOAT) : base(scale, mode, distribution, seed, dtype)
{

}

public object get_config()
{
return new
{
scale = _scale,
mode = _mode,
distribution = _distribution,
seed = _seed,
dtype = _dtype
};
}
}
}
}

+ 7
- 0
src/TensorFlowNET.Core/APIs/tf.random.cs View File

@@ -22,5 +22,12 @@ namespace Tensorflow
TF_DataType dtype = TF_DataType.TF_FLOAT,
int? seed = null,
string name = null) => random_ops.random_normal(shape, mean, stddev, dtype, seed, name);

public static Tensor random_uniform(int[] shape,
float minval = 0,
float? maxval = null,
TF_DataType dtype = TF_DataType.TF_FLOAT,
int? seed = null,
string name = null) => random_ops.random_uniform(shape, minval, maxval, dtype, seed, name);
}
}

+ 20
- 0
src/TensorFlowNET.Core/Keras/Initializers.cs View File

@@ -0,0 +1,20 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Operations.Initializers;

namespace Tensorflow.Keras
{
public class Initializers
{
/// <summary>
/// He normal initializer.
/// </summary>
/// <param name="seed"></param>
/// <returns></returns>
public IInitializer he_normal(int? seed = null)
{
return new VarianceScaling(scale: 20f, mode: "fan_in", distribution: "truncated_normal", seed: seed);
}
}
}

+ 15
- 0
src/TensorFlowNET.Core/Keras/tf.keras.cs View File

@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras;

namespace Tensorflow
{
public static partial class tf
{
public static class keras
{
public static Initializers initializers => new Initializers();
}
}
}

+ 30
- 0
src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs View File

@@ -0,0 +1,30 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Operations.Initializers
{
public class GlorotUniform : VarianceScaling
{
public GlorotUniform(float scale = 1.0f,
string mode = "fan_avg",
string distribution = "uniform",
int? seed = null,
TF_DataType dtype = TF_DataType.TF_FLOAT) : base(scale, mode, distribution, seed, dtype)
{

}

public object get_config()
{
return new
{
scale = _scale,
mode = _mode,
distribution = _distribution,
seed = _seed,
dtype = _dtype
};
}
}
}

src/TensorFlowNET.Core/Operations/IInitializer.cs → src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs View File

@@ -6,7 +6,7 @@ namespace Tensorflow
{
public interface IInitializer
{
Tensor call(TensorShape shape, TF_DataType dtype);
Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid);
object get_config();
}
}

+ 41
- 0
src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs View File

@@ -0,0 +1,41 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Operations.Initializers
{
public class TruncatedNormal : IInitializer
{
private float mean;
private float stddev;
private int? seed;
private TF_DataType dtype;

public TruncatedNormal(float mean = 0.0f,
float stddev = 1.0f,
int? seed = null,
TF_DataType dtype = TF_DataType.TF_FLOAT)
{
this.mean = mean;
this.stddev = stddev;
this.seed = seed;
this.dtype = dtype;
}

public Tensor call(TensorShape shape, TF_DataType dtype)
{
throw new NotImplementedException("");
}

public object get_config()
{
return new
{
mean = mean,
stddev = stddev,
seed = seed,
dtype = dtype.name()
};
}
}
}

+ 82
- 0
src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs View File

@@ -0,0 +1,82 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Operations.Initializers
{
/// <summary>
/// Initializer capable of adapting its scale to the shape of weights tensors.
/// </summary>
public class VarianceScaling : IInitializer
{
protected float _scale;
protected string _mode;
protected string _distribution;
protected int? _seed;
protected TF_DataType _dtype;

public VarianceScaling(float scale = 1.0f,
string mode = "fan_in",
string distribution = "truncated_normal",
int? seed = null,
TF_DataType dtype = TF_DataType.TF_FLOAT)
{
if (scale < 0)
throw new ValueError("`scale` must be positive float.");
_scale = scale;
_mode = mode;
_distribution = distribution;
_seed = seed;
_dtype = dtype;
}

public Tensor call(TensorShape shape, TF_DataType dtype)
{
var (fan_in, fan_out) = _compute_fans(shape);
if (_mode == "fan_in")
_scale /= Math.Max(1, fan_in);
else if (_mode == "fan_out")
_scale /= Math.Max(1, fan_out);
else
_scale /= Math.Max(1, (fan_in + fan_out) / 2);

if (_distribution == "normal" || _distribution == "truncated_normal")
{
throw new NotImplementedException("truncated_normal");
}
else if (_distribution == "untruncated_normal")
{
throw new NotImplementedException("truncated_normal");
}
else
{
var limit = Math.Sqrt(3.0f * _scale);
return random_ops.random_uniform(shape, (float)-limit, (float)limit, dtype, seed: _seed);
}
}

private (int, int) _compute_fans(int[] shape)
{
if (shape.Length < 1)
return (1, 1);
if (shape.Length == 1)
return (shape[0], shape[0]);
if (shape.Length == 2)
return (shape[0], shape[1]);
else
throw new NotImplementedException("VarianceScaling._compute_fans");
}

public virtual object get_config()
{
return new
{
scale = _scale,
mode = _mode,
distribution = _distribution,
seed = _seed,
dtype = _dtype
};
}
}
}

+ 29
- 0
src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs View File

@@ -0,0 +1,29 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Operations.Initializers
{
public class Zeros : IInitializer
{
private TF_DataType dtype;

public Zeros(TF_DataType dtype = TF_DataType.TF_FLOAT)
{
this.dtype = dtype;
}

public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid)
{
if (dtype == TF_DataType.DtInvalid)
dtype = this.dtype;

return array_ops.zeros(shape, dtype);
}

public object get_config()
{
return new { dtype = dtype.name() };
}
}
}

+ 14
- 0
src/TensorFlowNET.Core/Operations/embedding_ops.cs View File

@@ -0,0 +1,14 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public class embedding_ops
{
public Tensor _embedding_lookup_and_transform()
{
throw new NotImplementedException("");
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -43,7 +43,7 @@ Fixed import name scope issue.</PackageReleaseNotes>
</ItemGroup>

<ItemGroup>
<PackageReference Include="Google.Protobuf" Version="3.6.1" />
<PackageReference Include="Google.Protobuf" Version="3.7.0" />
<PackageReference Include="NumSharp" Version="0.7.3" />
</ItemGroup>



+ 1
- 1
src/TensorFlowNET.Core/Variables/VariableScope.cs View File

@@ -33,7 +33,7 @@ namespace Tensorflow
string name,
TensorShape shape = null,
TF_DataType dtype = TF_DataType.DtInvalid,
IInitializer initializer = null,
object initializer = null, // IInitializer or Tensor
bool? trainable = null,
VariableSynchronization synchronization = VariableSynchronization.AUTO,
VariableAggregation aggregation= VariableAggregation.NONE)


+ 63
- 5
src/TensorFlowNET.Core/Variables/_VariableStore.cs View File

@@ -23,7 +23,7 @@ namespace Tensorflow
public RefVariable get_variable(string name,
TensorShape shape = null,
TF_DataType dtype = TF_DataType.TF_FLOAT,
IInitializer initializer = null,
object initializer = null, // IInitializer or Tensor
bool? trainable = null,
bool validate_shape = true,
VariableSynchronization synchronization = VariableSynchronization.AUTO,
@@ -45,7 +45,7 @@ namespace Tensorflow
private RefVariable _true_getter(string name,
TensorShape shape = null,
TF_DataType dtype = TF_DataType.TF_FLOAT,
IInitializer initializer = null,
object initializer = null,
bool? trainable = null,
bool validate_shape = true,
VariableSynchronization synchronization = VariableSynchronization.AUTO,
@@ -53,14 +53,32 @@ namespace Tensorflow
{
bool is_scalar = shape.NDim == 0;

return _get_single_variable(name: name,
shape: shape,
if (initializer is IInitializer init)
{
return _get_single_variable(name: name,
shape: shape,
dtype: dtype,
initializer: initializer,
initializer: init,
trainable: trainable,
validate_shape: validate_shape,
synchronization: synchronization,
aggregation: aggregation);
}
else if (initializer is Tensor tensor)
{
return _get_single_variable(name: name,
shape: shape,
dtype: dtype,
initializer: tensor,
trainable: trainable,
validate_shape: validate_shape,
synchronization: synchronization,
aggregation: aggregation);
}
else
{
throw new NotImplementedException("_true_getter");
}
}

private RefVariable _get_single_variable(string name,
@@ -125,5 +143,45 @@ namespace Tensorflow

return v;
}

private RefVariable _get_single_variable(string name,
TensorShape shape = null,
TF_DataType dtype = TF_DataType.DtInvalid,
Tensor initializer = null,
bool reuse = false,
bool? trainable = null,
bool validate_shape = false,
bool? use_resource = null,
VariableSynchronization synchronization = VariableSynchronization.AUTO,
VariableAggregation aggregation = VariableAggregation.NONE)
{
if (use_resource == null)
use_resource = false;

if (_vars.ContainsKey(name))
{
if (!reuse)
{
var var = _vars[name];

}
throw new NotImplementedException("_get_single_variable");
}

RefVariable v = null;
// Create the variable.
ops.init_scope();
{
var init_val = initializer;
v = new RefVariable(init_val,
name: name,
validate_shape: validate_shape,
trainable: trainable.Value);
}

_vars[name] = v;

return v;
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Variables/tf.variable.cs View File

@@ -15,7 +15,7 @@ namespace Tensorflow
public static RefVariable get_variable(string name,
TensorShape shape = null,
TF_DataType dtype = TF_DataType.DtInvalid,
IInitializer initializer = null,
object initializer = null, // IInitializer or Tensor
bool? trainable = null,
VariableSynchronization synchronization = VariableSynchronization.AUTO,
VariableAggregation aggregation = VariableAggregation.NONE)


+ 11
- 4
src/TensorFlowNET.Core/tf.cs View File

@@ -12,20 +12,27 @@ namespace Tensorflow
public static TF_DataType float16 = TF_DataType.TF_HALF;
public static TF_DataType float32 = TF_DataType.TF_FLOAT;
public static TF_DataType float64 = TF_DataType.TF_DOUBLE;
public static TF_DataType boolean = TF_DataType.TF_BOOL;
public static TF_DataType chars = TF_DataType.TF_STRING;

public static Context context = new Context(new ContextOptions(), new Status());

public static Session defaultSession;

public static RefVariable Variable<T>(T data, string name = null, TF_DataType dtype = TF_DataType.DtInvalid)
public static RefVariable Variable<T>(T data,
bool trainable = true,
string name = null,
TF_DataType dtype = TF_DataType.DtInvalid)
{
return Tensorflow.variable_scope.default_variable_creator(data, name: name, dtype: TF_DataType.DtInvalid);
return Tensorflow.variable_scope.default_variable_creator(data,
trainable: trainable,
name: name,
dtype: TF_DataType.DtInvalid);
}

public static unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null)
public static unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null)
{
return gen_array_ops.placeholder(dtype, shape);
return gen_array_ops.placeholder(dtype, shape, name);
}

public static void enable_eager_execution()


+ 0
- 1
test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj View File

@@ -13,7 +13,6 @@

<ItemGroup>
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" />
<ProjectReference Include="..\..\..\scikit-learn.net\src\scikit-learn\scikit-learn.csproj" />
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" />
</ItemGroup>



+ 1
- 1
test/TensorFlowNET.Examples/TextClassification/DataHelpers.cs View File

@@ -77,7 +77,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification

var positive_labels = positive_examples.Select(x => new int[2] { 0, 1 }).ToArray();
var negative_labels = negative_examples.Select(x => new int[2] { 1, 0 }).ToArray();
var y = np.array(1);// np.concatenate(new int[][][] { positive_labels, negative_labels });
var y = np.concatenate(new int[][][] { positive_labels, negative_labels });
return (x_text.ToArray(), y);
}



+ 43
- 1
test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs View File

@@ -5,6 +5,7 @@ using System.IO;
using System.Linq;
using System.Text;
using Tensorflow;
using TensorFlowNET.Examples.TextClassification;
using TensorFlowNET.Examples.Utility;

namespace TensorFlowNET.Examples.CnnTextClassification
@@ -18,13 +19,21 @@ namespace TensorFlowNET.Examples.CnnTextClassification
private string dataFileName = "dbpedia_csv.tar.gz";

private const int CHAR_MAX_LEN = 1014;
private const int NUM_CLASS = 2;

public void Run()
{
download_dbpedia();
Console.WriteLine("Building dataset...");
var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "vdcnn", CHAR_MAX_LEN);
//var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15);

var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);

with(tf.Session(), sess =>
{
new VdCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS);

});
}

public void download_dbpedia()
@@ -33,5 +42,38 @@ namespace TensorFlowNET.Examples.CnnTextClassification
Web.Download(url, dataDir, dataFileName);
Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir);
}

private (int[][], int[][], int[], int[]) train_test_split(int[][] x, int[] y, float test_size = 0.3f)
{
int len = x.Length;
int classes = y.Distinct().Count();
int samples = len / classes;
int train_size = int.Parse((samples * (1 - test_size)).ToString());

var train_x = new List<int[]>();
var valid_x = new List<int[]>();
var train_y = new List<int>();
var valid_y = new List<int>();

for (int i = 0; i< classes; i++)
{
for (int j = 0; j < samples; j++)
{
int idx = i * samples + j;
if (idx < train_size + samples * i)
{
train_x.Add(x[idx]);
train_y.Add(y[idx]);
}
else
{
valid_x.Add(x[idx]);
valid_y.Add(y[idx]);
}
}
}

return (train_x.ToArray(), valid_x.ToArray(), train_y.ToArray(), valid_y.ToArray());
}
}
}

+ 44
- 0
test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs View File

@@ -0,0 +1,44 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;

namespace TensorFlowNET.Examples.TextClassification
{
public class VdCnn : Python
{
private int embedding_size;
private int[] filter_sizes;
private int[] num_filters;
private int[] num_blocks;
private float learning_rate;
private IInitializer cnn_initializer;
private Tensor x;
private Tensor y;
private Tensor is_training;
private RefVariable global_step;
private RefVariable embeddings;
private Tensor x_emb;

public VdCnn(int alphabet_size, int document_max_len, int num_class)
{
embedding_size = 16;
filter_sizes = new int[] { 3, 3, 3, 3, 3 };
num_filters = new int[] { 64, 64, 128, 256, 512 };
num_blocks = new int[] { 2, 2, 2, 2 };
learning_rate = 0.001f;
cnn_initializer = tf.keras.initializers.he_normal();
x = tf.placeholder(tf.int32, new TensorShape(-1, document_max_len), name: "x");
y = tf.placeholder(tf.int32, new TensorShape(-1), name: "y");
is_training = tf.placeholder(tf.boolean, new TensorShape(), name: "is_training");
global_step = tf.Variable(0, trainable: false);

with(tf.name_scope("embedding"), delegate
{
var init_embeddings = tf.random_uniform(new int[] { alphabet_size, embedding_size }, -1.0f, 1.0f);
embeddings = tf.get_variable("embeddings", initializer: init_embeddings);
// x_emb = tf.nn.embedding_lookup(embeddings, x);
});
}
}
}

+ 3
- 2
test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj View File

@@ -16,14 +16,15 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.9.0" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.0.1" />
<PackageReference Include="MSTest.TestAdapter" Version="1.4.0" />
<PackageReference Include="MSTest.TestFramework" Version="1.4.0" />
<PackageReference Include="NumSharp" Version="0.7.3" />
<PackageReference Include="TensorFlow.NET" Version="0.3.0" />
<PackageReference Include="TensorFlow.NET" Version="0.4.2" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" />
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" />
</ItemGroup>



+ 1
- 1
test/TensorFlowNET.UnitTest/VariableTest.cs View File

@@ -26,7 +26,7 @@ namespace TensorFlowNET.UnitTest
[TestMethod]
public void StringVar()
{
var mammal1 = tf.Variable("Elephant", "var1", tf.chars);
var mammal1 = tf.Variable("Elephant", name: "var1", dtype: tf.chars);
var mammal2 = tf.Variable("Tiger");
}



Loading…
Cancel
Save