@@ -424,24 +424,30 @@ namespace Tensorflow | |||||
return true; | return true; | ||||
} | } | ||||
public static void extendleft<T>(this Queue<T> queue, IEnumerable<T> elements) | |||||
{ | |||||
foreach (var element in elements.Reverse()) | |||||
queue.Enqueue(element); | |||||
} | |||||
public static bool empty<T>(this Queue<T> queue) | public static bool empty<T>(this Queue<T> queue) | ||||
=> queue.Count == 0; | => queue.Count == 0; | ||||
public static TValue SetDefault<TKey, TValue>(this Dictionary<TKey, TValue> dic, TKey key, TValue value) | |||||
public static TValue SetDefault<TKey, TValue>(this Dictionary<TKey, TValue> dic, TKey key, TValue defaultValue) | |||||
{ | { | ||||
if (dic.ContainsKey(key)) | if (dic.ContainsKey(key)) | ||||
return dic[key]; | return dic[key]; | ||||
dic[key] = value; | |||||
return value; | |||||
dic[key] = defaultValue; | |||||
return defaultValue; | |||||
} | } | ||||
public static TValue Get<TKey, TValue>(this Dictionary<TKey, TValue> dic, TKey key, TValue value) | |||||
public static TValue Get<TKey, TValue>(this Dictionary<TKey, TValue> dic, TKey key, TValue defaultValue) | |||||
{ | { | ||||
if (dic.ContainsKey(key)) | if (dic.ContainsKey(key)) | ||||
return dic[key]; | return dic[key]; | ||||
return value; | |||||
return defaultValue; | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -0,0 +1,15 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition | |||||
{ | |||||
public class OptimizerV2Args | |||||
{ | |||||
public string Name { get; set; } | |||||
public float LearningRate { get; set; } = 0.001f; | |||||
public float InitialDecay { get; set; } | |||||
public float ClipNorm { get; set; } | |||||
public float ClipValue { get; set; } | |||||
} | |||||
} |
@@ -4,13 +4,11 @@ using System.Text; | |||||
namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
{ | { | ||||
public class RMSpropArgs | |||||
public class RMSpropArgs : OptimizerV2Args | |||||
{ | { | ||||
public float LearningRate { get; set; } = 0.001f; | |||||
public float RHO { get; set; } = 0.9f; | public float RHO { get; set; } = 0.9f; | ||||
public float Momentum { get; set; } = 0.0f; | public float Momentum { get; set; } = 0.0f; | ||||
public float Epsilon { get; set; } = 1e-7f; | public float Epsilon { get; set; } = 1e-7f; | ||||
public bool Centered { get; set; } = false; | public bool Centered { get; set; } = false; | ||||
public string Name { get; set; } = "RMSprop"; | |||||
} | } | ||||
} | } |
@@ -0,0 +1,17 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.Engine | |||||
{ | |||||
public class Container | |||||
{ | |||||
protected string[] _output_names; | |||||
protected bool _built; | |||||
public Container(string[] output_names) | |||||
{ | |||||
_output_names = output_names; | |||||
} | |||||
} | |||||
} |
@@ -96,9 +96,37 @@ namespace Tensorflow.Keras.Engine | |||||
NodesByDepth = nodes_by_depth; | NodesByDepth = nodes_by_depth; | ||||
_layers = layers; | _layers = layers; | ||||
// Build self.input_names and self.output_names. | |||||
_set_output_names(); | |||||
ComputeTensorUsageCount(); | ComputeTensorUsageCount(); | ||||
} | } | ||||
/// <summary> | |||||
/// Assigns unique names to the Network's outputs. | |||||
/// </summary> | |||||
void _set_output_names() | |||||
{ | |||||
var uniquified = new List<string>(); | |||||
var output_names = new List<string>(); | |||||
var prefix_count = new Dictionary<string, int>(); | |||||
foreach (var layer in _output_layers) | |||||
{ | |||||
var proposal = layer.Name; | |||||
while (output_names.Contains(proposal)) | |||||
{ | |||||
var existing_count = prefix_count.Get(layer.Name, 1); | |||||
proposal = $"{layer.Name}_{existing_count}"; | |||||
prefix_count[layer.Name] = existing_count + 1; | |||||
} | |||||
output_names.add(proposal); | |||||
uniquified.append(proposal); | |||||
} | |||||
this.output_names = uniquified.ToArray(); | |||||
} | |||||
void ComputeTensorUsageCount() | void ComputeTensorUsageCount() | ||||
{ | { | ||||
var available_tensors = inputs.Select(x => x.GetHashCode()).ToList(); | var available_tensors = inputs.Select(x => x.GetHashCode()).ToList(); | ||||
@@ -0,0 +1,29 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.Engine | |||||
{ | |||||
public partial class Layer | |||||
{ | |||||
public IEnumerable<Layer> _flatten_layers(bool recursive = true, bool include_self = true) | |||||
{ | |||||
if (include_self) | |||||
yield return this; | |||||
var seen_object_ids = new List<int>(); | |||||
var deque = new Queue<Layer>(_layers); | |||||
while (!deque.empty()) | |||||
{ | |||||
var layer_or_container = deque.Dequeue(); | |||||
var layer_or_container_id = layer_or_container.GetHashCode(); | |||||
if (seen_object_ids.Contains(layer_or_container_id)) | |||||
continue; | |||||
seen_object_ids.Add(layer_or_container_id); | |||||
yield return layer_or_container; | |||||
if (recursive) | |||||
deque.extendleft(layer_or_container._layers); | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
protected List<Layer> _layers = new List<Layer>(); | protected List<Layer> _layers = new List<Layer>(); | ||||
public List<Layer> Layers => _layers; | public List<Layer> Layers => _layers; | ||||
protected Layer Dense(int units, | protected Layer Dense(int units, | ||||
Activation activation = null, | Activation activation = null, | ||||
TensorShape input_shape = null) | TensorShape input_shape = null) | ||||
@@ -6,11 +6,19 @@ namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
public partial class Layer | public partial class Layer | ||||
{ | { | ||||
Dictionary<Layer, object> trainable_state; | |||||
Dictionary<Layer, object> _get_trainable_state() | |||||
protected Dictionary<Layer, bool> trainable_state; | |||||
protected Dictionary<Layer, bool> _compiled_trainable_state; | |||||
/// <summary> | |||||
/// Get the `trainable` state of each sublayer. | |||||
/// </summary> | |||||
/// <returns></returns> | |||||
protected Dictionary<Layer, bool> _get_trainable_state() | |||||
{ | { | ||||
trainable_state = new Dictionary<Layer, object>(); | |||||
throw new NotImplementedException(""); | |||||
trainable_state = new Dictionary<Layer, bool>(); | |||||
foreach (var layer in _flatten_layers()) | |||||
trainable_state[layer] = layer.Trainable; | |||||
return trainable_state; | |||||
} | } | ||||
void _set_trainable_state(Dictionary<Layer, object> trainable_state) | void _set_trainable_state(Dictionary<Layer, object> trainable_state) | ||||
@@ -0,0 +1,25 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Losses; | |||||
using Tensorflow.Keras.Metrics; | |||||
namespace Tensorflow.Keras.Engine | |||||
{ | |||||
public class LossesContainer : Container | |||||
{ | |||||
ILossFunc _user_losses; | |||||
ILossFunc _losses; | |||||
Mean _loss_metric; | |||||
public LossesContainer(ILossFunc losses, string[] output_names = null) | |||||
: base(output_names) | |||||
{ | |||||
_user_losses = losses; | |||||
_losses = losses; | |||||
_loss_metric = new Mean(name: "loss"); | |||||
_built = false; | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,20 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.Engine | |||||
{ | |||||
public class MetricsContainer : Container | |||||
{ | |||||
string[] _user_metrics; | |||||
string[] _metrics; | |||||
public MetricsContainer(string[] metrics, string[] output_names = null) | |||||
: base(output_names) | |||||
{ | |||||
_user_metrics = metrics; | |||||
_metrics = metrics; | |||||
_built = false; | |||||
} | |||||
} | |||||
} |
@@ -4,6 +4,7 @@ using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Engine.DataAdapters; | using Tensorflow.Keras.Engine.DataAdapters; | ||||
using Tensorflow.Keras.Losses; | using Tensorflow.Keras.Losses; | ||||
using Tensorflow.Keras.Optimizers; | using Tensorflow.Keras.Optimizers; | ||||
using NumSharp; | |||||
namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
{ | { | ||||
@@ -20,12 +21,17 @@ namespace Tensorflow.Keras.Engine | |||||
bool _is_compiled; | bool _is_compiled; | ||||
#pragma warning restore CS0414 // The field 'Model._is_compiled' is assigned but its value is never used | #pragma warning restore CS0414 // The field 'Model._is_compiled' is assigned but its value is never used | ||||
#pragma warning restore CS0108 // Member hides inherited member; missing new keyword | #pragma warning restore CS0108 // Member hides inherited member; missing new keyword | ||||
string loss; | |||||
ILossFunc loss; | |||||
IOptimizer optimizer; | IOptimizer optimizer; | ||||
IVariableV1 _steps_per_execution; | IVariableV1 _steps_per_execution; | ||||
protected bool _is_graph_network; | protected bool _is_graph_network; | ||||
protected Tensors inputs; | protected Tensors inputs; | ||||
protected Tensors outputs; | protected Tensors outputs; | ||||
public string[] output_names; | |||||
IVariableV1 _train_counter; | |||||
IVariableV1 _test_counter; | |||||
IVariableV1 _predict_counter; | |||||
bool _base_model_initialized; | |||||
public Model(ModelArgs args) | public Model(ModelArgs args) | ||||
: base(args) | : base(args) | ||||
@@ -35,7 +41,17 @@ namespace Tensorflow.Keras.Engine | |||||
public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics) | public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics) | ||||
{ | { | ||||
this.optimizer = optimizer; | |||||
var compiled_loss = new LossesContainer(loss, output_names: output_names); | |||||
var compiled_metrics = new MetricsContainer(metrics, output_names: output_names); | |||||
int experimental_steps_per_execution = 1; | |||||
_configure_steps_per_execution(experimental_steps_per_execution); | |||||
// Initialize cache attrs. | |||||
_reset_compile_cache(); | |||||
_is_compiled = true; | |||||
this.loss = loss; | |||||
} | } | ||||
public void compile(string optimizerName, string lossName) | public void compile(string optimizerName, string lossName) | ||||
@@ -55,10 +71,29 @@ namespace Tensorflow.Keras.Engine | |||||
_reset_compile_cache(); | _reset_compile_cache(); | ||||
loss = lossName; | |||||
_is_compiled = true; | _is_compiled = true; | ||||
} | } | ||||
/// <summary> | |||||
/// Trains the model for a fixed number of epochs (iterations on a dataset). | |||||
/// </summary> | |||||
/// <param name="x"></param> | |||||
/// <param name="y"></param> | |||||
/// <param name="batch_size"></param> | |||||
/// <param name="epochs"></param> | |||||
/// <param name="verbose"></param> | |||||
/// <param name="validation_split"></param> | |||||
/// <param name="shuffle"></param> | |||||
public void fit(NDArray x, NDArray y, | |||||
int batch_size = -1, | |||||
int epochs = 1, | |||||
int verbose = 1, | |||||
float validation_split = 0f, | |||||
bool shuffle = true) | |||||
{ | |||||
} | |||||
void _configure_steps_per_execution(int steps_per_execution) | void _configure_steps_per_execution(int steps_per_execution) | ||||
{ | { | ||||
_steps_per_execution = tf.Variable(steps_per_execution, | _steps_per_execution = tf.Variable(steps_per_execution, | ||||
@@ -68,7 +103,23 @@ namespace Tensorflow.Keras.Engine | |||||
void _reset_compile_cache() | void _reset_compile_cache() | ||||
{ | { | ||||
// Used to cache `trainable` attr of `Layer`s for `fit`. | |||||
_compiled_trainable_state = _get_trainable_state(); | |||||
} | |||||
void _init_batch_counters() | |||||
{ | |||||
_train_counter = tf.Variable(0, | |||||
dtype: TF_DataType.TF_INT64, | |||||
aggregation: VariableAggregation.OnlyFirstReplica); | |||||
_test_counter = tf.Variable(0, | |||||
dtype: TF_DataType.TF_INT64, | |||||
aggregation: VariableAggregation.OnlyFirstReplica); | |||||
_predict_counter = tf.Variable(0, | |||||
dtype: TF_DataType.TF_INT64, | |||||
aggregation: VariableAggregation.OnlyFirstReplica); | |||||
} | } | ||||
public void compile(string optimizerName, ILossFunc lossName) | public void compile(string optimizerName, ILossFunc lossName) | ||||
@@ -0,0 +1,19 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
namespace Tensorflow.Keras.Metrics | |||||
{ | |||||
/// <summary> | |||||
/// Computes the (weighted) mean of the given values. | |||||
/// </summary> | |||||
public class Mean : Reduce | |||||
{ | |||||
public Mean(string name = "mean", TF_DataType dtype = TF_DataType.DtInvalid) | |||||
: base(Reduction.WEIGHTED_MEAN, name, dtype: dtype) | |||||
{ | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,50 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Engine; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Keras.Metrics | |||||
{ | |||||
/// <summary> | |||||
/// Encapsulates metric logic and state. | |||||
/// </summary> | |||||
public class Metric : Layer | |||||
{ | |||||
public Metric(string name = null, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
: base(new LayerArgs | |||||
{ | |||||
Name = name, | |||||
DType = dtype | |||||
}) | |||||
{ | |||||
stateful = true; | |||||
built = true; | |||||
} | |||||
protected override IVariableV1 add_weight(string name, | |||||
TensorShape shape = null, | |||||
TF_DataType dtype = TF_DataType.TF_FLOAT, | |||||
IInitializer initializer = null, | |||||
IRegularizer regularizer = null, | |||||
VariableSynchronization synchronization = VariableSynchronization.OnRead, | |||||
VariableAggregation aggregation = VariableAggregation.Sum, | |||||
bool trainable = true, | |||||
Func<VariableArgs, IVariableV1> getter = null) | |||||
{ | |||||
if (shape == null) | |||||
shape = new TensorShape(new int[0]); | |||||
return tf_with(ops.init_scope(), delegate | |||||
{ | |||||
return base.add_weight(name, shape, | |||||
dtype: dtype, | |||||
trainable: false, | |||||
initializer: initializer, | |||||
synchronization: synchronization, | |||||
aggregation: aggregation); | |||||
}); | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,28 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Keras.Metrics | |||||
{ | |||||
/// <summary> | |||||
/// Encapsulates metrics that perform a reduce operation on the values. | |||||
/// </summary> | |||||
public class Reduce : Metric | |||||
{ | |||||
IVariableV1 total; | |||||
IVariableV1 count; | |||||
public Reduce(string reduction, string name, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
: base(name: name, dtype: dtype) | |||||
{ | |||||
total = add_weight("total", initializer: tf.zeros_initializer); | |||||
if (reduction == Reduction.WEIGHTED_MEAN || | |||||
reduction == Reduction.SUM_OVER_BATCH_SIZE) | |||||
{ | |||||
count = add_weight("count", initializer: tf.zeros_initializer); | |||||
} | |||||
} | |||||
} | |||||
} |
@@ -0,0 +1,10 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
namespace Tensorflow.Keras.Metrics | |||||
{ | |||||
class Sum | |||||
{ | |||||
} | |||||
} |
@@ -3,6 +3,7 @@ using System.Collections.Generic; | |||||
using System.Linq; | using System.Linq; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using Tensorflow.Keras.ArgsDefinition; | |||||
namespace Tensorflow.Keras.Optimizers | namespace Tensorflow.Keras.Optimizers | ||||
{ | { | ||||
@@ -22,7 +23,7 @@ namespace Tensorflow.Keras.Optimizers | |||||
float beta_2 = 0.999f, | float beta_2 = 0.999f, | ||||
float epsilon = 1e-7f, | float epsilon = 1e-7f, | ||||
bool amsgrad = false, | bool amsgrad = false, | ||||
string name = "Adam") | |||||
string name = "Adam") : base(new OptimizerV2Args { }) | |||||
{ | { | ||||
_set_hyper("learning_rate", learning_rate); | _set_hyper("learning_rate", learning_rate); | ||||
// _set_hyper("decay", _initial_decay); | // _set_hyper("decay", _initial_decay); | ||||
@@ -7,6 +7,7 @@ using Tensorflow.Train; | |||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using Tensorflow; | using Tensorflow; | ||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using Tensorflow.Keras.ArgsDefinition; | |||||
namespace Tensorflow.Keras.Optimizers | namespace Tensorflow.Keras.Optimizers | ||||
{ | { | ||||
@@ -15,6 +16,7 @@ namespace Tensorflow.Keras.Optimizers | |||||
/// </summary> | /// </summary> | ||||
public class OptimizerV2 : Trackable, IOptimizer | public class OptimizerV2 : Trackable, IOptimizer | ||||
{ | { | ||||
OptimizerV2Args args; | |||||
protected bool _hypers_created; | protected bool _hypers_created; | ||||
protected virtual string _name { get; } | protected virtual string _name { get; } | ||||
@@ -30,13 +32,17 @@ namespace Tensorflow.Keras.Optimizers | |||||
Dictionary<string, Dictionary<string, IVariableV1>> _slots; | Dictionary<string, Dictionary<string, IVariableV1>> _slots; | ||||
List<string> _slot_names; | List<string> _slot_names; | ||||
public OptimizerV2() : base() | |||||
public OptimizerV2(OptimizerV2Args args) : base() | |||||
{ | { | ||||
this.args = args; | |||||
_weights = new List<IVariableV1>(); | _weights = new List<IVariableV1>(); | ||||
_hyper = new Dictionary<string, float>(); | _hyper = new Dictionary<string, float>(); | ||||
_hyper_variables = new Dictionary<string, IVariableV1>(); | _hyper_variables = new Dictionary<string, IVariableV1>(); | ||||
_slots = new Dictionary<string, Dictionary<string, IVariableV1>>(); | _slots = new Dictionary<string, Dictionary<string, IVariableV1>>(); | ||||
_slot_names = new List<string>(); | _slot_names = new List<string>(); | ||||
_set_hyper("learning_rate", args.LearningRate); | |||||
_set_hyper("decay", args.InitialDecay); | |||||
} | } | ||||
public void apply_gradients((Tensor, ResourceVariable) grads_and_vars, | public void apply_gradients((Tensor, ResourceVariable) grads_and_vars, | ||||
@@ -12,9 +12,11 @@ namespace Tensorflow.Keras.Optimizers | |||||
{ | { | ||||
RMSpropArgs args; | RMSpropArgs args; | ||||
public RMSprop(RMSpropArgs args) | |||||
public RMSprop(RMSpropArgs args) : base(args) | |||||
{ | { | ||||
this.args = args; | this.args = args; | ||||
_set_hyper("rho", args.RHO); | |||||
_set_hyper("momentum", args.Momentum); | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -3,6 +3,7 @@ using System.Collections.Generic; | |||||
using System.Linq; | using System.Linq; | ||||
using System.Text; | using System.Text; | ||||
using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
using Tensorflow.Keras.ArgsDefinition; | |||||
namespace Tensorflow.Keras.Optimizers | namespace Tensorflow.Keras.Optimizers | ||||
{ | { | ||||
@@ -17,7 +18,7 @@ namespace Tensorflow.Keras.Optimizers | |||||
public SGD(float learning_rate, | public SGD(float learning_rate, | ||||
float momentum = 0.0f, | float momentum = 0.0f, | ||||
bool nesterov = false, | bool nesterov = false, | ||||
float decay = 0.0f) : base() | |||||
float decay = 0.0f) : base(new OptimizerV2Args { }) | |||||
{ | { | ||||
_set_hyper("learning_rate", learning_rate); | _set_hyper("learning_rate", learning_rate); | ||||
_set_hyper("decay", decay); | _set_hyper("decay", decay); | ||||
@@ -3,9 +3,9 @@ | |||||
public class Reduction | public class Reduction | ||||
{ | { | ||||
public const string NONE = "none"; | public const string NONE = "none"; | ||||
public const string SUM = "weighted_sum"; | |||||
public const string WEIGHTED_SUM = "weighted_sum"; | |||||
public const string SUM_OVER_BATCH_SIZE = "weighted_sum_over_batch_size"; | public const string SUM_OVER_BATCH_SIZE = "weighted_sum_over_batch_size"; | ||||
public const string MEAN = "weighted_mean"; | |||||
public const string WEIGHTED_MEAN = "weighted_mean"; | |||||
public const string SUM_BY_NONZERO_WEIGHTS = "weighted_sum_by_nonzero_weights"; | public const string SUM_BY_NONZERO_WEIGHTS = "weighted_sum_by_nonzero_weights"; | ||||
public const string SUM_OVER_NONZERO_WEIGHTS = SUM_BY_NONZERO_WEIGHTS; | public const string SUM_OVER_NONZERO_WEIGHTS = SUM_BY_NONZERO_WEIGHTS; | ||||
} | } | ||||
@@ -47,7 +47,7 @@ namespace Tensorflow | |||||
else | else | ||||
{ | { | ||||
loss = math_ops.reduce_sum(weighted_losses); | loss = math_ops.reduce_sum(weighted_losses); | ||||
if (reduction == Reduction.MEAN) | |||||
if (reduction == Reduction.WEIGHTED_MEAN) | |||||
loss = _safe_mean( | loss = _safe_mean( | ||||
loss, math_ops.reduce_sum(array_ops.ones_like(losses) * weights)); | loss, math_ops.reduce_sum(array_ops.ones_like(losses) * weights)); | ||||
else if (reduction == Reduction.SUM_BY_NONZERO_WEIGHTS || | else if (reduction == Reduction.SUM_BY_NONZERO_WEIGHTS || | ||||
@@ -5,7 +5,7 @@ | |||||
<AssemblyName>TensorFlow.NET</AssemblyName> | <AssemblyName>TensorFlow.NET</AssemblyName> | ||||
<RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
<TargetTensorFlow>2.2.0</TargetTensorFlow> | <TargetTensorFlow>2.2.0</TargetTensorFlow> | ||||
<Version>0.21.0</Version> | |||||
<Version>0.30.0</Version> | |||||
<LangVersion>8.0</LangVersion> | <LangVersion>8.0</LangVersion> | ||||
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | ||||
<Company>SciSharp STACK</Company> | <Company>SciSharp STACK</Company> | ||||
@@ -19,14 +19,14 @@ | |||||
<Description>Google's TensorFlow full binding in .NET Standard. | <Description>Google's TensorFlow full binding in .NET Standard. | ||||
Building, training and infering deep learning models. | Building, training and infering deep learning models. | ||||
https://tensorflownet.readthedocs.io</Description> | https://tensorflownet.readthedocs.io</Description> | ||||
<AssemblyVersion>0.21.0.0</AssemblyVersion> | |||||
<AssemblyVersion>0.30.0.0</AssemblyVersion> | |||||
<PackageReleaseNotes>tf.net 0.20.x and above are based on tensorflow native 2.x. | <PackageReleaseNotes>tf.net 0.20.x and above are based on tensorflow native 2.x. | ||||
* Eager Mode is added finally. | * Eager Mode is added finally. | ||||
* tf.keras is partially working. | * tf.keras is partially working. | ||||
* tf.data is added. | * tf.data is added. | ||||
* autograph works partially.</PackageReleaseNotes> | * autograph works partially.</PackageReleaseNotes> | ||||
<FileVersion>0.21.0.0</FileVersion> | |||||
<FileVersion>0.30.0.0</FileVersion> | |||||
<PackageLicenseFile>LICENSE</PackageLicenseFile> | <PackageLicenseFile>LICENSE</PackageLicenseFile> | ||||
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | ||||
<SignAssembly>true</SignAssembly> | <SignAssembly>true</SignAssembly> | ||||