Browse Source

Add Metrics architecture.

tags/v0.30
Oceania2018 5 years ago
parent
commit
006eeaa454
22 changed files with 339 additions and 25 deletions
  1. +11
    -5
      src/TensorFlowNET.Core/Binding.Util.cs
  2. +15
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs
  3. +1
    -3
      src/TensorFlowNET.Core/Keras/ArgsDefinition/RMSpropArgs.cs
  4. +17
    -0
      src/TensorFlowNET.Core/Keras/Engine/Container.cs
  5. +28
    -0
      src/TensorFlowNET.Core/Keras/Engine/Functional.cs
  6. +29
    -0
      src/TensorFlowNET.Core/Keras/Engine/Layer.FlattenLayers.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Keras/Engine/Layer.Layers.cs
  8. +12
    -4
      src/TensorFlowNET.Core/Keras/Engine/Layer.State.cs
  9. +25
    -0
      src/TensorFlowNET.Core/Keras/Engine/LossesContainer.cs
  10. +20
    -0
      src/TensorFlowNET.Core/Keras/Engine/MetricsContainer.cs
  11. +53
    -2
      src/TensorFlowNET.Core/Keras/Engine/Model.cs
  12. +19
    -0
      src/TensorFlowNET.Core/Keras/Metrics/Mean.cs
  13. +50
    -0
      src/TensorFlowNET.Core/Keras/Metrics/Metric.cs
  14. +28
    -0
      src/TensorFlowNET.Core/Keras/Metrics/Reduce.cs
  15. +10
    -0
      src/TensorFlowNET.Core/Keras/Metrics/Sum.cs
  16. +2
    -1
      src/TensorFlowNET.Core/Keras/Optimizers/Adam.cs
  17. +7
    -1
      src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs
  18. +3
    -1
      src/TensorFlowNET.Core/Keras/Optimizers/RMSprop.cs
  19. +2
    -1
      src/TensorFlowNET.Core/Keras/Optimizers/SGD.cs
  20. +2
    -2
      src/TensorFlowNET.Core/Operations/Losses/Reduction.cs
  21. +1
    -1
      src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs
  22. +3
    -3
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj

+ 11
- 5
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -424,24 +424,30 @@ namespace Tensorflow
return true; return true;
} }


public static void extendleft<T>(this Queue<T> queue, IEnumerable<T> elements)
{
foreach (var element in elements.Reverse())
queue.Enqueue(element);
}

public static bool empty<T>(this Queue<T> queue) public static bool empty<T>(this Queue<T> queue)
=> queue.Count == 0; => queue.Count == 0;


public static TValue SetDefault<TKey, TValue>(this Dictionary<TKey, TValue> dic, TKey key, TValue value)
public static TValue SetDefault<TKey, TValue>(this Dictionary<TKey, TValue> dic, TKey key, TValue defaultValue)
{ {
if (dic.ContainsKey(key)) if (dic.ContainsKey(key))
return dic[key]; return dic[key];


dic[key] = value;
return value;
dic[key] = defaultValue;
return defaultValue;
} }


public static TValue Get<TKey, TValue>(this Dictionary<TKey, TValue> dic, TKey key, TValue value)
public static TValue Get<TKey, TValue>(this Dictionary<TKey, TValue> dic, TKey key, TValue defaultValue)
{ {
if (dic.ContainsKey(key)) if (dic.ContainsKey(key))
return dic[key]; return dic[key];


return value;
return defaultValue;
} }
} }
} }

+ 15
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs View File

@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class OptimizerV2Args
{
public string Name { get; set; }
public float LearningRate { get; set; } = 0.001f;
public float InitialDecay { get; set; }
public float ClipNorm { get; set; }
public float ClipValue { get; set; }
}
}

+ 1
- 3
src/TensorFlowNET.Core/Keras/ArgsDefinition/RMSpropArgs.cs View File

@@ -4,13 +4,11 @@ using System.Text;


