@@ -424,24 +424,30 @@ namespace Tensorflow | |||
return true; | |||
} | |||
public static void extendleft<T>(this Queue<T> queue, IEnumerable<T> elements) | |||
{ | |||
foreach (var element in elements.Reverse()) | |||
queue.Enqueue(element); | |||
} | |||
public static bool empty<T>(this Queue<T> queue) | |||
=> queue.Count == 0; | |||
public static TValue SetDefault<TKey, TValue>(this Dictionary<TKey, TValue> dic, TKey key, TValue value) | |||
public static TValue SetDefault<TKey, TValue>(this Dictionary<TKey, TValue> dic, TKey key, TValue defaultValue) | |||
{ | |||
if (dic.ContainsKey(key)) | |||
return dic[key]; | |||
dic[key] = value; | |||
return value; | |||
dic[key] = defaultValue; | |||
return defaultValue; | |||
} | |||
public static TValue Get<TKey, TValue>(this Dictionary<TKey, TValue> dic, TKey key, TValue value) | |||
public static TValue Get<TKey, TValue>(this Dictionary<TKey, TValue> dic, TKey key, TValue defaultValue) | |||
{ | |||
if (dic.ContainsKey(key)) | |||
return dic[key]; | |||
return value; | |||
return defaultValue; | |||
} | |||
} | |||
} |
@@ -0,0 +1,15 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class OptimizerV2Args | |||
{ | |||
public string Name { get; set; } | |||
public float LearningRate { get; set; } = 0.001f; | |||
public float InitialDecay { get; set; } | |||
public float ClipNorm { get; set; } | |||
public float ClipValue { get; set; } | |||
} | |||
} |
@@ -4,13 +4,11 @@ using System.Text; | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class RMSpropArgs | |||
public class RMSpropArgs : OptimizerV2Args | |||
{ | |||
public float LearningRate { get; set; } = 0.001f; | |||
public float RHO { get; set; } = 0.9f; | |||
public float Momentum { get; set; } = 0.0f; | |||
public float Epsilon { get; set; } = 1e-7f; | |||
public bool Centered { get; set; } = false; | |||
public string Name { get; set; } = "RMSprop"; | |||
} | |||
} |
@@ -0,0 +1,17 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
public class Container | |||
{ | |||
protected string[] _output_names; | |||
protected bool _built; | |||
public Container(string[] output_names) | |||
{ | |||
_output_names = output_names; | |||
} | |||
} | |||
} |
@@ -96,9 +96,37 @@ namespace Tensorflow.Keras.Engine | |||
NodesByDepth = nodes_by_depth; | |||
_layers = layers; | |||
// Build self.input_names and self.output_names. | |||
_set_output_names(); | |||
ComputeTensorUsageCount(); | |||
} | |||
/// <summary> | |||
/// Assigns unique names to the Network's outputs. | |||
/// </summary> | |||
void _set_output_names() | |||
{ | |||
var uniquified = new List<string>(); | |||
var output_names = new List<string>(); | |||
var prefix_count = new Dictionary<string, int>(); | |||
foreach (var layer in _output_layers) | |||
{ | |||
var proposal = layer.Name; | |||
while (output_names.Contains(proposal)) | |||
{ | |||
var existing_count = prefix_count.Get(layer.Name, 1); | |||
proposal = $"{layer.Name}_{existing_count}"; | |||
prefix_count[layer.Name] = existing_count + 1; | |||
} | |||
output_names.add(proposal); | |||
uniquified.append(proposal); | |||
} | |||
this.output_names = uniquified.ToArray(); | |||
} | |||
void ComputeTensorUsageCount() | |||
{ | |||
var available_tensors = inputs.Select(x => x.GetHashCode()).ToList(); | |||
@@ -0,0 +1,29 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
public partial class Layer | |||
{ | |||
public IEnumerable<Layer> _flatten_layers(bool recursive = true, bool include_self = true) | |||
{ | |||
if (include_self) | |||
yield return this; | |||
var seen_object_ids = new List<int>(); | |||
var deque = new Queue<Layer>(_layers); | |||
while (!deque.empty()) | |||
{ | |||
var layer_or_container = deque.Dequeue(); | |||
var layer_or_container_id = layer_or_container.GetHashCode(); | |||
if (seen_object_ids.Contains(layer_or_container_id)) | |||
continue; | |||
seen_object_ids.Add(layer_or_container_id); | |||
yield return layer_or_container; | |||
if (recursive) | |||
deque.extendleft(layer_or_container._layers); | |||
} | |||
} | |||
} | |||
} |
@@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
protected List<Layer> _layers = new List<Layer>(); | |||
public List<Layer> Layers => _layers; | |||
protected Layer Dense(int units, | |||
Activation activation = null, | |||
TensorShape input_shape = null) | |||
@@ -6,11 +6,19 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
public partial class Layer | |||
{ | |||
Dictionary<Layer, object> trainable_state; | |||
Dictionary<Layer, object> _get_trainable_state() | |||
protected Dictionary<Layer, bool> trainable_state; | |||
protected Dictionary<Layer, bool> _compiled_trainable_state; | |||
/// <summary> | |||
/// Get the `trainable` state of each sublayer. | |||
/// </summary> | |||
/// <returns></returns> | |||
protected Dictionary<Layer, bool> _get_trainable_state() | |||
{ | |||
trainable_state = new Dictionary<Layer, object>(); | |||
throw new NotImplementedException(""); | |||
trainable_state = new Dictionary<Layer, bool>(); | |||
foreach (var layer in _flatten_layers()) | |||
trainable_state[layer] = layer.Trainable; | |||
return trainable_state; | |||
} | |||
void _set_trainable_state(Dictionary<Layer, object> trainable_state) | |||
@@ -0,0 +1,25 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Losses; | |||
using Tensorflow.Keras.Metrics; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
public class LossesContainer : Container | |||
{ | |||
ILossFunc _user_losses; | |||
ILossFunc _losses; | |||
Mean _loss_metric; | |||
public LossesContainer(ILossFunc losses, string[] output_names = null) | |||
: base(output_names) | |||
{ | |||
_user_losses = losses; | |||
_losses = losses; | |||
_loss_metric = new Mean(name: "loss"); | |||
_built = false; | |||
} | |||
} | |||
} |
@@ -0,0 +1,20 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
public class MetricsContainer : Container | |||
{ | |||
string[] _user_metrics; | |||
string[] _metrics; | |||
public MetricsContainer(string[] metrics, string[] output_names = null) | |||
: base(output_names) | |||
{ | |||
_user_metrics = metrics; | |||
_metrics = metrics; | |||
_built = false; | |||
} | |||
} | |||
} |
@@ -4,6 +4,7 @@ using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine.DataAdapters; | |||
using Tensorflow.Keras.Losses; | |||
using Tensorflow.Keras.Optimizers; | |||
using NumSharp; | |||
namespace Tensorflow.Keras.Engine | |||
{ | |||
@@ -20,12 +21,17 @@ namespace Tensorflow.Keras.Engine | |||
bool _is_compiled; | |||
#pragma warning restore CS0414 // The field 'Model._is_compiled' is assigned but its value is never used | |||
#pragma warning restore CS0108 // Member hides inherited member; missing new keyword | |||
string loss; | |||
ILossFunc loss; | |||
IOptimizer optimizer; | |||
IVariableV1 _steps_per_execution; | |||
protected bool _is_graph_network; | |||
protected Tensors inputs; | |||
protected Tensors outputs; | |||
public string[] output_names; | |||
IVariableV1 _train_counter; | |||
IVariableV1 _test_counter; | |||
IVariableV1 _predict_counter; | |||
bool _base_model_initialized; | |||
public Model(ModelArgs args) | |||
: base(args) | |||
@@ -35,7 +41,17 @@ namespace Tensorflow.Keras.Engine | |||
public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics) | |||
{ | |||
this.optimizer = optimizer; | |||
var compiled_loss = new LossesContainer(loss, output_names: output_names); | |||
var compiled_metrics = new MetricsContainer(metrics, output_names: output_names); | |||
int experimental_steps_per_execution = 1; | |||
_configure_steps_per_execution(experimental_steps_per_execution); | |||
// Initialize cache attrs. | |||
_reset_compile_cache(); | |||
_is_compiled = true; | |||
this.loss = loss; | |||
} | |||
public void compile(string optimizerName, string lossName) | |||
@@ -55,10 +71,29 @@ namespace Tensorflow.Keras.Engine | |||
_reset_compile_cache(); | |||
loss = lossName; | |||
_is_compiled = true; | |||
} | |||
/// <summary> | |||
/// Trains the model for a fixed number of epochs (iterations on a dataset). | |||
/// </summary> | |||
/// <param name="x"></param> | |||
/// <param name="y"></param> | |||
/// <param name="batch_size"></param> | |||
/// <param name="epochs"></param> | |||
/// <param name="verbose"></param> | |||
/// <param name="validation_split"></param> | |||
/// <param name="shuffle"></param> | |||
public void fit(NDArray x, NDArray y, | |||
int batch_size = -1, | |||
int epochs = 1, | |||
int verbose = 1, | |||
float validation_split = 0f, | |||
bool shuffle = true) | |||
{ | |||
} | |||
void _configure_steps_per_execution(int steps_per_execution) | |||
{ | |||
_steps_per_execution = tf.Variable(steps_per_execution, | |||
@@ -68,7 +103,23 @@ namespace Tensorflow.Keras.Engine | |||
void _reset_compile_cache() | |||
{ | |||
// Used to cache `trainable` attr of `Layer`s for `fit`. | |||
_compiled_trainable_state = _get_trainable_state(); | |||
} | |||
void _init_batch_counters() | |||
{ | |||
_train_counter = tf.Variable(0, | |||
dtype: TF_DataType.TF_INT64, | |||
aggregation: VariableAggregation.OnlyFirstReplica); | |||
_test_counter = tf.Variable(0, | |||
dtype: TF_DataType.TF_INT64, | |||
aggregation: VariableAggregation.OnlyFirstReplica); | |||
_predict_counter = tf.Variable(0, | |||
dtype: TF_DataType.TF_INT64, | |||
aggregation: VariableAggregation.OnlyFirstReplica); | |||
} | |||
public void compile(string optimizerName, ILossFunc lossName) | |||
@@ -0,0 +1,19 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
namespace Tensorflow.Keras.Metrics | |||
{ | |||
/// <summary> | |||
/// Computes the (weighted) mean of the given values. | |||
/// </summary> | |||
public class Mean : Reduce | |||
{ | |||
public Mean(string name = "mean", TF_DataType dtype = TF_DataType.DtInvalid) | |||
: base(Reduction.WEIGHTED_MEAN, name, dtype: dtype) | |||
{ | |||
} | |||
} | |||
} |
@@ -0,0 +1,50 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Metrics | |||
{ | |||
/// <summary> | |||
/// Encapsulates metric logic and state. | |||
/// </summary> | |||
public class Metric : Layer | |||
{ | |||
public Metric(string name = null, TF_DataType dtype = TF_DataType.DtInvalid) | |||
: base(new LayerArgs | |||
{ | |||
Name = name, | |||
DType = dtype | |||
}) | |||
{ | |||
stateful = true; | |||
built = true; | |||
} | |||
protected override IVariableV1 add_weight(string name, | |||
TensorShape shape = null, | |||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
IInitializer initializer = null, | |||
IRegularizer regularizer = null, | |||
VariableSynchronization synchronization = VariableSynchronization.OnRead, | |||
VariableAggregation aggregation = VariableAggregation.Sum, | |||
bool trainable = true, | |||
Func<VariableArgs, IVariableV1> getter = null) | |||
{ | |||
if (shape == null) | |||
shape = new TensorShape(new int[0]); | |||
return tf_with(ops.init_scope(), delegate | |||
{ | |||
return base.add_weight(name, shape, | |||
dtype: dtype, | |||
trainable: false, | |||
initializer: initializer, | |||
synchronization: synchronization, | |||
aggregation: aggregation); | |||
}); | |||
} | |||
} | |||
} |
@@ -0,0 +1,28 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow.Keras.Metrics | |||
{ | |||
/// <summary> | |||
/// Encapsulates metrics that perform a reduce operation on the values. | |||
/// </summary> | |||
public class Reduce : Metric | |||
{ | |||
IVariableV1 total; | |||
IVariableV1 count; | |||
public Reduce(string reduction, string name, TF_DataType dtype = TF_DataType.DtInvalid) | |||
: base(name: name, dtype: dtype) | |||
{ | |||
total = add_weight("total", initializer: tf.zeros_initializer); | |||
if (reduction == Reduction.WEIGHTED_MEAN || | |||
reduction == Reduction.SUM_OVER_BATCH_SIZE) | |||
{ | |||
count = add_weight("count", initializer: tf.zeros_initializer); | |||
} | |||
} | |||
} | |||
} |
@@ -0,0 +1,10 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
namespace Tensorflow.Keras.Metrics | |||
{ | |||
class Sum | |||
{ | |||
} | |||
} |
@@ -3,6 +3,7 @@ using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
namespace Tensorflow.Keras.Optimizers | |||
{ | |||
@@ -22,7 +23,7 @@ namespace Tensorflow.Keras.Optimizers | |||
float beta_2 = 0.999f, | |||
float epsilon = 1e-7f, | |||
bool amsgrad = false, | |||
string name = "Adam") | |||
string name = "Adam") : base(new OptimizerV2Args { }) | |||
{ | |||
_set_hyper("learning_rate", learning_rate); | |||
// _set_hyper("decay", _initial_decay); | |||
@@ -7,6 +7,7 @@ using Tensorflow.Train; | |||
using static Tensorflow.Binding; | |||
using Tensorflow; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
namespace Tensorflow.Keras.Optimizers | |||
{ | |||
@@ -15,6 +16,7 @@ namespace Tensorflow.Keras.Optimizers | |||
/// </summary> | |||
public class OptimizerV2 : Trackable, IOptimizer | |||
{ | |||
OptimizerV2Args args; | |||
protected bool _hypers_created; | |||
protected virtual string _name { get; } | |||
@@ -30,13 +32,17 @@ namespace Tensorflow.Keras.Optimizers | |||
Dictionary<string, Dictionary<string, IVariableV1>> _slots; | |||
List<string> _slot_names; | |||
public OptimizerV2() : base() | |||
public OptimizerV2(OptimizerV2Args args) : base() | |||
{ | |||
this.args = args; | |||
_weights = new List<IVariableV1>(); | |||
_hyper = new Dictionary<string, float>(); | |||
_hyper_variables = new Dictionary<string, IVariableV1>(); | |||
_slots = new Dictionary<string, Dictionary<string, IVariableV1>>(); | |||
_slot_names = new List<string>(); | |||
_set_hyper("learning_rate", args.LearningRate); | |||
_set_hyper("decay", args.InitialDecay); | |||
} | |||
public void apply_gradients((Tensor, ResourceVariable) grads_and_vars, | |||
@@ -12,9 +12,11 @@ namespace Tensorflow.Keras.Optimizers | |||
{ | |||
RMSpropArgs args; | |||
public RMSprop(RMSpropArgs args) | |||
public RMSprop(RMSpropArgs args) : base(args) | |||
{ | |||
this.args = args; | |||
_set_hyper("rho", args.RHO); | |||
_set_hyper("momentum", args.Momentum); | |||
} | |||
} | |||
} |
@@ -3,6 +3,7 @@ using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Eager; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
namespace Tensorflow.Keras.Optimizers | |||
{ | |||
@@ -17,7 +18,7 @@ namespace Tensorflow.Keras.Optimizers | |||
public SGD(float learning_rate, | |||
float momentum = 0.0f, | |||
bool nesterov = false, | |||
float decay = 0.0f) : base() | |||
float decay = 0.0f) : base(new OptimizerV2Args { }) | |||
{ | |||
_set_hyper("learning_rate", learning_rate); | |||
_set_hyper("decay", decay); | |||
@@ -3,9 +3,9 @@ | |||
public class Reduction | |||
{ | |||
public const string NONE = "none"; | |||
public const string SUM = "weighted_sum"; | |||
public const string WEIGHTED_SUM = "weighted_sum"; | |||
public const string SUM_OVER_BATCH_SIZE = "weighted_sum_over_batch_size"; | |||
public const string MEAN = "weighted_mean"; | |||
public const string WEIGHTED_MEAN = "weighted_mean"; | |||
public const string SUM_BY_NONZERO_WEIGHTS = "weighted_sum_by_nonzero_weights"; | |||
public const string SUM_OVER_NONZERO_WEIGHTS = SUM_BY_NONZERO_WEIGHTS; | |||
} | |||
@@ -47,7 +47,7 @@ namespace Tensorflow | |||
else | |||
{ | |||
loss = math_ops.reduce_sum(weighted_losses); | |||
if (reduction == Reduction.MEAN) | |||
if (reduction == Reduction.WEIGHTED_MEAN) | |||
loss = _safe_mean( | |||
loss, math_ops.reduce_sum(array_ops.ones_like(losses) * weights)); | |||
else if (reduction == Reduction.SUM_BY_NONZERO_WEIGHTS || | |||
@@ -5,7 +5,7 @@ | |||
<AssemblyName>TensorFlow.NET</AssemblyName> | |||
<RootNamespace>Tensorflow</RootNamespace> | |||
<TargetTensorFlow>2.2.0</TargetTensorFlow> | |||
<Version>0.21.0</Version> | |||
<Version>0.30.0</Version> | |||
<LangVersion>8.0</LangVersion> | |||
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | |||
<Company>SciSharp STACK</Company> | |||
@@ -19,14 +19,14 @@ | |||
<Description>Google's TensorFlow full binding in .NET Standard. | |||
Building, training and infering deep learning models. | |||
https://tensorflownet.readthedocs.io</Description> | |||
<AssemblyVersion>0.21.0.0</AssemblyVersion> | |||
<AssemblyVersion>0.30.0.0</AssemblyVersion> | |||
<PackageReleaseNotes>tf.net 0.20.x and above are based on tensorflow native 2.x. | |||
* Eager Mode is added finally. | |||
* tf.keras is partially working. | |||
* tf.data is added. | |||
* autograph works partially.</PackageReleaseNotes> | |||
<FileVersion>0.21.0.0</FileVersion> | |||
<FileVersion>0.30.0.0</FileVersion> | |||
<PackageLicenseFile>LICENSE</PackageLicenseFile> | |||
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | |||
<SignAssembly>true</SignAssembly> | |||