From a4e52abe3aca9e13d4e9ddd30ce07b69fb5800f0 Mon Sep 17 00:00:00 2001 From: haiping008 Date: Wed, 6 Mar 2019 14:29:22 -0600 Subject: [PATCH] add tf.keras --- TensorFlow.NET.sln | 16 +-- src/TensorFlowNET.Core/APIs/tf.init.cs | 124 +----------------- src/TensorFlowNET.Core/APIs/tf.random.cs | 7 + src/TensorFlowNET.Core/Keras/Initializers.cs | 20 +++ src/TensorFlowNET.Core/Keras/tf.keras.cs | 15 +++ .../Operations/Initializers/GlorotUniform.cs | 30 +++++ .../{ => Initializers}/IInitializer.cs | 2 +- .../Initializers/TruncatedNormal.cs | 41 ++++++ .../Initializers/VarianceScaling.cs | 82 ++++++++++++ .../Operations/Initializers/Zeros.cs | 29 ++++ .../Operations/embedding_ops.cs | 14 ++ .../TensorFlowNET.Core.csproj | 2 +- .../Variables/VariableScope.cs | 2 +- .../Variables/_VariableStore.cs | 68 +++++++++- .../Variables/tf.variable.cs | 2 +- src/TensorFlowNET.Core/tf.cs | 15 ++- .../TensorFlowNET.Examples.csproj | 1 - .../TextClassification/DataHelpers.cs | 2 +- .../TextClassificationTrain.cs | 44 ++++++- .../TextClassification/cnn_models/VdCnn.cs | 44 +++++++ .../TensorFlowNET.UnitTest.csproj | 5 +- test/TensorFlowNET.UnitTest/VariableTest.cs | 2 +- 22 files changed, 414 insertions(+), 153 deletions(-) create mode 100644 src/TensorFlowNET.Core/Keras/Initializers.cs create mode 100644 src/TensorFlowNET.Core/Keras/tf.keras.cs create mode 100644 src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs rename src/TensorFlowNET.Core/Operations/{ => Initializers}/IInitializer.cs (67%) create mode 100644 src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs create mode 100644 src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs create mode 100644 src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs create mode 100644 src/TensorFlowNET.Core/Operations/embedding_ops.cs create mode 100644 test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln index 25a97e6d..4f866b47 100644 --- a/TensorFlow.NET.sln +++ b/TensorFlow.NET.sln @@ -11,9 +11,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\T EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Visualization", "TensorFlowNET.Visualization\TensorFlowNET.Visualization.csproj", "{4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{E8340C61-12C1-4BEE-A340-403E7C1ACD82}" -EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "scikit-learn", "..\scikit-learn.net\src\scikit-learn\scikit-learn.csproj", "{199DDAD8-4A6F-43B3-A560-C0393619E304}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{1F4C5683-48B4-4328-8171-E9F7ABAA2E72}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -37,14 +35,10 @@ Global {4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Debug|Any CPU.Build.0 = Debug|Any CPU {4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Release|Any CPU.ActiveCfg = Release|Any CPU {4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Release|Any CPU.Build.0 = Release|Any CPU - {E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Debug|Any CPU.Build.0 = Debug|Any CPU - {E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Release|Any CPU.ActiveCfg = Release|Any CPU - {E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Release|Any CPU.Build.0 = Release|Any CPU - {199DDAD8-4A6F-43B3-A560-C0393619E304}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {199DDAD8-4A6F-43B3-A560-C0393619E304}.Debug|Any CPU.Build.0 = Debug|Any CPU - {199DDAD8-4A6F-43B3-A560-C0393619E304}.Release|Any CPU.ActiveCfg = Release|Any CPU - {199DDAD8-4A6F-43B3-A560-C0393619E304}.Release|Any CPU.Build.0 = Release|Any CPU + {1F4C5683-48B4-4328-8171-E9F7ABAA2E72}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {1F4C5683-48B4-4328-8171-E9F7ABAA2E72}.Debug|Any CPU.Build.0 = Debug|Any CPU + {1F4C5683-48B4-4328-8171-E9F7ABAA2E72}.Release|Any CPU.ActiveCfg = Release|Any CPU + {1F4C5683-48B4-4328-8171-E9F7ABAA2E72}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/src/TensorFlowNET.Core/APIs/tf.init.cs b/src/TensorFlowNET.Core/APIs/tf.init.cs index 19876d62..4863b510 100644 --- a/src/TensorFlowNET.Core/APIs/tf.init.cs +++ b/src/TensorFlowNET.Core/APIs/tf.init.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Operations.Initializers; namespace Tensorflow { @@ -24,128 +25,5 @@ namespace Tensorflow default_name, values, auxiliary_name_scope); - - public class Zeros : IInitializer - { - private TF_DataType dtype; - - public Zeros(TF_DataType dtype = TF_DataType.TF_FLOAT) - { - this.dtype = dtype; - } - - public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) - { - if (dtype == TF_DataType.DtInvalid) - dtype = this.dtype; - - return array_ops.zeros(shape, dtype); - } - - public object get_config() - { - return new { dtype = dtype.name() }; - } - } - - /// - /// Initializer capable of adapting its scale to the shape of weights tensors. - /// - public class VarianceScaling : IInitializer - { - protected float _scale; - protected string _mode; - protected string _distribution; - protected int? _seed; - protected TF_DataType _dtype; - - public VarianceScaling(float scale = 1.0f, - string mode = "fan_in", - string distribution= "truncated_normal", - int? seed = null, - TF_DataType dtype = TF_DataType.TF_FLOAT) - { - if (scale < 0) - throw new ValueError("`scale` must be positive float."); - _scale = scale; - _mode = mode; - _distribution = distribution; - _seed = seed; - _dtype = dtype; - } - - public Tensor call(TensorShape shape, TF_DataType dtype) - { - var (fan_in, fan_out) = _compute_fans(shape); - if (_mode == "fan_in") - _scale /= Math.Max(1, fan_in); - else if (_mode == "fan_out") - _scale /= Math.Max(1, fan_out); - else - _scale /= Math.Max(1, (fan_in + fan_out) / 2); - - if (_distribution == "normal" || _distribution == "truncated_normal") - { - throw new NotImplementedException("truncated_normal"); - } - else if(_distribution == "untruncated_normal") - { - throw new NotImplementedException("truncated_normal"); - } - else - { - var limit = Math.Sqrt(3.0f * _scale); - return random_ops.random_uniform(shape, (float)-limit, (float)limit, dtype, seed: _seed); - } - } - - private (int, int) _compute_fans(int[] shape) - { - if (shape.Length < 1) - return (1, 1); - if (shape.Length == 1) - return (shape[0], shape[0]); - if (shape.Length == 2) - return (shape[0], shape[1]); - else - throw new NotImplementedException("VarianceScaling._compute_fans"); - } - - public virtual object get_config() - { - return new - { - scale = _scale, - mode = _mode, - distribution = _distribution, - seed = _seed, - dtype = _dtype - }; - } - } - - public class GlorotUniform : VarianceScaling - { - public GlorotUniform(float scale = 1.0f, - string mode = "fan_avg", - string distribution = "uniform", - int? seed = null, - TF_DataType dtype = TF_DataType.TF_FLOAT) : base(scale, mode, distribution, seed, dtype) - { - - } - - public object get_config() - { - return new - { - scale = _scale, - mode = _mode, - distribution = _distribution, - seed = _seed, - dtype = _dtype - }; - } - } } } diff --git a/src/TensorFlowNET.Core/APIs/tf.random.cs b/src/TensorFlowNET.Core/APIs/tf.random.cs index f8eff7e7..3e273e61 100644 --- a/src/TensorFlowNET.Core/APIs/tf.random.cs +++ b/src/TensorFlowNET.Core/APIs/tf.random.cs @@ -22,5 +22,12 @@ namespace Tensorflow TF_DataType dtype = TF_DataType.TF_FLOAT, int? seed = null, string name = null) => random_ops.random_normal(shape, mean, stddev, dtype, seed, name); + + public static Tensor random_uniform(int[] shape, + float minval = 0, + float? maxval = null, + TF_DataType dtype = TF_DataType.TF_FLOAT, + int? seed = null, + string name = null) => random_ops.random_uniform(shape, minval, maxval, dtype, seed, name); } } diff --git a/src/TensorFlowNET.Core/Keras/Initializers.cs b/src/TensorFlowNET.Core/Keras/Initializers.cs new file mode 100644 index 00000000..cea77ae9 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Initializers.cs @@ -0,0 +1,20 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Operations.Initializers; + +namespace Tensorflow.Keras +{ + public class Initializers + { + /// + /// He normal initializer. + /// + /// + /// + public IInitializer he_normal(int? seed = null) + { + return new VarianceScaling(scale: 20f, mode: "fan_in", distribution: "truncated_normal", seed: seed); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/tf.keras.cs b/src/TensorFlowNET.Core/Keras/tf.keras.cs new file mode 100644 index 00000000..73b8e0a0 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/tf.keras.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras; + +namespace Tensorflow +{ + public static partial class tf + { + public static class keras + { + public static Initializers initializers => new Initializers(); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs b/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs new file mode 100644 index 00000000..5d905583 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs @@ -0,0 +1,30 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations.Initializers +{ + public class GlorotUniform : VarianceScaling + { + public GlorotUniform(float scale = 1.0f, + string mode = "fan_avg", + string distribution = "uniform", + int? seed = null, + TF_DataType dtype = TF_DataType.TF_FLOAT) : base(scale, mode, distribution, seed, dtype) + { + + } + + public object get_config() + { + return new + { + scale = _scale, + mode = _mode, + distribution = _distribution, + seed = _seed, + dtype = _dtype + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/IInitializer.cs b/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs similarity index 67% rename from src/TensorFlowNET.Core/Operations/IInitializer.cs rename to src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs index 6382e3e0..422bf95d 100644 --- a/src/TensorFlowNET.Core/Operations/IInitializer.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs @@ -6,7 +6,7 @@ namespace Tensorflow { public interface IInitializer { - Tensor call(TensorShape shape, TF_DataType dtype); + Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid); object get_config(); } } diff --git a/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs new file mode 100644 index 00000000..4c0a7cee --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs @@ -0,0 +1,41 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations.Initializers +{ + public class TruncatedNormal : IInitializer + { + private float mean; + private float stddev; + private int? seed; + private TF_DataType dtype; + + public TruncatedNormal(float mean = 0.0f, + float stddev = 1.0f, + int? seed = null, + TF_DataType dtype = TF_DataType.TF_FLOAT) + { + this.mean = mean; + this.stddev = stddev; + this.seed = seed; + this.dtype = dtype; + } + + public Tensor call(TensorShape shape, TF_DataType dtype) + { + throw new NotImplementedException(""); + } + + public object get_config() + { + return new + { + mean = mean, + stddev = stddev, + seed = seed, + dtype = dtype.name() + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs new file mode 100644 index 00000000..0fcaf392 --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs @@ -0,0 +1,82 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations.Initializers +{ + /// + /// Initializer capable of adapting its scale to the shape of weights tensors. + /// + public class VarianceScaling : IInitializer + { + protected float _scale; + protected string _mode; + protected string _distribution; + protected int? _seed; + protected TF_DataType _dtype; + + public VarianceScaling(float scale = 1.0f, + string mode = "fan_in", + string distribution = "truncated_normal", + int? seed = null, + TF_DataType dtype = TF_DataType.TF_FLOAT) + { + if (scale < 0) + throw new ValueError("`scale` must be positive float."); + _scale = scale; + _mode = mode; + _distribution = distribution; + _seed = seed; + _dtype = dtype; + } + + public Tensor call(TensorShape shape, TF_DataType dtype) + { + var (fan_in, fan_out) = _compute_fans(shape); + if (_mode == "fan_in") + _scale /= Math.Max(1, fan_in); + else if (_mode == "fan_out") + _scale /= Math.Max(1, fan_out); + else + _scale /= Math.Max(1, (fan_in + fan_out) / 2); + + if (_distribution == "normal" || _distribution == "truncated_normal") + { + throw new NotImplementedException("truncated_normal"); + } + else if (_distribution == "untruncated_normal") + { + throw new NotImplementedException("truncated_normal"); + } + else + { + var limit = Math.Sqrt(3.0f * _scale); + return random_ops.random_uniform(shape, (float)-limit, (float)limit, dtype, seed: _seed); + } + } + + private (int, int) _compute_fans(int[] shape) + { + if (shape.Length < 1) + return (1, 1); + if (shape.Length == 1) + return (shape[0], shape[0]); + if (shape.Length == 2) + return (shape[0], shape[1]); + else + throw new NotImplementedException("VarianceScaling._compute_fans"); + } + + public virtual object get_config() + { + return new + { + scale = _scale, + mode = _mode, + distribution = _distribution, + seed = _seed, + dtype = _dtype + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs b/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs new file mode 100644 index 00000000..ca1f42df --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs @@ -0,0 +1,29 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Operations.Initializers +{ + public class Zeros : IInitializer + { + private TF_DataType dtype; + + public Zeros(TF_DataType dtype = TF_DataType.TF_FLOAT) + { + this.dtype = dtype; + } + + public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) + { + if (dtype == TF_DataType.DtInvalid) + dtype = this.dtype; + + return array_ops.zeros(shape, dtype); + } + + public object get_config() + { + return new { dtype = dtype.name() }; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/embedding_ops.cs b/src/TensorFlowNET.Core/Operations/embedding_ops.cs new file mode 100644 index 00000000..27fba04b --- /dev/null +++ b/src/TensorFlowNET.Core/Operations/embedding_ops.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public class embedding_ops + { + public Tensor _embedding_lookup_and_transform() + { + throw new NotImplementedException(""); + } + } +} diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 949166c1..1e4f2065 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -43,7 +43,7 @@ Fixed import name scope issue. - + diff --git a/src/TensorFlowNET.Core/Variables/VariableScope.cs b/src/TensorFlowNET.Core/Variables/VariableScope.cs index 29c03c19..c3453744 100644 --- a/src/TensorFlowNET.Core/Variables/VariableScope.cs +++ b/src/TensorFlowNET.Core/Variables/VariableScope.cs @@ -33,7 +33,7 @@ namespace Tensorflow string name, TensorShape shape = null, TF_DataType dtype = TF_DataType.DtInvalid, - IInitializer initializer = null, + object initializer = null, // IInitializer or Tensor bool? trainable = null, VariableSynchronization synchronization = VariableSynchronization.AUTO, VariableAggregation aggregation= VariableAggregation.NONE) diff --git a/src/TensorFlowNET.Core/Variables/_VariableStore.cs b/src/TensorFlowNET.Core/Variables/_VariableStore.cs index 5bd8c86d..9f067bf8 100644 --- a/src/TensorFlowNET.Core/Variables/_VariableStore.cs +++ b/src/TensorFlowNET.Core/Variables/_VariableStore.cs @@ -23,7 +23,7 @@ namespace Tensorflow public RefVariable get_variable(string name, TensorShape shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT, - IInitializer initializer = null, + object initializer = null, // IInitializer or Tensor bool? trainable = null, bool validate_shape = true, VariableSynchronization synchronization = VariableSynchronization.AUTO, @@ -45,7 +45,7 @@ namespace Tensorflow private RefVariable _true_getter(string name, TensorShape shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT, - IInitializer initializer = null, + object initializer = null, bool? trainable = null, bool validate_shape = true, VariableSynchronization synchronization = VariableSynchronization.AUTO, @@ -53,14 +53,32 @@ namespace Tensorflow { bool is_scalar = shape.NDim == 0; - return _get_single_variable(name: name, - shape: shape, + if (initializer is IInitializer init) + { + return _get_single_variable(name: name, + shape: shape, dtype: dtype, - initializer: initializer, + initializer: init, trainable: trainable, validate_shape: validate_shape, synchronization: synchronization, aggregation: aggregation); + } + else if (initializer is Tensor tensor) + { + return _get_single_variable(name: name, + shape: shape, + dtype: dtype, + initializer: tensor, + trainable: trainable, + validate_shape: validate_shape, + synchronization: synchronization, + aggregation: aggregation); + } + else + { + throw new NotImplementedException("_true_getter"); + } } private RefVariable _get_single_variable(string name, @@ -125,5 +143,45 @@ namespace Tensorflow return v; } + + private RefVariable _get_single_variable(string name, + TensorShape shape = null, + TF_DataType dtype = TF_DataType.DtInvalid, + Tensor initializer = null, + bool reuse = false, + bool? trainable = null, + bool validate_shape = false, + bool? use_resource = null, + VariableSynchronization synchronization = VariableSynchronization.AUTO, + VariableAggregation aggregation = VariableAggregation.NONE) + { + if (use_resource == null) + use_resource = false; + + if (_vars.ContainsKey(name)) + { + if (!reuse) + { + var var = _vars[name]; + + } + throw new NotImplementedException("_get_single_variable"); + } + + RefVariable v = null; + // Create the variable. + ops.init_scope(); + { + var init_val = initializer; + v = new RefVariable(init_val, + name: name, + validate_shape: validate_shape, + trainable: trainable.Value); + } + + _vars[name] = v; + + return v; + } } } diff --git a/src/TensorFlowNET.Core/Variables/tf.variable.cs b/src/TensorFlowNET.Core/Variables/tf.variable.cs index b61b558e..aac58a9a 100644 --- a/src/TensorFlowNET.Core/Variables/tf.variable.cs +++ b/src/TensorFlowNET.Core/Variables/tf.variable.cs @@ -15,7 +15,7 @@ namespace Tensorflow public static RefVariable get_variable(string name, TensorShape shape = null, TF_DataType dtype = TF_DataType.DtInvalid, - IInitializer initializer = null, + object initializer = null, // IInitializer or Tensor bool? trainable = null, VariableSynchronization synchronization = VariableSynchronization.AUTO, VariableAggregation aggregation = VariableAggregation.NONE) diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs index eeb415f5..e1e1331d 100644 --- a/src/TensorFlowNET.Core/tf.cs +++ b/src/TensorFlowNET.Core/tf.cs @@ -12,20 +12,27 @@ namespace Tensorflow public static TF_DataType float16 = TF_DataType.TF_HALF; public static TF_DataType float32 = TF_DataType.TF_FLOAT; public static TF_DataType float64 = TF_DataType.TF_DOUBLE; + public static TF_DataType boolean = TF_DataType.TF_BOOL; public static TF_DataType chars = TF_DataType.TF_STRING; public static Context context = new Context(new ContextOptions(), new Status()); public static Session defaultSession; - public static RefVariable Variable(T data, string name = null, TF_DataType dtype = TF_DataType.DtInvalid) + public static RefVariable Variable(T data, + bool trainable = true, + string name = null, + TF_DataType dtype = TF_DataType.DtInvalid) { - return Tensorflow.variable_scope.default_variable_creator(data, name: name, dtype: TF_DataType.DtInvalid); + return Tensorflow.variable_scope.default_variable_creator(data, + trainable: trainable, + name: name, + dtype: TF_DataType.DtInvalid); } - public static unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null) + public static unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null) { - return gen_array_ops.placeholder(dtype, shape); + return gen_array_ops.placeholder(dtype, shape, name); } public static void enable_eager_execution() diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj index 12a92226..c6c052fb 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj @@ -13,7 +13,6 @@ - diff --git a/test/TensorFlowNET.Examples/TextClassification/DataHelpers.cs b/test/TensorFlowNET.Examples/TextClassification/DataHelpers.cs index f705b98c..586a978a 100644 --- a/test/TensorFlowNET.Examples/TextClassification/DataHelpers.cs +++ b/test/TensorFlowNET.Examples/TextClassification/DataHelpers.cs @@ -77,7 +77,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification var positive_labels = positive_examples.Select(x => new int[2] { 0, 1 }).ToArray(); var negative_labels = negative_examples.Select(x => new int[2] { 1, 0 }).ToArray(); - var y = np.array(1);// np.concatenate(new int[][][] { positive_labels, negative_labels }); + var y = np.concatenate(new int[][][] { positive_labels, negative_labels }); return (x_text.ToArray(), y); } diff --git a/test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs b/test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs index 4ea583d4..e6e26ba6 100644 --- a/test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs +++ b/test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs @@ -5,6 +5,7 @@ using System.IO; using System.Linq; using System.Text; using Tensorflow; +using TensorFlowNET.Examples.TextClassification; using TensorFlowNET.Examples.Utility; namespace TensorFlowNET.Examples.CnnTextClassification @@ -18,13 +19,21 @@ namespace TensorFlowNET.Examples.CnnTextClassification private string dataFileName = "dbpedia_csv.tar.gz"; private const int CHAR_MAX_LEN = 1014; + private const int NUM_CLASS = 2; public void Run() { download_dbpedia(); Console.WriteLine("Building dataset..."); var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "vdcnn", CHAR_MAX_LEN); - //var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15); + + var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); + + with(tf.Session(), sess => + { + new VdCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS); + + }); } public void download_dbpedia() @@ -33,5 +42,38 @@ namespace TensorFlowNET.Examples.CnnTextClassification Web.Download(url, dataDir, dataFileName); Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir); } + + private (int[][], int[][], int[], int[]) train_test_split(int[][] x, int[] y, float test_size = 0.3f) + { + int len = x.Length; + int classes = y.Distinct().Count(); + int samples = len / classes; + int train_size = int.Parse((samples * (1 - test_size)).ToString()); + + var train_x = new List(); + var valid_x = new List(); + var train_y = new List(); + var valid_y = new List(); + + for (int i = 0; i< classes; i++) + { + for (int j = 0; j < samples; j++) + { + int idx = i * samples + j; + if (idx < train_size + samples * i) + { + train_x.Add(x[idx]); + train_y.Add(y[idx]); + } + else + { + valid_x.Add(x[idx]); + valid_y.Add(y[idx]); + } + } + } + + return (train_x.ToArray(), valid_x.ToArray(), train_y.ToArray(), valid_y.ToArray()); + } } } diff --git a/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs b/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs new file mode 100644 index 00000000..cbdcecee --- /dev/null +++ b/test/TensorFlowNET.Examples/TextClassification/cnn_models/VdCnn.cs @@ -0,0 +1,44 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow; + +namespace TensorFlowNET.Examples.TextClassification +{ + public class VdCnn : Python + { + private int embedding_size; + private int[] filter_sizes; + private int[] num_filters; + private int[] num_blocks; + private float learning_rate; + private IInitializer cnn_initializer; + private Tensor x; + private Tensor y; + private Tensor is_training; + private RefVariable global_step; + private RefVariable embeddings; + private Tensor x_emb; + + public VdCnn(int alphabet_size, int document_max_len, int num_class) + { + embedding_size = 16; + filter_sizes = new int[] { 3, 3, 3, 3, 3 }; + num_filters = new int[] { 64, 64, 128, 256, 512 }; + num_blocks = new int[] { 2, 2, 2, 2 }; + learning_rate = 0.001f; + cnn_initializer = tf.keras.initializers.he_normal(); + x = tf.placeholder(tf.int32, new TensorShape(-1, document_max_len), name: "x"); + y = tf.placeholder(tf.int32, new TensorShape(-1), name: "y"); + is_training = tf.placeholder(tf.boolean, new TensorShape(), name: "is_training"); + global_step = tf.Variable(0, trainable: false); + + with(tf.name_scope("embedding"), delegate + { + var init_embeddings = tf.random_uniform(new int[] { alphabet_size, embedding_size }, -1.0f, 1.0f); + embeddings = tf.get_variable("embeddings", initializer: init_embeddings); + // x_emb = tf.nn.embedding_lookup(embeddings, x); + }); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj index f495711c..eaf16ef8 100644 --- a/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj @@ -16,14 +16,15 @@ - + - + + diff --git a/test/TensorFlowNET.UnitTest/VariableTest.cs b/test/TensorFlowNET.UnitTest/VariableTest.cs index 7713e774..8254b774 100644 --- a/test/TensorFlowNET.UnitTest/VariableTest.cs +++ b/test/TensorFlowNET.UnitTest/VariableTest.cs @@ -26,7 +26,7 @@ namespace TensorFlowNET.UnitTest [TestMethod] public void StringVar() { - var mammal1 = tf.Variable("Elephant", "var1", tf.chars); + var mammal1 = tf.Variable("Elephant", name: "var1", dtype: tf.chars); var mammal2 = tf.Variable("Tiger"); }