@@ -11,9 +11,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\T | |||
EndProject | |||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Visualization", "TensorFlowNET.Visualization\TensorFlowNET.Visualization.csproj", "{4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}" | |||
EndProject | |||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{E8340C61-12C1-4BEE-A340-403E7C1ACD82}" | |||
EndProject | |||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "scikit-learn", "..\scikit-learn.net\src\scikit-learn\scikit-learn.csproj", "{199DDAD8-4A6F-43B3-A560-C0393619E304}" | |||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{1F4C5683-48B4-4328-8171-E9F7ABAA2E72}" | |||
EndProject | |||
Global | |||
GlobalSection(SolutionConfigurationPlatforms) = preSolution | |||
@@ -37,14 +35,10 @@ Global | |||
{4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
{4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
{4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Release|Any CPU.Build.0 = Release|Any CPU | |||
{E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||
{E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
{E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
{E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Release|Any CPU.Build.0 = Release|Any CPU | |||
{199DDAD8-4A6F-43B3-A560-C0393619E304}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||
{199DDAD8-4A6F-43B3-A560-C0393619E304}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
{199DDAD8-4A6F-43B3-A560-C0393619E304}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
{199DDAD8-4A6F-43B3-A560-C0393619E304}.Release|Any CPU.Build.0 = Release|Any CPU | |||
{1F4C5683-48B4-4328-8171-E9F7ABAA2E72}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||
{1F4C5683-48B4-4328-8171-E9F7ABAA2E72}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
{1F4C5683-48B4-4328-8171-E9F7ABAA2E72}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
{1F4C5683-48B4-4328-8171-E9F7ABAA2E72}.Release|Any CPU.Build.0 = Release|Any CPU | |||
EndGlobalSection | |||
GlobalSection(SolutionProperties) = preSolution | |||
HideSolutionNode = FALSE | |||
@@ -1,6 +1,7 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Operations.Initializers; | |||
namespace Tensorflow | |||
{ | |||
@@ -24,128 +25,5 @@ namespace Tensorflow | |||
default_name, | |||
values, | |||
auxiliary_name_scope); | |||
public class Zeros : IInitializer | |||
{ | |||
private TF_DataType dtype; | |||
public Zeros(TF_DataType dtype = TF_DataType.TF_FLOAT) | |||
{ | |||
this.dtype = dtype; | |||
} | |||
public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) | |||
{ | |||
if (dtype == TF_DataType.DtInvalid) | |||
dtype = this.dtype; | |||
return array_ops.zeros(shape, dtype); | |||
} | |||
public object get_config() | |||
{ | |||
return new { dtype = dtype.name() }; | |||
} | |||
} | |||
/// <summary> | |||
/// Initializer capable of adapting its scale to the shape of weights tensors. | |||
/// </summary> | |||
public class VarianceScaling : IInitializer | |||
{ | |||
protected float _scale; | |||
protected string _mode; | |||
protected string _distribution; | |||
protected int? _seed; | |||
protected TF_DataType _dtype; | |||
public VarianceScaling(float scale = 1.0f, | |||
string mode = "fan_in", | |||
string distribution= "truncated_normal", | |||
int? seed = null, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT) | |||
{ | |||
if (scale < 0) | |||
throw new ValueError("`scale` must be positive float."); | |||
_scale = scale; | |||
_mode = mode; | |||
_distribution = distribution; | |||
_seed = seed; | |||
_dtype = dtype; | |||
} | |||
public Tensor call(TensorShape shape, TF_DataType dtype) | |||
{ | |||
var (fan_in, fan_out) = _compute_fans(shape); | |||
if (_mode == "fan_in") | |||
_scale /= Math.Max(1, fan_in); | |||
else if (_mode == "fan_out") | |||
_scale /= Math.Max(1, fan_out); | |||
else | |||
_scale /= Math.Max(1, (fan_in + fan_out) / 2); | |||
if (_distribution == "normal" || _distribution == "truncated_normal") | |||
{ | |||
throw new NotImplementedException("truncated_normal"); | |||
} | |||
else if(_distribution == "untruncated_normal") | |||
{ | |||
throw new NotImplementedException("truncated_normal"); | |||
} | |||
else | |||
{ | |||
var limit = Math.Sqrt(3.0f * _scale); | |||
return random_ops.random_uniform(shape, (float)-limit, (float)limit, dtype, seed: _seed); | |||
} | |||
} | |||
private (int, int) _compute_fans(int[] shape) | |||
{ | |||
if (shape.Length < 1) | |||
return (1, 1); | |||
if (shape.Length == 1) | |||
return (shape[0], shape[0]); | |||
if (shape.Length == 2) | |||
return (shape[0], shape[1]); | |||
else | |||
throw new NotImplementedException("VarianceScaling._compute_fans"); | |||
} | |||
public virtual object get_config() | |||
{ | |||
return new | |||
{ | |||
scale = _scale, | |||
mode = _mode, | |||
distribution = _distribution, | |||
seed = _seed, | |||
dtype = _dtype | |||
}; | |||
} | |||
} | |||
public class GlorotUniform : VarianceScaling | |||
{ | |||
public GlorotUniform(float scale = 1.0f, | |||
string mode = "fan_avg", | |||
string distribution = "uniform", | |||
int? seed = null, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT) : base(scale, mode, distribution, seed, dtype) | |||
{ | |||
} | |||
public object get_config() | |||
{ | |||
return new | |||
{ | |||
scale = _scale, | |||
mode = _mode, | |||
distribution = _distribution, | |||
seed = _seed, | |||
dtype = _dtype | |||
}; | |||
} | |||
} | |||
} | |||
} |
@@ -22,5 +22,12 @@ namespace Tensorflow | |||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
int? seed = null, | |||
string name = null) => random_ops.random_normal(shape, mean, stddev, dtype, seed, name); | |||
public static Tensor random_uniform(int[] shape, | |||
float minval = 0, | |||
float? maxval = null, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
int? seed = null, | |||
string name = null) => random_ops.random_uniform(shape, minval, maxval, dtype, seed, name); | |||
} | |||
} |
@@ -0,0 +1,20 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Operations.Initializers; | |||
namespace Tensorflow.Keras | |||
{ | |||
public class Initializers | |||
{ | |||
/// <summary> | |||
/// He normal initializer. | |||
/// </summary> | |||
/// <param name="seed"></param> | |||
/// <returns></returns> | |||
public IInitializer he_normal(int? seed = null) | |||
{ | |||
return new VarianceScaling(scale: 20f, mode: "fan_in", distribution: "truncated_normal", seed: seed); | |||
} | |||
} | |||
} |
@@ -0,0 +1,15 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras; | |||
namespace Tensorflow | |||
{ | |||
public static partial class tf | |||
{ | |||
public static class keras | |||
{ | |||
public static Initializers initializers => new Initializers(); | |||
} | |||
} | |||
} |
@@ -0,0 +1,30 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Operations.Initializers | |||
{ | |||
public class GlorotUniform : VarianceScaling | |||
{ | |||
public GlorotUniform(float scale = 1.0f, | |||
string mode = "fan_avg", | |||
string distribution = "uniform", | |||
int? seed = null, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT) : base(scale, mode, distribution, seed, dtype) | |||
{ | |||
} | |||
public object get_config() | |||
{ | |||
return new | |||
{ | |||
scale = _scale, | |||
mode = _mode, | |||
distribution = _distribution, | |||
seed = _seed, | |||
dtype = _dtype | |||
}; | |||
} | |||
} | |||
} |
@@ -6,7 +6,7 @@ namespace Tensorflow | |||
{ | |||
public interface IInitializer | |||
{ | |||
Tensor call(TensorShape shape, TF_DataType dtype); | |||
Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid); | |||
object get_config(); | |||
} | |||
} |
@@ -0,0 +1,41 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Operations.Initializers | |||
{ | |||
public class TruncatedNormal : IInitializer | |||
{ | |||
private float mean; | |||
private float stddev; | |||
private int? seed; | |||
private TF_DataType dtype; | |||
public TruncatedNormal(float mean = 0.0f, | |||
float stddev = 1.0f, | |||
int? seed = null, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT) | |||
{ | |||
this.mean = mean; | |||
this.stddev = stddev; | |||
this.seed = seed; | |||
this.dtype = dtype; | |||
} | |||
public Tensor call(TensorShape shape, TF_DataType dtype) | |||
{ | |||
throw new NotImplementedException(""); | |||
} | |||
public object get_config() | |||
{ | |||
return new | |||
{ | |||
mean = mean, | |||
stddev = stddev, | |||
seed = seed, | |||
dtype = dtype.name() | |||
}; | |||
} | |||
} | |||
} |
@@ -0,0 +1,82 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Operations.Initializers | |||
{ | |||
/// <summary> | |||
/// Initializer capable of adapting its scale to the shape of weights tensors. | |||
/// </summary> | |||
public class VarianceScaling : IInitializer | |||
{ | |||
protected float _scale; | |||
protected string _mode; | |||
protected string _distribution; | |||
protected int? _seed; | |||
protected TF_DataType _dtype; | |||
public VarianceScaling(float scale = 1.0f, | |||
string mode = "fan_in", | |||
string distribution = "truncated_normal", | |||
int? seed = null, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT) | |||
{ | |||
if (scale < 0) | |||
throw new ValueError("`scale` must be positive float."); | |||
_scale = scale; | |||
_mode = mode; | |||
_distribution = distribution; | |||
_seed = seed; | |||
_dtype = dtype; | |||
} | |||
public Tensor call(TensorShape shape, TF_DataType dtype) | |||
{ | |||
var (fan_in, fan_out) = _compute_fans(shape); | |||
if (_mode == "fan_in") | |||
_scale /= Math.Max(1, fan_in); | |||
else if (_mode == "fan_out") | |||
_scale /= Math.Max(1, fan_out); | |||
else | |||
_scale /= Math.Max(1, (fan_in + fan_out) / 2); | |||
if (_distribution == "normal" || _distribution == "truncated_normal") | |||
{ | |||
throw new NotImplementedException("truncated_normal"); | |||
} | |||
else if (_distribution == "untruncated_normal") | |||
{ | |||
throw new NotImplementedException("truncated_normal"); | |||
} | |||
else | |||
{ | |||
var limit = Math.Sqrt(3.0f * _scale); | |||
return random_ops.random_uniform(shape, (float)-limit, (float)limit, dtype, seed: _seed); | |||
} | |||
} | |||
private (int, int) _compute_fans(int[] shape) | |||
{ | |||
if (shape.Length < 1) | |||
return (1, 1); | |||
if (shape.Length == 1) | |||
return (shape[0], shape[0]); | |||
if (shape.Length == 2) | |||
return (shape[0], shape[1]); | |||
else | |||
throw new NotImplementedException("VarianceScaling._compute_fans"); | |||
} | |||
public virtual object get_config() | |||
{ | |||
return new | |||
{ | |||
scale = _scale, | |||
mode = _mode, | |||
distribution = _distribution, | |||
seed = _seed, | |||
dtype = _dtype | |||
}; | |||
} | |||
} | |||
} |
@@ -0,0 +1,29 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Operations.Initializers | |||
{ | |||
public class Zeros : IInitializer | |||
{ | |||
private TF_DataType dtype; | |||
public Zeros(TF_DataType dtype = TF_DataType.TF_FLOAT) | |||
{ | |||
this.dtype = dtype; | |||
} | |||
public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) | |||
{ | |||
if (dtype == TF_DataType.DtInvalid) | |||
dtype = this.dtype; | |||
return array_ops.zeros(shape, dtype); | |||
} | |||
public object get_config() | |||
{ | |||
return new { dtype = dtype.name() }; | |||
} | |||
} | |||
} |
@@ -0,0 +1,14 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow | |||
{ | |||
public class embedding_ops | |||
{ | |||
public Tensor _embedding_lookup_and_transform() | |||
{ | |||
throw new NotImplementedException(""); | |||
} | |||
} | |||
} |
@@ -43,7 +43,7 @@ Fixed import name scope issue.</PackageReleaseNotes> | |||
</ItemGroup> | |||
<ItemGroup> | |||
<PackageReference Include="Google.Protobuf" Version="3.6.1" /> | |||
<PackageReference Include="Google.Protobuf" Version="3.7.0" /> | |||
<PackageReference Include="NumSharp" Version="0.7.3" /> | |||
</ItemGroup> | |||
@@ -33,7 +33,7 @@ namespace Tensorflow | |||
string name, | |||
TensorShape shape = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
IInitializer initializer = null, | |||
object initializer = null, // IInitializer or Tensor | |||
bool? trainable = null, | |||
VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||
VariableAggregation aggregation= VariableAggregation.NONE) | |||
@@ -23,7 +23,7 @@ namespace Tensorflow | |||
public RefVariable get_variable(string name, | |||
TensorShape shape = null, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
IInitializer initializer = null, | |||
object initializer = null, // IInitializer or Tensor | |||
bool? trainable = null, | |||
bool validate_shape = true, | |||
VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||
@@ -45,7 +45,7 @@ namespace Tensorflow | |||
private RefVariable _true_getter(string name, | |||
TensorShape shape = null, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
IInitializer initializer = null, | |||
object initializer = null, | |||
bool? trainable = null, | |||
bool validate_shape = true, | |||
VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||
@@ -53,14 +53,32 @@ namespace Tensorflow | |||
{ | |||
bool is_scalar = shape.NDim == 0; | |||
return _get_single_variable(name: name, | |||
shape: shape, | |||
if (initializer is IInitializer init) | |||
{ | |||
return _get_single_variable(name: name, | |||
shape: shape, | |||
dtype: dtype, | |||
initializer: initializer, | |||
initializer: init, | |||
trainable: trainable, | |||
validate_shape: validate_shape, | |||
synchronization: synchronization, | |||
aggregation: aggregation); | |||
} | |||
else if (initializer is Tensor tensor) | |||
{ | |||
return _get_single_variable(name: name, | |||
shape: shape, | |||
dtype: dtype, | |||
initializer: tensor, | |||
trainable: trainable, | |||
validate_shape: validate_shape, | |||
synchronization: synchronization, | |||
aggregation: aggregation); | |||
} | |||
else | |||
{ | |||
throw new NotImplementedException("_true_getter"); | |||
} | |||
} | |||
private RefVariable _get_single_variable(string name, | |||
@@ -125,5 +143,45 @@ namespace Tensorflow | |||
return v; | |||
} | |||
private RefVariable _get_single_variable(string name, | |||
TensorShape shape = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
Tensor initializer = null, | |||
bool reuse = false, | |||
bool? trainable = null, | |||
bool validate_shape = false, | |||
bool? use_resource = null, | |||
VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||
VariableAggregation aggregation = VariableAggregation.NONE) | |||
{ | |||
if (use_resource == null) | |||
use_resource = false; | |||
if (_vars.ContainsKey(name)) | |||
{ | |||
if (!reuse) | |||
{ | |||
var var = _vars[name]; | |||
} | |||
throw new NotImplementedException("_get_single_variable"); | |||
} | |||
RefVariable v = null; | |||
// Create the variable. | |||
ops.init_scope(); | |||
{ | |||
var init_val = initializer; | |||
v = new RefVariable(init_val, | |||
name: name, | |||
validate_shape: validate_shape, | |||
trainable: trainable.Value); | |||
} | |||
_vars[name] = v; | |||
return v; | |||
} | |||
} | |||
} |
@@ -15,7 +15,7 @@ namespace Tensorflow | |||
public static RefVariable get_variable(string name, | |||
TensorShape shape = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
IInitializer initializer = null, | |||
object initializer = null, // IInitializer or Tensor | |||
bool? trainable = null, | |||
VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||
VariableAggregation aggregation = VariableAggregation.NONE) | |||
@@ -12,20 +12,27 @@ namespace Tensorflow | |||
public static TF_DataType float16 = TF_DataType.TF_HALF; | |||
public static TF_DataType float32 = TF_DataType.TF_FLOAT; | |||
public static TF_DataType float64 = TF_DataType.TF_DOUBLE; | |||
public static TF_DataType boolean = TF_DataType.TF_BOOL; | |||
public static TF_DataType chars = TF_DataType.TF_STRING; | |||
public static Context context = new Context(new ContextOptions(), new Status()); | |||
public static Session defaultSession; | |||
public static RefVariable Variable<T>(T data, string name = null, TF_DataType dtype = TF_DataType.DtInvalid) | |||
public static RefVariable Variable<T>(T data, | |||
bool trainable = true, | |||
string name = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid) | |||
{ | |||
return Tensorflow.variable_scope.default_variable_creator(data, name: name, dtype: TF_DataType.DtInvalid); | |||
return Tensorflow.variable_scope.default_variable_creator(data, | |||
trainable: trainable, | |||
name: name, | |||
dtype: TF_DataType.DtInvalid); | |||
} | |||
public static unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null) | |||
public static unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null) | |||
{ | |||
return gen_array_ops.placeholder(dtype, shape); | |||
return gen_array_ops.placeholder(dtype, shape, name); | |||
} | |||
public static void enable_eager_execution() | |||
@@ -13,7 +13,6 @@ | |||
<ItemGroup> | |||
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||
<ProjectReference Include="..\..\..\scikit-learn.net\src\scikit-learn\scikit-learn.csproj" /> | |||
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | |||
</ItemGroup> | |||
@@ -77,7 +77,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification | |||
var positive_labels = positive_examples.Select(x => new int[2] { 0, 1 }).ToArray(); | |||
var negative_labels = negative_examples.Select(x => new int[2] { 1, 0 }).ToArray(); | |||
var y = np.array(1);// np.concatenate(new int[][][] { positive_labels, negative_labels }); | |||
var y = np.concatenate(new int[][][] { positive_labels, negative_labels }); | |||
return (x_text.ToArray(), y); | |||
} | |||
@@ -5,6 +5,7 @@ using System.IO; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow; | |||
using TensorFlowNET.Examples.TextClassification; | |||
using TensorFlowNET.Examples.Utility; | |||
namespace TensorFlowNET.Examples.CnnTextClassification | |||
@@ -18,13 +19,21 @@ namespace TensorFlowNET.Examples.CnnTextClassification | |||
private string dataFileName = "dbpedia_csv.tar.gz"; | |||
private const int CHAR_MAX_LEN = 1014; | |||
private const int NUM_CLASS = 2; | |||
public void Run() | |||
{ | |||
download_dbpedia(); | |||
Console.WriteLine("Building dataset..."); | |||
var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "vdcnn", CHAR_MAX_LEN); | |||
//var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15); | |||
var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); | |||
with(tf.Session(), sess => | |||
{ | |||
new VdCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS); | |||
}); | |||
} | |||
public void download_dbpedia() | |||
@@ -33,5 +42,38 @@ namespace TensorFlowNET.Examples.CnnTextClassification | |||
Web.Download(url, dataDir, dataFileName); | |||
Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir); | |||
} | |||
private (int[][], int[][], int[], int[]) train_test_split(int[][] x, int[] y, float test_size = 0.3f) | |||
{ | |||
int len = x.Length; | |||
int classes = y.Distinct().Count(); | |||
int samples = len / classes; | |||
int train_size = int.Parse((samples * (1 - test_size)).ToString()); | |||
var train_x = new List<int[]>(); | |||
var valid_x = new List<int[]>(); | |||
var train_y = new List<int>(); | |||
var valid_y = new List<int>(); | |||
for (int i = 0; i< classes; i++) | |||
{ | |||
for (int j = 0; j < samples; j++) | |||
{ | |||
int idx = i * samples + j; | |||
if (idx < train_size + samples * i) | |||
{ | |||
train_x.Add(x[idx]); | |||
train_y.Add(y[idx]); | |||
} | |||
else | |||
{ | |||
valid_x.Add(x[idx]); | |||
valid_y.Add(y[idx]); | |||
} | |||
} | |||
} | |||
return (train_x.ToArray(), valid_x.ToArray(), train_y.ToArray(), valid_y.ToArray()); | |||
} | |||
} | |||
} |
@@ -0,0 +1,44 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow; | |||
namespace TensorFlowNET.Examples.TextClassification | |||
{ | |||
public class VdCnn : Python | |||
{ | |||
private int embedding_size; | |||
private int[] filter_sizes; | |||
private int[] num_filters; | |||
private int[] num_blocks; | |||
private float learning_rate; | |||
private IInitializer cnn_initializer; | |||
private Tensor x; | |||
private Tensor y; | |||
private Tensor is_training; | |||
private RefVariable global_step; | |||
private RefVariable embeddings; | |||
private Tensor x_emb; | |||
public VdCnn(int alphabet_size, int document_max_len, int num_class) | |||
{ | |||
embedding_size = 16; | |||
filter_sizes = new int[] { 3, 3, 3, 3, 3 }; | |||
num_filters = new int[] { 64, 64, 128, 256, 512 }; | |||
num_blocks = new int[] { 2, 2, 2, 2 }; | |||
learning_rate = 0.001f; | |||
cnn_initializer = tf.keras.initializers.he_normal(); | |||
x = tf.placeholder(tf.int32, new TensorShape(-1, document_max_len), name: "x"); | |||
y = tf.placeholder(tf.int32, new TensorShape(-1), name: "y"); | |||
is_training = tf.placeholder(tf.boolean, new TensorShape(), name: "is_training"); | |||
global_step = tf.Variable(0, trainable: false); | |||
with(tf.name_scope("embedding"), delegate | |||
{ | |||
var init_embeddings = tf.random_uniform(new int[] { alphabet_size, embedding_size }, -1.0f, 1.0f); | |||
embeddings = tf.get_variable("embeddings", initializer: init_embeddings); | |||
// x_emb = tf.nn.embedding_lookup(embeddings, x); | |||
}); | |||
} | |||
} | |||
} |
@@ -16,14 +16,15 @@ | |||
</PropertyGroup> | |||
<ItemGroup> | |||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.9.0" /> | |||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.0.1" /> | |||
<PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | |||
<PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | |||
<PackageReference Include="NumSharp" Version="0.7.3" /> | |||
<PackageReference Include="TensorFlow.NET" Version="0.3.0" /> | |||
<PackageReference Include="TensorFlow.NET" Version="0.4.2" /> | |||
</ItemGroup> | |||
<ItemGroup> | |||
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | |||
</ItemGroup> | |||
@@ -26,7 +26,7 @@ namespace TensorFlowNET.UnitTest | |||
[TestMethod] | |||
public void StringVar() | |||
{ | |||
var mammal1 = tf.Variable("Elephant", "var1", tf.chars); | |||
var mammal1 = tf.Variable("Elephant", name: "var1", dtype: tf.chars); | |||
var mammal2 = tf.Variable("Tiger"); | |||
} | |||