namespace Tensorflow.Keras.ArgsDefinition namespace Tensorflow.Keras.ArgsDefinition
{ {
public class RMSpropArgs
public class RMSpropArgs : OptimizerV2Args
{ {
public float LearningRate { get; set; } = 0.001f;
public float RHO { get; set; } = 0.9f; public float RHO { get; set; } = 0.9f;
public float Momentum { get; set; } = 0.0f; public float Momentum { get; set; } = 0.0f;
public float Epsilon { get; set; } = 1e-7f; public float Epsilon { get; set; } = 1e-7f;
public bool Centered { get; set; } = false; public bool Centered { get; set; } = false;
public string Name { get; set; } = "RMSprop";
} }
} }

+ 17
- 0
src/TensorFlowNET.Core/Keras/Engine/Container.cs View File

@@ -0,0 +1,17 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Engine
{
public class Container
{
protected string[] _output_names;
protected bool _built;

public Container(string[] output_names)
{
_output_names = output_names;
}
}
}

+ 28
- 0
src/TensorFlowNET.Core/Keras/Engine/Functional.cs View File

@@ -96,9 +96,37 @@ namespace Tensorflow.Keras.Engine
NodesByDepth = nodes_by_depth; NodesByDepth = nodes_by_depth;
_layers = layers; _layers = layers;


// Build self.input_names and self.output_names.
_set_output_names();

ComputeTensorUsageCount(); ComputeTensorUsageCount();
} }


/// <summary>
/// Assigns unique names to the Network's outputs.
/// </summary>
void _set_output_names()
{
var uniquified = new List<string>();
var output_names = new List<string>();
var prefix_count = new Dictionary<string, int>();

foreach (var layer in _output_layers)
{
var proposal = layer.Name;
while (output_names.Contains(proposal))
{
var existing_count = prefix_count.Get(layer.Name, 1);
proposal = $"{layer.Name}_{existing_count}";
prefix_count[layer.Name] = existing_count + 1;
}
output_names.add(proposal);
uniquified.append(proposal);
}

this.output_names = uniquified.ToArray();
}

