@@ -11,9 +11,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\T | |||||
EndProject | EndProject | ||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Visualization", "TensorFlowNET.Visualization\TensorFlowNET.Visualization.csproj", "{4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}" | Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Visualization", "TensorFlowNET.Visualization\TensorFlowNET.Visualization.csproj", "{4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}" | ||||
EndProject | EndProject | ||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{E8340C61-12C1-4BEE-A340-403E7C1ACD82}" | |||||
EndProject | |||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "scikit-learn", "..\scikit-learn.net\src\scikit-learn\scikit-learn.csproj", "{199DDAD8-4A6F-43B3-A560-C0393619E304}" | |||||
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{1F4C5683-48B4-4328-8171-E9F7ABAA2E72}" | |||||
EndProject | EndProject | ||||
Global | Global | ||||
GlobalSection(SolutionConfigurationPlatforms) = preSolution | GlobalSection(SolutionConfigurationPlatforms) = preSolution | ||||
@@ -37,14 +35,10 @@ Global | |||||
{4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Debug|Any CPU.Build.0 = Debug|Any CPU | {4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Debug|Any CPU.Build.0 = Debug|Any CPU | ||||
{4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Release|Any CPU.ActiveCfg = Release|Any CPU | {4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Release|Any CPU.ActiveCfg = Release|Any CPU | ||||
{4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Release|Any CPU.Build.0 = Release|Any CPU | {4BB2ABD1-635E-41E4-B534-CB5B6A2D754D}.Release|Any CPU.Build.0 = Release|Any CPU | ||||
{E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||||
{E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||||
{E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||||
{E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Release|Any CPU.Build.0 = Release|Any CPU | |||||
{199DDAD8-4A6F-43B3-A560-C0393619E304}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||||
{199DDAD8-4A6F-43B3-A560-C0393619E304}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||||
{199DDAD8-4A6F-43B3-A560-C0393619E304}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||||
{199DDAD8-4A6F-43B3-A560-C0393619E304}.Release|Any CPU.Build.0 = Release|Any CPU | |||||
{1F4C5683-48B4-4328-8171-E9F7ABAA2E72}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||||
{1F4C5683-48B4-4328-8171-E9F7ABAA2E72}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||||
{1F4C5683-48B4-4328-8171-E9F7ABAA2E72}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||||
{1F4C5683-48B4-4328-8171-E9F7ABAA2E72}.Release|Any CPU.Build.0 = Release|Any CPU | |||||
EndGlobalSection | EndGlobalSection | ||||
GlobalSection(SolutionProperties) = preSolution | GlobalSection(SolutionProperties) = preSolution | ||||
HideSolutionNode = FALSE | HideSolutionNode = FALSE | ||||
@@ -1,6 +1,7 @@ | |||||
using System; | using System; | ||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Operations.Initializers; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -24,128 +25,5 @@ namespace Tensorflow | |||||
default_name, | default_name, | ||||
values, | values, | ||||
auxiliary_name_scope); | auxiliary_name_scope); | ||||
public class Zeros : IInitializer | |||||
{ | |||||
private TF_DataType dtype; | |||||
public Zeros(TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
{ | |||||
this.dtype = dtype; | |||||
} | |||||
public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
{ | |||||
if (dtype == TF_DataType.DtInvalid) | |||||
dtype = this.dtype; | |||||
return array_ops.zeros(shape, dtype); | |||||
} | |||||
public object get_config() | |||||
{ | |||||
return new { dtype = dtype.name() }; | |||||
} | |||||
} | |||||
/// <summary> | |||||
/// Initializer capable of adapting its scale to the shape of weights tensors. | |||||
/// </summary> | |||||
public class VarianceScaling : IInitializer | |||||
{ | |||||
protected float _scale; | |||||
protected string _mode; | |||||
protected string _distribution; | |||||
protected int? _seed; | |||||
protected TF_DataType _dtype; | |||||
public VarianceScaling(float scale = 1.0f, | |||||
string mode = "fan_in", | |||||
string distribution= "truncated_normal", | |||||
int? seed = null, | |||||
TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
{ | |||||
if (scale < 0) | |||||
throw new ValueError("`scale` must be positive float."); | |||||
_scale = scale; | |||||
_mode = mode; | |||||
_distribution = distribution; | |||||
_seed = seed; | |||||
_dtype = dtype; | |||||
} | |||||
public Tensor call(TensorShape shape, TF_DataType dtype) | |||||
{ | |||||
var (fan_in, fan_out) = _compute_fans(shape); | |||||
if (_mode == "fan_in") | |||||
_scale /= Math.Max(1, fan_in); | |||||
else if (_mode == "fan_out") | |||||
_scale /= Math.Max(1, fan_out); | |||||
else | |||||
_scale /= Math.Max(1, (fan_in + fan_out) / 2); | |||||
if (_distribution == "normal" || _distribution == "truncated_normal") | |||||
{ | |||||
throw new NotImplementedException("truncated_normal"); | |||||
} | |||||
else if(_distribution == "untruncated_normal") | |||||
{ | |||||
throw new NotImplementedException("truncated_normal"); | |||||
} | |||||
else | |||||
{ | |||||
var limit = Math.Sqrt(3.0f * _scale); | |||||
return random_ops.random_uniform(shape, (float)-limit, (float)limit, dtype, seed: _seed); | |||||
} | |||||
} | |||||
private (int, int) _compute_fans(int[] shape) | |||||
{ | |||||
if (shape.Length < 1) | |||||
return (1, 1); | |||||
if (shape.Length == 1) | |||||
return (shape[0], shape[0]); | |||||
if (shape.Length == 2) | |||||
return (shape[0], shape[1]); | |||||
else | |||||
throw new NotImplementedException("VarianceScaling._compute_fans"); | |||||
} | |||||
public virtual object get_config() | |||||
{ | |||||
return new | |||||
{ | |||||
scale = _scale, | |||||
mode = _mode, | |||||
distribution = _distribution, | |||||
seed = _seed, | |||||
dtype = _dtype | |||||
}; | |||||
} | |||||
} | |||||
public class GlorotUniform : VarianceScaling | |||||
{ | |||||
public GlorotUniform(float scale = 1.0f, | |||||
string mode = "fan_avg", | |||||
string distribution = "uniform", | |||||
int? seed = null, | |||||
TF_DataType dtype = TF_DataType.TF_FLOAT) : base(scale, mode, distribution, seed, dtype) | |||||
{ | |||||
} | |||||
public object get_config() | |||||
{ | |||||
return new | |||||
{ | |||||
scale = _scale, | |||||
mode = _mode, | |||||
distribution = _distribution, | |||||
seed = _seed, | |||||
dtype = _dtype | |||||
}; | |||||
} | |||||
} | |||||
} | } | ||||
} | } |
@@ -22,5 +22,12 @@ namespace Tensorflow | |||||
TF_DataType dtype = TF_DataType.TF_FLOAT, | TF_DataType dtype = TF_DataType.TF_FLOAT, | ||||
int? seed = null, | int? seed = null, | ||||
string name = null) => random_ops.random_normal(shape, mean, stddev, dtype, seed, name); | string name = null) => random_ops.random_normal(shape, mean, stddev, dtype, seed, name); | ||||
public static Tensor random_uniform(int[] shape, | |||||
float minval = 0, | |||||
float? maxval = null, | |||||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||||
int? seed = null, | |||||
string name = null) => random_ops.random_uniform(shape, minval, maxval, dtype, seed, name); | |||||
} | } | ||||
} | } |
@@ -0,0 +1,20 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Operations.Initializers; | |||||
namespace Tensorflow.Keras | |||||
{ | |||||
public class Initializers | |||||
{ | |||||
/// <summary> | |||||
/// He normal initializer. | |||||
/// </summary> | |||||
/// <param name="seed"></param> | |||||
/// <returns></returns> | |||||
public IInitializer he_normal(int? seed = null) | |||||
{ | |||||
return new VarianceScaling(scale: 20f, mode: "fan_in", distribution: "truncated_normal", seed: seed); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,15 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Keras; | |||||
namespace Tensorflow | |||||
{ | |||||
public static partial class tf | |||||
{ | |||||
public static class keras | |||||
{ | |||||
public static Initializers initializers => new Initializers(); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,30 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Operations.Initializers | |||||
{ | |||||
public class GlorotUniform : VarianceScaling | |||||
{ | |||||
public GlorotUniform(float scale = 1.0f, | |||||
string mode = "fan_avg", | |||||
string distribution = "uniform", | |||||
int? seed = null, | |||||
TF_DataType dtype = TF_DataType.TF_FLOAT) : base(scale, mode, distribution, seed, dtype) | |||||
{ | |||||
} | |||||
public object get_config() | |||||
{ | |||||
return new | |||||
{ | |||||
scale = _scale, | |||||
mode = _mode, | |||||
distribution = _distribution, | |||||
seed = _seed, | |||||
dtype = _dtype | |||||
}; | |||||
} | |||||
} | |||||
} |
@@ -6,7 +6,7 @@ namespace Tensorflow | |||||
{ | { | ||||
public interface IInitializer | public interface IInitializer | ||||
{ | { | ||||
Tensor call(TensorShape shape, TF_DataType dtype); | |||||
Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid); | |||||
object get_config(); | object get_config(); | ||||
} | } | ||||
} | } |
@@ -0,0 +1,41 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Operations.Initializers | |||||
{ | |||||
public class TruncatedNormal : IInitializer | |||||
{ | |||||
private float mean; | |||||
private float stddev; | |||||
private int? seed; | |||||
private TF_DataType dtype; | |||||
public TruncatedNormal(float mean = 0.0f, | |||||
float stddev = 1.0f, | |||||
int? seed = null, | |||||
TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
{ | |||||
this.mean = mean; | |||||
this.stddev = stddev; | |||||
this.seed = seed; | |||||
this.dtype = dtype; | |||||
} | |||||
public Tensor call(TensorShape shape, TF_DataType dtype) | |||||
{ | |||||
throw new NotImplementedException(""); | |||||
} | |||||
public object get_config() | |||||
{ | |||||
return new | |||||
{ | |||||
mean = mean, | |||||
stddev = stddev, | |||||
seed = seed, | |||||
dtype = dtype.name() | |||||
}; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,82 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Operations.Initializers | |||||
{ | |||||
/// <summary> | |||||
/// Initializer capable of adapting its scale to the shape of weights tensors. | |||||
/// </summary> | |||||
public class VarianceScaling : IInitializer | |||||
{ | |||||
protected float _scale; | |||||
protected string _mode; | |||||
protected string _distribution; | |||||
protected int? _seed; | |||||
protected TF_DataType _dtype; | |||||
public VarianceScaling(float scale = 1.0f, | |||||
string mode = "fan_in", | |||||
string distribution = "truncated_normal", | |||||
int? seed = null, | |||||
TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
{ | |||||
if (scale < 0) | |||||
throw new ValueError("`scale` must be positive float."); | |||||
_scale = scale; | |||||
_mode = mode; | |||||
_distribution = distribution; | |||||
_seed = seed; | |||||
_dtype = dtype; | |||||
} | |||||
public Tensor call(TensorShape shape, TF_DataType dtype) | |||||
{ | |||||
var (fan_in, fan_out) = _compute_fans(shape); | |||||
if (_mode == "fan_in") | |||||
_scale /= Math.Max(1, fan_in); | |||||
else if (_mode == "fan_out") | |||||
_scale /= Math.Max(1, fan_out); | |||||
else | |||||
_scale /= Math.Max(1, (fan_in + fan_out) / 2); | |||||
if (_distribution == "normal" || _distribution == "truncated_normal") | |||||
{ | |||||
throw new NotImplementedException("truncated_normal"); | |||||
} | |||||
else if (_distribution == "untruncated_normal") | |||||
{ | |||||
throw new NotImplementedException("truncated_normal"); | |||||
} | |||||
else | |||||
{ | |||||
var limit = Math.Sqrt(3.0f * _scale); | |||||
return random_ops.random_uniform(shape, (float)-limit, (float)limit, dtype, seed: _seed); | |||||
} | |||||
} | |||||
private (int, int) _compute_fans(int[] shape) | |||||
{ | |||||
if (shape.Length < 1) | |||||
return (1, 1); | |||||
if (shape.Length == 1) | |||||
return (shape[0], shape[0]); | |||||
if (shape.Length == 2) | |||||
return (shape[0], shape[1]); | |||||
else | |||||
throw new NotImplementedException("VarianceScaling._compute_fans"); | |||||
} | |||||
public virtual object get_config() | |||||
{ | |||||
return new | |||||
{ | |||||
scale = _scale, | |||||
mode = _mode, | |||||
distribution = _distribution, | |||||
seed = _seed, | |||||
dtype = _dtype | |||||
}; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,29 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Operations.Initializers | |||||
{ | |||||
public class Zeros : IInitializer | |||||
{ | |||||
private TF_DataType dtype; | |||||
public Zeros(TF_DataType dtype = TF_DataType.TF_FLOAT) | |||||
{ | |||||
this.dtype = dtype; | |||||
} | |||||
public Tensor call(TensorShape shape, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
{ | |||||
if (dtype == TF_DataType.DtInvalid) | |||||
dtype = this.dtype; | |||||
return array_ops.zeros(shape, dtype); | |||||
} | |||||
public object get_config() | |||||
{ | |||||
return new { dtype = dtype.name() }; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,14 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow | |||||
{ | |||||
public class embedding_ops | |||||
{ | |||||
public Tensor _embedding_lookup_and_transform() | |||||
{ | |||||
throw new NotImplementedException(""); | |||||
} | |||||
} | |||||
} |
@@ -43,7 +43,7 @@ Fixed import name scope issue.</PackageReleaseNotes> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="Google.Protobuf" Version="3.6.1" /> | |||||
<PackageReference Include="Google.Protobuf" Version="3.7.0" /> | |||||
<PackageReference Include="NumSharp" Version="0.7.3" /> | <PackageReference Include="NumSharp" Version="0.7.3" /> | ||||
</ItemGroup> | </ItemGroup> | ||||
@@ -33,7 +33,7 @@ namespace Tensorflow | |||||
string name, | string name, | ||||
TensorShape shape = null, | TensorShape shape = null, | ||||
TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
IInitializer initializer = null, | |||||
object initializer = null, // IInitializer or Tensor | |||||
bool? trainable = null, | bool? trainable = null, | ||||
VariableSynchronization synchronization = VariableSynchronization.AUTO, | VariableSynchronization synchronization = VariableSynchronization.AUTO, | ||||
VariableAggregation aggregation= VariableAggregation.NONE) | VariableAggregation aggregation= VariableAggregation.NONE) | ||||
@@ -23,7 +23,7 @@ namespace Tensorflow | |||||
public RefVariable get_variable(string name, | public RefVariable get_variable(string name, | ||||
TensorShape shape = null, | TensorShape shape = null, | ||||
TF_DataType dtype = TF_DataType.TF_FLOAT, | TF_DataType dtype = TF_DataType.TF_FLOAT, | ||||
IInitializer initializer = null, | |||||
object initializer = null, // IInitializer or Tensor | |||||
bool? trainable = null, | bool? trainable = null, | ||||
bool validate_shape = true, | bool validate_shape = true, | ||||
VariableSynchronization synchronization = VariableSynchronization.AUTO, | VariableSynchronization synchronization = VariableSynchronization.AUTO, | ||||
@@ -45,7 +45,7 @@ namespace Tensorflow | |||||
private RefVariable _true_getter(string name, | private RefVariable _true_getter(string name, | ||||
TensorShape shape = null, | TensorShape shape = null, | ||||
TF_DataType dtype = TF_DataType.TF_FLOAT, | TF_DataType dtype = TF_DataType.TF_FLOAT, | ||||
IInitializer initializer = null, | |||||
object initializer = null, | |||||
bool? trainable = null, | bool? trainable = null, | ||||
bool validate_shape = true, | bool validate_shape = true, | ||||
VariableSynchronization synchronization = VariableSynchronization.AUTO, | VariableSynchronization synchronization = VariableSynchronization.AUTO, | ||||
@@ -53,14 +53,32 @@ namespace Tensorflow | |||||
{ | { | ||||
bool is_scalar = shape.NDim == 0; | bool is_scalar = shape.NDim == 0; | ||||
return _get_single_variable(name: name, | |||||
shape: shape, | |||||
if (initializer is IInitializer init) | |||||
{ | |||||
return _get_single_variable(name: name, | |||||
shape: shape, | |||||
dtype: dtype, | dtype: dtype, | ||||
initializer: initializer, | |||||
initializer: init, | |||||
trainable: trainable, | trainable: trainable, | ||||
validate_shape: validate_shape, | validate_shape: validate_shape, | ||||
synchronization: synchronization, | synchronization: synchronization, | ||||
aggregation: aggregation); | aggregation: aggregation); | ||||
} | |||||
else if (initializer is Tensor tensor) | |||||
{ | |||||
return _get_single_variable(name: name, | |||||
shape: shape, | |||||
dtype: dtype, | |||||
initializer: tensor, | |||||
trainable: trainable, | |||||
validate_shape: validate_shape, | |||||
synchronization: synchronization, | |||||
aggregation: aggregation); | |||||
} | |||||
else | |||||
{ | |||||
throw new NotImplementedException("_true_getter"); | |||||
} | |||||
} | } | ||||
private RefVariable _get_single_variable(string name, | private RefVariable _get_single_variable(string name, | ||||
@@ -125,5 +143,45 @@ namespace Tensorflow | |||||
return v; | return v; | ||||
} | } | ||||
private RefVariable _get_single_variable(string name, | |||||
TensorShape shape = null, | |||||
TF_DataType dtype = TF_DataType.DtInvalid, | |||||
Tensor initializer = null, | |||||
bool reuse = false, | |||||
bool? trainable = null, | |||||
bool validate_shape = false, | |||||
bool? use_resource = null, | |||||
VariableSynchronization synchronization = VariableSynchronization.AUTO, | |||||
VariableAggregation aggregation = VariableAggregation.NONE) | |||||
{ | |||||
if (use_resource == null) | |||||
use_resource = false; | |||||
if (_vars.ContainsKey(name)) | |||||
{ | |||||
if (!reuse) | |||||
{ | |||||
var var = _vars[name]; | |||||
} | |||||
throw new NotImplementedException("_get_single_variable"); | |||||
} | |||||
RefVariable v = null; | |||||
// Create the variable. | |||||
ops.init_scope(); | |||||
{ | |||||
var init_val = initializer; | |||||
v = new RefVariable(init_val, | |||||
name: name, | |||||
validate_shape: validate_shape, | |||||
trainable: trainable.Value); | |||||
} | |||||
_vars[name] = v; | |||||
return v; | |||||
} | |||||
} | } | ||||
} | } |
@@ -15,7 +15,7 @@ namespace Tensorflow | |||||
public static RefVariable get_variable(string name, | public static RefVariable get_variable(string name, | ||||
TensorShape shape = null, | TensorShape shape = null, | ||||
TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
IInitializer initializer = null, | |||||
object initializer = null, // IInitializer or Tensor | |||||
bool? trainable = null, | bool? trainable = null, | ||||
VariableSynchronization synchronization = VariableSynchronization.AUTO, | VariableSynchronization synchronization = VariableSynchronization.AUTO, | ||||
VariableAggregation aggregation = VariableAggregation.NONE) | VariableAggregation aggregation = VariableAggregation.NONE) | ||||
@@ -12,20 +12,27 @@ namespace Tensorflow | |||||
public static TF_DataType float16 = TF_DataType.TF_HALF; | public static TF_DataType float16 = TF_DataType.TF_HALF; | ||||
public static TF_DataType float32 = TF_DataType.TF_FLOAT; | public static TF_DataType float32 = TF_DataType.TF_FLOAT; | ||||
public static TF_DataType float64 = TF_DataType.TF_DOUBLE; | public static TF_DataType float64 = TF_DataType.TF_DOUBLE; | ||||
public static TF_DataType boolean = TF_DataType.TF_BOOL; | |||||
public static TF_DataType chars = TF_DataType.TF_STRING; | public static TF_DataType chars = TF_DataType.TF_STRING; | ||||
public static Context context = new Context(new ContextOptions(), new Status()); | public static Context context = new Context(new ContextOptions(), new Status()); | ||||
public static Session defaultSession; | public static Session defaultSession; | ||||
public static RefVariable Variable<T>(T data, string name = null, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
public static RefVariable Variable<T>(T data, | |||||
bool trainable = true, | |||||
string name = null, | |||||
TF_DataType dtype = TF_DataType.DtInvalid) | |||||
{ | { | ||||
return Tensorflow.variable_scope.default_variable_creator(data, name: name, dtype: TF_DataType.DtInvalid); | |||||
return Tensorflow.variable_scope.default_variable_creator(data, | |||||
trainable: trainable, | |||||
name: name, | |||||
dtype: TF_DataType.DtInvalid); | |||||
} | } | ||||
public static unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null) | |||||
public static unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null, string name = null) | |||||
{ | { | ||||
return gen_array_ops.placeholder(dtype, shape); | |||||
return gen_array_ops.placeholder(dtype, shape, name); | |||||
} | } | ||||
public static void enable_eager_execution() | public static void enable_eager_execution() | ||||
@@ -13,7 +13,6 @@ | |||||
<ItemGroup> | <ItemGroup> | ||||
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | <ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | ||||
<ProjectReference Include="..\..\..\scikit-learn.net\src\scikit-learn\scikit-learn.csproj" /> | |||||
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | <ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | ||||
</ItemGroup> | </ItemGroup> | ||||
@@ -77,7 +77,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification | |||||
var positive_labels = positive_examples.Select(x => new int[2] { 0, 1 }).ToArray(); | var positive_labels = positive_examples.Select(x => new int[2] { 0, 1 }).ToArray(); | ||||
var negative_labels = negative_examples.Select(x => new int[2] { 1, 0 }).ToArray(); | var negative_labels = negative_examples.Select(x => new int[2] { 1, 0 }).ToArray(); | ||||
var y = np.array(1);// np.concatenate(new int[][][] { positive_labels, negative_labels }); | |||||
var y = np.concatenate(new int[][][] { positive_labels, negative_labels }); | |||||
return (x_text.ToArray(), y); | return (x_text.ToArray(), y); | ||||
} | } | ||||
@@ -5,6 +5,7 @@ using System.IO; | |||||
using System.Linq; | using System.Linq; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow; | using Tensorflow; | ||||
using TensorFlowNET.Examples.TextClassification; | |||||
using TensorFlowNET.Examples.Utility; | using TensorFlowNET.Examples.Utility; | ||||
namespace TensorFlowNET.Examples.CnnTextClassification | namespace TensorFlowNET.Examples.CnnTextClassification | ||||
@@ -18,13 +19,21 @@ namespace TensorFlowNET.Examples.CnnTextClassification | |||||
private string dataFileName = "dbpedia_csv.tar.gz"; | private string dataFileName = "dbpedia_csv.tar.gz"; | ||||
private const int CHAR_MAX_LEN = 1014; | private const int CHAR_MAX_LEN = 1014; | ||||
private const int NUM_CLASS = 2; | |||||
public void Run() | public void Run() | ||||
{ | { | ||||
download_dbpedia(); | download_dbpedia(); | ||||
Console.WriteLine("Building dataset..."); | Console.WriteLine("Building dataset..."); | ||||
var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "vdcnn", CHAR_MAX_LEN); | var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "vdcnn", CHAR_MAX_LEN); | ||||
//var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15); | |||||
var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); | |||||
with(tf.Session(), sess => | |||||
{ | |||||
new VdCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS); | |||||
}); | |||||
} | } | ||||
public void download_dbpedia() | public void download_dbpedia() | ||||
@@ -33,5 +42,38 @@ namespace TensorFlowNET.Examples.CnnTextClassification | |||||
Web.Download(url, dataDir, dataFileName); | Web.Download(url, dataDir, dataFileName); | ||||
Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir); | Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir); | ||||
} | } | ||||
private (int[][], int[][], int[], int[]) train_test_split(int[][] x, int[] y, float test_size = 0.3f) | |||||
{ | |||||
int len = x.Length; | |||||
int classes = y.Distinct().Count(); | |||||
int samples = len / classes; | |||||
int train_size = int.Parse((samples * (1 - test_size)).ToString()); | |||||
var train_x = new List<int[]>(); | |||||
var valid_x = new List<int[]>(); | |||||
var train_y = new List<int>(); | |||||
var valid_y = new List<int>(); | |||||
for (int i = 0; i< classes; i++) | |||||
{ | |||||
for (int j = 0; j < samples; j++) | |||||
{ | |||||
int idx = i * samples + j; | |||||
if (idx < train_size + samples * i) | |||||
{ | |||||
train_x.Add(x[idx]); | |||||
train_y.Add(y[idx]); | |||||
} | |||||
else | |||||
{ | |||||
valid_x.Add(x[idx]); | |||||
valid_y.Add(y[idx]); | |||||
} | |||||
} | |||||
} | |||||
return (train_x.ToArray(), valid_x.ToArray(), train_y.ToArray(), valid_y.ToArray()); | |||||
} | |||||
} | } | ||||
} | } |
@@ -0,0 +1,44 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow; | |||||
namespace TensorFlowNET.Examples.TextClassification | |||||
{ | |||||
public class VdCnn : Python | |||||
{ | |||||
private int embedding_size; | |||||
private int[] filter_sizes; | |||||
private int[] num_filters; | |||||
private int[] num_blocks; | |||||
private float learning_rate; | |||||
private IInitializer cnn_initializer; | |||||
private Tensor x; | |||||
private Tensor y; | |||||
private Tensor is_training; | |||||
private RefVariable global_step; | |||||
private RefVariable embeddings; | |||||
private Tensor x_emb; | |||||
public VdCnn(int alphabet_size, int document_max_len, int num_class) | |||||
{ | |||||
embedding_size = 16; | |||||
filter_sizes = new int[] { 3, 3, 3, 3, 3 }; | |||||
num_filters = new int[] { 64, 64, 128, 256, 512 }; | |||||
num_blocks = new int[] { 2, 2, 2, 2 }; | |||||
learning_rate = 0.001f; | |||||
cnn_initializer = tf.keras.initializers.he_normal(); | |||||
x = tf.placeholder(tf.int32, new TensorShape(-1, document_max_len), name: "x"); | |||||
y = tf.placeholder(tf.int32, new TensorShape(-1), name: "y"); | |||||
is_training = tf.placeholder(tf.boolean, new TensorShape(), name: "is_training"); | |||||
global_step = tf.Variable(0, trainable: false); | |||||
with(tf.name_scope("embedding"), delegate | |||||
{ | |||||
var init_embeddings = tf.random_uniform(new int[] { alphabet_size, embedding_size }, -1.0f, 1.0f); | |||||
embeddings = tf.get_variable("embeddings", initializer: init_embeddings); | |||||
// x_emb = tf.nn.embedding_lookup(embeddings, x); | |||||
}); | |||||
} | |||||
} | |||||
} |
@@ -16,14 +16,15 @@ | |||||
</PropertyGroup> | </PropertyGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.9.0" /> | |||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.0.1" /> | |||||
<PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" /> | ||||
<PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | <PackageReference Include="MSTest.TestFramework" Version="1.4.0" /> | ||||
<PackageReference Include="NumSharp" Version="0.7.3" /> | <PackageReference Include="NumSharp" Version="0.7.3" /> | ||||
<PackageReference Include="TensorFlow.NET" Version="0.3.0" /> | |||||
<PackageReference Include="TensorFlow.NET" Version="0.4.2" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
<ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||||
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | <ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | ||||
</ItemGroup> | </ItemGroup> | ||||
@@ -26,7 +26,7 @@ namespace TensorFlowNET.UnitTest | |||||
[TestMethod] | [TestMethod] | ||||
public void StringVar() | public void StringVar() | ||||
{ | { | ||||
var mammal1 = tf.Variable("Elephant", "var1", tf.chars); | |||||
var mammal1 = tf.Variable("Elephant", name: "var1", dtype: tf.chars); | |||||
var mammal2 = tf.Variable("Tiger"); | var mammal2 = tf.Variable("Tiger"); | ||||
} | } | ||||