void ComputeTensorUsageCount() void ComputeTensorUsageCount()
{ {
var available_tensors = inputs.Select(x => x.GetHashCode()).ToList(); var available_tensors = inputs.Select(x => x.GetHashCode()).ToList();


+ 29
- 0
src/TensorFlowNET.Core/Keras/Engine/Layer.FlattenLayers.cs View File

@@ -0,0 +1,29 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Engine
{
public partial class Layer
{
public IEnumerable<Layer> _flatten_layers(bool recursive = true, bool include_self = true)
{
if (include_self)
yield return this;

var seen_object_ids = new List<int>();
var deque = new Queue<Layer>(_layers);
while (!deque.empty())
{
var layer_or_container = deque.Dequeue();
var layer_or_container_id = layer_or_container.GetHashCode();
if (seen_object_ids.Contains(layer_or_container_id))
continue;
seen_object_ids.Add(layer_or_container_id);
yield return layer_or_container;
if (recursive)
deque.extendleft(layer_or_container._layers);
}
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Keras/Engine/Layer.Layers.cs View File

@@ -12,7 +12,7 @@ namespace Tensorflow.Keras.Engine
{ {
protected List<Layer> _layers = new List<Layer>(); protected List<Layer> _layers = new List<Layer>();
public List<Layer> Layers => _layers; public List<Layer> Layers => _layers;
protected Layer Dense(int units, protected Layer Dense(int units,
Activation activation = null, Activation activation = null,
TensorShape input_shape = null) TensorShape input_shape = null)


+ 12
- 4
src/TensorFlowNET.Core/Keras/Engine/Layer.State.cs View File

@@ -6,11 +6,19 @@ namespace Tensorflow.Keras.Engine
{ {
public partial class Layer public partial class Layer
{ {
Dictionary<Layer, object> trainable_state;
Dictionary<Layer, object> _get_trainable_state()
protected Dictionary<Layer, bool> trainable_state;
protected Dictionary<Layer, bool> _compiled_trainable_state;

/// <summary>
/// Get the `trainable` state of each sublayer.
/// </summary>
/// <returns></returns>
protected Dictionary<Layer, bool> _get_trainable_state()
{ {
trainable_state = new Dictionary<Layer, object>();
throw new NotImplementedException("");
trainable_state = new Dictionary<Layer, bool>();
foreach (var layer in _flatten_layers())
trainable_state[layer] = layer.Trainable;
return trainable_state;
} }


void _set_trainable_state(Dictionary<Layer, object> trainable_state) void _set_trainable_state(Dictionary<Layer, object> trainable_state)


+ 25
- 0
src/TensorFlowNET.Core/Keras/Engine/LossesContainer.cs View File

@@ -0,0 +1,25 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Metrics;

namespace Tensorflow.Keras.Engine
{
public class LossesContainer : Container
{
ILossFunc _user_losses;
ILossFunc _losses;
Mean _loss_metric;

public LossesContainer(ILossFunc losses, string[] output_names = null)
: base(output_names)
{
_user_losses = losses;
_losses = losses;
_loss_metric = new Mean(name: "loss");
_built = false;
}
}
}

+ 20
- 0
src/TensorFlowNET.Core/Keras/Engine/MetricsContainer.cs View File

@@ -0,0 +1,20 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Engine
{
public class MetricsContainer : Container
{
string[] _user_metrics;
string[] _metrics;

public MetricsContainer(string[] metrics, string[] output_names = null)
: base(output_names)
{
_user_metrics = metrics;
_metrics = metrics;
_built = false;
}
}
}

+ 53
- 2
src/TensorFlowNET.Core/Keras/Engine/Model.cs View File

@@ -4,6 +4,7 @@ using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine.DataAdapters; using Tensorflow.Keras.Engine.DataAdapters;
using Tensorflow.Keras.Losses; using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Optimizers; using Tensorflow.Keras.Optimizers;
using NumSharp;


namespace Tensorflow.Keras.Engine namespace Tensorflow.Keras.Engine
{ {
@@ -20,12 +21,17 @@ namespace Tensorflow.Keras.Engine
bool _is_compiled; bool _is_compiled;
#pragma warning restore CS0414 // The field 'Model._is_compiled' is assigned but its value is never used #pragma warning restore CS0414 // The field 'Model._is_compiled' is assigned but its value is never used
#pragma warning restore CS0108 // Member hides inherited member; missing new keyword #pragma warning restore CS0108 // Member hides inherited member; missing new keyword
string loss;
ILossFunc loss;
IOptimizer optimizer; IOptimizer optimizer;
IVariableV1 _steps_per_execution; IVariableV1 _steps_per_execution;
protected bool _is_graph_network; protected bool _is_graph_network;
protected Tensors inputs; protected Tensors inputs;
protected Tensors outputs; protected Tensors outputs;
public string[] output_names;
IVariableV1 _train_counter;
IVariableV1 _test_counter;
IVariableV1 _predict_counter;
bool _base_model_initialized;


public Model(ModelArgs args) public Model(ModelArgs args)
: base(args) : base(args)
@@ -35,7 +41,17 @@ namespace Tensorflow.Keras.Engine


public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics) public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics)
{ {
this.optimizer = optimizer;
var compiled_loss = new LossesContainer(loss, output_names: output_names);
var compiled_metrics = new MetricsContainer(metrics, output_names: output_names);


int experimental_steps_per_execution = 1;
_configure_steps_per_execution(experimental_steps_per_execution);

// Initialize cache attrs.
_reset_compile_cache();
_is_compiled = true;
this.loss = loss;
} }


public void compile(string optimizerName, string lossName) public void compile(string optimizerName, string lossName)
@@ -55,10 +71,29 @@ namespace Tensorflow.Keras.Engine


_reset_compile_cache(); _reset_compile_cache();


loss = lossName;
_is_compiled = true; _is_compiled = true;
} }


/// <summary>
/// Trains the model for a fixed number of epochs (iterations on a dataset).
/// </summary>
/// <param name="x"></param>
/// <param name="y"></param>
/// <param name="batch_size"></param>
/// <param name="epochs"></param>
/// <param name="verbose"></param>
/// <param name="validation_split"></param>
/// <param name="shuffle"></param>
public void fit(NDArray x, NDArray y,
int batch_size = -1,
int epochs = 1,
int verbose = 1,
float validation_split = 0f,
bool shuffle = true)
{

}

void _configure_steps_per_execution(int steps_per_execution) void _configure_steps_per_execution(int steps_per_execution)
{ {
_steps_per_execution = tf.Variable(steps_per_execution, _steps_per_execution = tf.Variable(steps_per_execution,
@@ -68,7 +103,23 @@ namespace Tensorflow.Keras.Engine


void _reset_compile_cache() void _reset_compile_cache()
{ {
// Used to cache `trainable` attr of `Layer`s for `fit`.
_compiled_trainable_state = _get_trainable_state();
}

void _init_batch_counters()
{
_train_counter = tf.Variable(0,
dtype: TF_DataType.TF_INT64,
aggregation: VariableAggregation.OnlyFirstReplica);


_test_counter = tf.Variable(0,
dtype: TF_DataType.TF_INT64,
aggregation: VariableAggregation.OnlyFirstReplica);

_predict_counter = tf.Variable(0,
dtype: TF_DataType.TF_INT64,
aggregation: VariableAggregation.OnlyFirstReplica);
} }


public void compile(string optimizerName, ILossFunc lossName) public void compile(string optimizerName, ILossFunc lossName)


+ 19
- 0
src/TensorFlowNET.Core/Keras/Metrics/Mean.cs View File

@@ -0,0 +1,19 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;

namespace Tensorflow.Keras.Metrics
{
/// <summary>
/// Computes the (weighted) mean of the given values.
/// </summary>
public class Mean : Reduce
{
public Mean(string name = "mean", TF_DataType dtype = TF_DataType.DtInvalid)
: base(Reduction.WEIGHTED_MEAN, name, dtype: dtype)
{

}
}
}

+ 50
- 0
src/TensorFlowNET.Core/Keras/Metrics/Metric.cs View File

@@ -0,0 +1,50 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Metrics
{
/// <summary>
/// Encapsulates metric logic and state.
/// </summary>
public class Metric : Layer
{
public Metric(string name = null, TF_DataType dtype = TF_DataType.DtInvalid)
: base(new LayerArgs
{
Name = name,
DType = dtype
})
{
stateful = true;
built = true;
}

protected override IVariableV1 add_weight(string name,
TensorShape shape = null,
TF_DataType dtype = TF_DataType.TF_FLOAT,
IInitializer initializer = null,
IRegularizer regularizer = null,
VariableSynchronization synchronization = VariableSynchronization.OnRead,
VariableAggregation aggregation = VariableAggregation.Sum,
bool trainable = true,
Func<VariableArgs, IVariableV1> getter = null)
{
if (shape == null)
shape = new TensorShape(new int[0]);

return tf_with(ops.init_scope(), delegate
{
return base.add_weight(name, shape,
dtype: dtype,
trainable: false,
initializer: initializer,
synchronization: synchronization,
aggregation: aggregation);
});
}
}
}

+ 28
- 0
src/TensorFlowNET.Core/Keras/Metrics/Reduce.cs View File

@@ -0,0 +1,28 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.Metrics
{
/// <summary>
/// Encapsulates metrics that perform a reduce operation on the values.
/// </summary>
public class Reduce : Metric
{
IVariableV1 total;
IVariableV1 count;
public Reduce(string reduction, string name, TF_DataType dtype = TF_DataType.DtInvalid)
: base(name: name, dtype: dtype)
{
total = add_weight("total", initializer: tf.zeros_initializer);

if (reduction == Reduction.WEIGHTED_MEAN ||
reduction == Reduction.SUM_OVER_BATCH_SIZE)
{
count = add_weight("count", initializer: tf.zeros_initializer);
}
}
}
}

+ 10
- 0
src/TensorFlowNET.Core/Keras/Metrics/Sum.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Metrics
{
class Sum
{
}
}

+ 2
- 1
src/TensorFlowNET.Core/Keras/Optimizers/Adam.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using Tensorflow.Eager; using Tensorflow.Eager;
using Tensorflow.Keras.ArgsDefinition;


namespace Tensorflow.Keras.Optimizers namespace Tensorflow.Keras.Optimizers
{ {
@@ -22,7 +23,7 @@ namespace Tensorflow.Keras.Optimizers
float beta_2 = 0.999f, float beta_2 = 0.999f,
float epsilon = 1e-7f, float epsilon = 1e-7f,
bool amsgrad = false, bool amsgrad = false,
string name = "Adam")
string name = "Adam") : base(new OptimizerV2Args { })
{ {
_set_hyper("learning_rate", learning_rate); _set_hyper("learning_rate", learning_rate);
// _set_hyper("decay", _initial_decay); // _set_hyper("decay", _initial_decay);


+ 7
- 1
src/TensorFlowNET.Core/Keras/Optimizers/OptimizerV2.cs View File

@@ -7,6 +7,7 @@ using Tensorflow.Train;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using Tensorflow; using Tensorflow;
using Tensorflow.Eager; using Tensorflow.Eager;
using Tensorflow.Keras.ArgsDefinition;


namespace Tensorflow.Keras.Optimizers namespace Tensorflow.Keras.Optimizers
{ {
@@ -15,6 +16,7 @@ namespace Tensorflow.Keras.Optimizers
/// </summary> /// </summary>
public class OptimizerV2 : Trackable, IOptimizer public class OptimizerV2 : Trackable, IOptimizer
{ {
OptimizerV2Args args;
protected bool _hypers_created; protected bool _hypers_created;
protected virtual string _name { get; } protected virtual string _name { get; }


@@ -30,13 +32,17 @@ namespace Tensorflow.Keras.Optimizers
Dictionary<string, Dictionary<string, IVariableV1>> _slots; Dictionary<string, Dictionary<string, IVariableV1>> _slots;
List<string> _slot_names; List<string> _slot_names;


public OptimizerV2() : base()
public OptimizerV2(OptimizerV2Args args) : base()
{ {
this.args = args;
_weights = new List<IVariableV1>(); _weights = new List<IVariableV1>();
_hyper = new Dictionary<string, float>(); _hyper = new Dictionary<string, float>();
_hyper_variables = new Dictionary<string, IVariableV1>(); _hyper_variables = new Dictionary<string, IVariableV1>();
_slots = new Dictionary<string, Dictionary<string, IVariableV1>>(); _slots = new Dictionary<string, Dictionary<string, IVariableV1>>();
_slot_names = new List<string>(); _slot_names = new List<string>();

_set_hyper("learning_rate", args.LearningRate);
_set_hyper("decay", args.InitialDecay);
} }


public void apply_gradients((Tensor, ResourceVariable) grads_and_vars, public void apply_gradients((Tensor, ResourceVariable) grads_and_vars,


+ 3
- 1
src/TensorFlowNET.Core/Keras/Optimizers/RMSprop.cs View File

@@ -12,9 +12,11 @@ namespace Tensorflow.Keras.Optimizers
{ {
RMSpropArgs args; RMSpropArgs args;


public RMSprop(RMSpropArgs args)
public RMSprop(RMSpropArgs args) : base(args)
{ {
this.args = args; this.args = args;
_set_hyper("rho", args.RHO);
_set_hyper("momentum", args.Momentum);
} }
} }
} }

+ 2
- 1
src/TensorFlowNET.Core/Keras/Optimizers/SGD.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using Tensorflow.Eager; using Tensorflow.Eager;
using Tensorflow.Keras.ArgsDefinition;


namespace Tensorflow.Keras.Optimizers namespace Tensorflow.Keras.Optimizers
{ {
@@ -17,7 +18,7 @@ namespace Tensorflow.Keras.Optimizers
public SGD(float learning_rate, public SGD(float learning_rate,
float momentum = 0.0f, float momentum = 0.0f,
bool nesterov = false, bool nesterov = false,
float decay = 0.0f) : base()
float decay = 0.0f) : base(new OptimizerV2Args { })
{ {
_set_hyper("learning_rate", learning_rate); _set_hyper("learning_rate", learning_rate);
_set_hyper("decay", decay); _set_hyper("decay", decay);


+ 2
- 2
src/TensorFlowNET.Core/Operations/Losses/Reduction.cs View File

@@ -3,9 +3,9 @@
public class Reduction public class Reduction
{ {
public const string NONE = "none"; public const string NONE = "none";
public const string SUM = "weighted_sum";
public const string WEIGHTED_SUM = "weighted_sum";
public const string SUM_OVER_BATCH_SIZE = "weighted_sum_over_batch_size"; public const string SUM_OVER_BATCH_SIZE = "weighted_sum_over_batch_size";
public const string MEAN = "weighted_mean";
public const string WEIGHTED_MEAN = "weighted_mean";
public const string SUM_BY_NONZERO_WEIGHTS = "weighted_sum_by_nonzero_weights"; public const string SUM_BY_NONZERO_WEIGHTS = "weighted_sum_by_nonzero_weights";
public const string SUM_OVER_NONZERO_WEIGHTS = SUM_BY_NONZERO_WEIGHTS; public const string SUM_OVER_NONZERO_WEIGHTS = SUM_BY_NONZERO_WEIGHTS;
} }


+ 1
- 1
src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs View File

@@ -47,7 +47,7 @@ namespace Tensorflow
else else
{ {
loss = math_ops.reduce_sum(weighted_losses); loss = math_ops.reduce_sum(weighted_losses);
if (reduction == Reduction.MEAN)
if (reduction == Reduction.WEIGHTED_MEAN)
loss = _safe_mean( loss = _safe_mean(
loss, math_ops.reduce_sum(array_ops.ones_like(losses) * weights)); loss, math_ops.reduce_sum(array_ops.ones_like(losses) * weights));
else if (reduction == Reduction.SUM_BY_NONZERO_WEIGHTS || else if (reduction == Reduction.SUM_BY_NONZERO_WEIGHTS ||


+ 3
- 3
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -5,7 +5,7 @@
<AssemblyName>TensorFlow.NET</AssemblyName> <AssemblyName>TensorFlow.NET</AssemblyName>
<RootNamespace>Tensorflow</RootNamespace> <RootNamespace>Tensorflow</RootNamespace>
<TargetTensorFlow>2.2.0</TargetTensorFlow> <TargetTensorFlow>2.2.0</TargetTensorFlow>
<Version>0.21.0</Version>
<Version>0.30.0</Version>
<LangVersion>8.0</LangVersion> <LangVersion>8.0</LangVersion>
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
<Company>SciSharp STACK</Company> <Company>SciSharp STACK</Company>
@@ -19,14 +19,14 @@
<Description>Google's TensorFlow full binding in .NET Standard. <Description>Google's TensorFlow full binding in .NET Standard.
Building, training and infering deep learning models. Building, training and infering deep learning models.
https://tensorflownet.readthedocs.io</Description> https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.21.0.0</AssemblyVersion>
<AssemblyVersion>0.30.0.0</AssemblyVersion>
<PackageReleaseNotes>tf.net 0.20.x and above are based on tensorflow native 2.x. <PackageReleaseNotes>tf.net 0.20.x and above are based on tensorflow native 2.x.


* Eager Mode is added finally. * Eager Mode is added finally.
* tf.keras is partially working. * tf.keras is partially working.
* tf.data is added. * tf.data is added.
* autograph works partially.</PackageReleaseNotes> * autograph works partially.</PackageReleaseNotes>
<FileVersion>0.21.0.0</FileVersion>
<FileVersion>0.30.0.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile> <PackageLicenseFile>LICENSE</PackageLicenseFile>
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
<SignAssembly>true</SignAssembly> <SignAssembly>true</SignAssembly>


Loading…
Cancel
Save