diff --git a/README.md b/README.md
index 4cb56684..2fd46c6b 100644
--- a/README.md
+++ b/README.md
@@ -26,12 +26,12 @@ In comparison to other projects, like for instance [TensorFlowSharp](https://www
### How to use
-| TensorFlow | tf native1.14 | tf native 1.15 | tf native 2.3 |
-| -------------------------- | ------------- | -------------- | ------------- |
-| tf.net 0.3x, tf.keras 0.2 | | | x |
-| tf.net 0.2x | | x | x |
-| tf.net 0.15 | x | x | |
-| tf.net 0.14 | x | | |
+| TensorFlow | tf native1.14, cuda 10.0 | tf native 1.15, cuda 10.0 | tf native 2.3, cuda 10.1 | tf native 2.4, cuda 11 |
+| -------------------------- | ------------- | -------------- | ------------- | ------------- |
+| tf.net 0.3x, tf.keras 0.2 | | | x | not compatible |
+| tf.net 0.2x | | x | x | |
+| tf.net 0.15 | x | x | | |
+| tf.net 0.14 | x | | | |
Troubleshooting of running example or installation, please refer [here](tensorflowlib/README.md).
diff --git a/src/SciSharp.TensorFlow.Redist/README.md b/src/SciSharp.TensorFlow.Redist/README.md
index 26eec870..141bba35 100644
--- a/src/SciSharp.TensorFlow.Redist/README.md
+++ b/src/SciSharp.TensorFlow.Redist/README.md
@@ -22,11 +22,19 @@ https://www.nuget.org/packages/SciSharp.TensorFlow.Redist
Related merged [commits](https://github.com/SciSharp/TensorFlow.NET/commit/854a5ba61ad0e400623821236bd117cc24c6cb77).
+
+
+#### Download pre-build package
+
+[Mac OSX CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-2.4.0.tar.gz), [Linux CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.4.0.tar.gz), [Linux GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-2.4.0.tar.gz), [Windows CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-2.4.0.tar.gz), [Windows GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-windows-x86_64-2.4.0.zip)
+
+
+
#### Pack and Deploy ####
On Windows, the tar command does not support extracting archives with symlinks. So when `dotnet pack` runs on Windows it will only package the Windows binaries.
1. Run `dotnet pack SciSharp.TensorFlow.Redist.nupkgproj` under `src/SciSharp.TensorFlow.Redist` directory in Linux.
-2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.3.1.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600`
+2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.4.0.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600`
diff --git a/src/TensorFlowNET.Core/Data/DatasetV2.cs b/src/TensorFlowNET.Core/Data/DatasetV2.cs
index 104789df..763baa31 100644
--- a/src/TensorFlowNET.Core/Data/DatasetV2.cs
+++ b/src/TensorFlowNET.Core/Data/DatasetV2.cs
@@ -3,6 +3,7 @@ using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Framework.Models;
+using static Tensorflow.Binding;
namespace Tensorflow
{
@@ -98,6 +99,20 @@ namespace Tensorflow
return dataset;
}
+ public Tensor dataset_cardinality(string name = null)
+ {
+ if (tf.Context.executing_eagerly())
+ {
+ var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
+ "DatasetCardinality", name,
+ null,
+ variant_tensor);
+ return results[0];
+ }
+
+ throw new NotImplementedException("");
+ }
+
public override string ToString()
=> $"{GetType().Name} shapes: {string.Join(", ", structure.Select(x => x.shape))}, types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}";
@@ -117,7 +132,9 @@ namespace Tensorflow
break;
}
- yield return (results[0], results.Length == 1 ? null : results[1]);
+ yield return results.Length == 2
+ ? (results[0], results[1])
+ : (null, results[0]);
}
}
diff --git a/src/TensorFlowNET.Core/Data/IDatasetV2.cs b/src/TensorFlowNET.Core/Data/IDatasetV2.cs
index fc47c832..9a31ff51 100644
--- a/src/TensorFlowNET.Core/Data/IDatasetV2.cs
+++ b/src/TensorFlowNET.Core/Data/IDatasetV2.cs
@@ -74,5 +74,7 @@ namespace Tensorflow
///
///
IDatasetV2 apply_options();
+
+ Tensor dataset_cardinality(string name = null);
}
}
diff --git a/src/TensorFlowNET.Core/Data/MapDataset.cs b/src/TensorFlowNET.Core/Data/MapDataset.cs
index 231e613e..c593322b 100644
--- a/src/TensorFlowNET.Core/Data/MapDataset.cs
+++ b/src/TensorFlowNET.Core/Data/MapDataset.cs
@@ -1,5 +1,6 @@
using System;
using Tensorflow.Functions;
+using static Tensorflow.Binding;
namespace Tensorflow
{
@@ -14,7 +15,12 @@ namespace Tensorflow
bool preserve_cardinality = false,
bool use_legacy_function = false) : base(input_dataset)
{
- var func = new ConcreteFunction(map_func, input_dataset.element_spec[0].dtype);
+ using var func = new ConcreteFunction($"autograph_{map_func.Method.Name}");
+ var input = tf.placeholder(input_dataset.element_spec[0].dtype);
+ var output = map_func(input);
+ func.ToGraph(input, output);
+
+ structure = func.OutputStructure;
variant_tensor = ops.map_dataset(input_dataset.variant_tensor,
func,
diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
index a3067182..90cb0494 100644
--- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
+++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
@@ -109,6 +109,8 @@ namespace Tensorflow.Functions
inputs,
outputs,
null);
+
+ OutputStructure = outputs.Select(x => x.ToTensorSpec()).ToArray();
}
public Tensors Invoke(Tensors inputs)
@@ -128,6 +130,9 @@ namespace Tensorflow.Functions
return new ForwardBackwardCall(functions, args, tape_watching: true);
}
+ public override string ToString()
+ => Name;
+
public void Dispose()
{
c_api.TFE_ContextRemoveFunction(tf.Context.Handle, Name, tf.Status.Handle);
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/TensorLikeDataAdapterArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs
similarity index 86%
rename from src/TensorFlowNET.Core/Keras/ArgsDefinition/TensorLikeDataAdapterArgs.cs
rename to src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs
index 921a4726..f3cca438 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/TensorLikeDataAdapterArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs
@@ -2,10 +2,11 @@
namespace Tensorflow.Keras.ArgsDefinition
{
- public class TensorLikeDataAdapterArgs
+ public class DataAdapterArgs
{
public Tensor X { get; set; }
public Tensor Y { get; set; }
+ public IDatasetV2 Dataset { get; set; }
public int BatchSize { get; set; } = 32;
public int Steps { get; set; }
public int Epochs { get; set; }
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs
index 63de54ad..b6e6849b 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs
@@ -6,6 +6,7 @@ namespace Tensorflow.Keras.ArgsDefinition
{
public Tensor X { get; set; }
public Tensor Y { get; set; }
+ public IDatasetV2 Dataset { get; set; }
public int BatchSize { get; set; } = 32;
public int StepsPerEpoch { get; set; } = -1;
public int InitialEpoch { get; set; } = 0;
diff --git a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs
index 014e376a..e48cb031 100644
--- a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs
+++ b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs
@@ -1702,74 +1702,79 @@ new_height, new_width");
public static Tensor decode_image(Tensor contents, int channels = 0, TF_DataType dtype = TF_DataType.TF_UINT8,
string name = null, bool expand_animations = true)
{
- Func _jpeg = () =>
+ return tf_with(ops.name_scope(name, "decode_image"), scope =>
{
- int jpeg_channels = channels;
- var good_channels = math_ops.not_equal(jpeg_channels, 4, name: "check_jpeg_channels");
- string channels_msg = "Channels must be in (None, 0, 1, 3) when decoding JPEG 'images'";
- var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg });
- return tf_with(ops.control_dependencies(new[] { assert_channels }), delegate
+ var substr = tf.strings.substr(contents, 0, 3);
+
+ Func _jpeg = () =>
{
- return convert_image_dtype(gen_image_ops.decode_jpeg(contents, channels), dtype);
- });
- };
+ int jpeg_channels = channels;
+ var good_channels = math_ops.not_equal(jpeg_channels, 4, name: "check_jpeg_channels");
+ string channels_msg = "Channels must be in (None, 0, 1, 3) when decoding JPEG 'images'";
+ var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg });
+ return tf_with(ops.control_dependencies(new[] { assert_channels }), delegate
+ {
+ return convert_image_dtype(gen_image_ops.decode_jpeg(contents, channels), dtype);
+ });
+ };
- Func _gif = () =>
- {
- int gif_channels = channels;
- var good_channels = math_ops.logical_and(
- math_ops.not_equal(gif_channels, 1, name: "check_gif_channels"),
- math_ops.not_equal(gif_channels, 4, name: "check_gif_channels"));
-
- string channels_msg = "Channels must be in (None, 0, 3) when decoding GIF images";
- var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg });
- return tf_with(ops.control_dependencies(new[] { assert_channels }), delegate
+ /*Func _gif = () =>
{
- var result = convert_image_dtype(gen_image_ops.decode_gif(contents), dtype);
- if (!expand_animations)
- result = array_ops.gather(result, 0);
- return result;
- });
- };
+ int gif_channels = channels;
+ var good_channels = math_ops.logical_and(
+ math_ops.not_equal(gif_channels, 1, name: "check_gif_channels"),
+ math_ops.not_equal(gif_channels, 4, name: "check_gif_channels"));
+
+ string channels_msg = "Channels must be in (None, 0, 3) when decoding GIF images";
+ var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg });
+ return tf_with(ops.control_dependencies(new[] { assert_channels }), delegate
+ {
+ var result = convert_image_dtype(gen_image_ops.decode_gif(contents), dtype);
+ if (!expand_animations)
+ result = array_ops.gather(result, 0);
+ return result;
+ });
+ };
- Func _bmp = () =>
- {
- int bmp_channels = channels;
- var signature = tf.strings.substr(contents, 0, 2);
- var is_bmp = math_ops.equal(signature, "BM", name: "is_bmp");
- string decode_msg = "Unable to decode bytes as JPEG, PNG, GIF, or BMP";
- var assert_decode = control_flow_ops.Assert(is_bmp, new string[] { decode_msg });
- var good_channels = math_ops.not_equal(bmp_channels, 1, name: "check_channels");
- string channels_msg = "Channels must be in (None, 0, 3) when decoding BMP images";
- var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg });
- return tf_with(ops.control_dependencies(new[] { assert_decode, assert_channels }), delegate
+ Func _bmp = () =>
{
- return convert_image_dtype(gen_image_ops.decode_bmp(contents), dtype);
- });
- };
+ int bmp_channels = channels;
+ var signature = tf.strings.substr(contents, 0, 2);
+ var is_bmp = math_ops.equal(signature, "BM", name: "is_bmp");
+ string decode_msg = "Unable to decode bytes as JPEG, PNG, GIF, or BMP";
+ var assert_decode = control_flow_ops.Assert(is_bmp, new string[] { decode_msg });
+ var good_channels = math_ops.not_equal(bmp_channels, 1, name: "check_channels");
+ string channels_msg = "Channels must be in (None, 0, 3) when decoding BMP images";
+ var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg });
+ return tf_with(ops.control_dependencies(new[] { assert_decode, assert_channels }), delegate
+ {
+ return convert_image_dtype(gen_image_ops.decode_bmp(contents), dtype);
+ });
+ };
- Func _png = () =>
- {
- return convert_image_dtype(gen_image_ops.decode_png(
- contents,
- channels,
- dtype: dtype),
- dtype);
- };
+ Func _png = () =>
+ {
+ return convert_image_dtype(gen_image_ops.decode_png(
+ contents,
+ channels,
+ dtype: dtype),
+ dtype);
+ };
- Func check_gif = () =>
- {
- return control_flow_ops.cond(is_gif(contents), _gif, _bmp, name: "cond_gif");
- };
+ Func check_gif = () =>
+ {
+ var gif = tf.constant(new byte[] { 0x47, 0x49, 0x46 }, TF_DataType.TF_STRING);
+ var is_gif = math_ops.equal(substr, gif, name: name);
+ return control_flow_ops.cond(is_gif, _gif, _bmp, name: "cond_gif");
+ };
- Func check_png = () =>
- {
- return control_flow_ops.cond(is_png(contents), _png, check_gif, name: "cond_png");
- };
+ Func check_png = () =>
+ {
+ return control_flow_ops.cond(is_png(contents), _png, check_gif, name: "cond_png");
+ };*/
- return tf_with(ops.name_scope(name, "decode_image"), scope =>
- {
- return control_flow_ops.cond(is_jpeg(contents), _jpeg, check_png, name: "cond_jpeg");
+ // return control_flow_ops.cond(is_jpeg(contents), _jpeg, check_png, name: "cond_jpeg");
+ return _jpeg() as Tensor;
});
}
diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
index fa292679..507f260c 100644
--- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
+++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
@@ -5,7 +5,7 @@
TensorFlow.NET
Tensorflow
2.2.0
- 0.31.1
+ 0.31.2
8.0
Haiping Chen, Meinrad Recheis, Eli Belash
SciSharp STACK
@@ -19,7 +19,7 @@
Google's TensorFlow full binding in .NET Standard.
Building, training and infering deep learning models.
https://tensorflownet.readthedocs.io
- 0.31.1.0
+ 0.31.2.0
tf.net 0.20.x and above are based on tensorflow native 2.x.
* Eager Mode is added finally.
@@ -30,7 +30,7 @@ https://tensorflownet.readthedocs.io
TensorFlow .NET v0.30 is focused on making more Keras API work including:
* tf.keras.datasets
* Building keras model in subclass, functional and sequential api
- 0.31.1.0
+ 0.31.2.0
LICENSE
true
true
diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
index 7a665c23..e2697aeb 100644
--- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs
+++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
@@ -20,6 +20,7 @@ using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Eager;
+using static Tensorflow.Binding;
namespace Tensorflow
{
@@ -410,14 +411,10 @@ would not be rank 1.", tensor.op.get_attr("axis")));
var value = constant_value(tensor);
if (!(value is null))
{
- int[] d_ = { };
- foreach (int d in value)
- {
- if (d >= 0)
- d_[d_.Length] = d;
- else
- d_[d_.Length] = -1; // None
- }
+ var d_ = new int[value.size];
+ foreach (var (index, d) in enumerate(value.ToArray()))
+ d_[index] = d >= 0 ? d : -1;
+
ret = ret.merge_with(new TensorShape(d_));
}
return ret;
@@ -577,7 +574,7 @@ would not be rank 1.", tensor.op.get_attr("axis")));
return string.Join(string.Empty, nd.ToArray()
.Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString()));
case TF_DataType.TF_BOOL:
- return (nd.GetByte(0) > 0).ToString();
+ return nd.GetBoolean(0).ToString();
case TF_DataType.TF_VARIANT:
case TF_DataType.TF_RESOURCE:
return "";
diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
index de5f4a8c..1bcb9c1d 100644
--- a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
+++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
@@ -37,19 +37,38 @@ namespace Tensorflow.Keras.Engine.DataAdapters
_steps_per_execution_value = args.StepsPerExecution.numpy();
}
- _adapter = new TensorLikeDataAdapter(new TensorLikeDataAdapterArgs
+ if(args.Dataset == null)
{
- X = args.X,
- Y = args.Y,
- BatchSize = args.BatchSize,
- Steps = args.StepsPerEpoch,
- Epochs = args.Epochs - args.InitialEpoch,
- Shuffle = args.Shuffle,
- MaxQueueSize = args.MaxQueueSize,
- Worker = args.Workers,
- UseMultiprocessing = args.UseMultiprocessing,
- Model = args.Model
- });
+ _adapter = new TensorLikeDataAdapter(new DataAdapterArgs
+ {
+ X = args.X,
+ Y = args.Y,
+ BatchSize = args.BatchSize,
+ Steps = args.StepsPerEpoch,
+ Epochs = args.Epochs - args.InitialEpoch,
+ Shuffle = args.Shuffle,
+ MaxQueueSize = args.MaxQueueSize,
+ Worker = args.Workers,
+ UseMultiprocessing = args.UseMultiprocessing,
+ Model = args.Model
+ });
+ }
+ else
+ {
+ _adapter = new DatasetAdapter(new DataAdapterArgs
+ {
+ Dataset = args.Dataset,
+ BatchSize = args.BatchSize,
+ Steps = args.StepsPerEpoch,
+ Epochs = args.Epochs - args.InitialEpoch,
+ Shuffle = args.Shuffle,
+ MaxQueueSize = args.MaxQueueSize,
+ Worker = args.Workers,
+ UseMultiprocessing = args.UseMultiprocessing,
+ Model = args.Model
+ });
+ }
+
_dataset = _adapter.GetDataset();
_inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset);
_current_step = 0;
@@ -66,7 +85,8 @@ namespace Tensorflow.Keras.Engine.DataAdapters
if (adapter_steps > -1)
return adapter_steps;
- throw new NotImplementedException("");
+ var size = dataset.dataset_cardinality();
+ return size.numpy();
}
public IEnumerable<(int, OwnedIterator)> enumerate_epochs()
diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DatasetAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DatasetAdapter.cs
new file mode 100644
index 00000000..d5f9613f
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DatasetAdapter.cs
@@ -0,0 +1,35 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Keras.ArgsDefinition;
+
+namespace Tensorflow.Keras.Engine.DataAdapters
+{
+ public class DatasetAdapter : IDataAdapter
+ {
+ DataAdapterArgs args;
+ IDatasetV2 _dataset => args.Dataset;
+ public DatasetAdapter(DataAdapterArgs args)
+ {
+ this.args = args;
+ }
+
+ public bool CanHandle(Tensor x, Tensor y = null)
+ {
+ throw new NotImplementedException();
+ }
+
+ public IDatasetV2 GetDataset()
+ => _dataset;
+
+ public int GetSize()
+ => -1;
+
+ public (Tensor, Tensor) Expand1d(Tensor x, Tensor y)
+ {
+ if (y.TensorShape.ndim == 1)
+ y = array_ops.expand_dims(y, axis: -1);
+ return (x, y);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
index 13f634dd..ecf5cbf9 100644
--- a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
+++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
@@ -9,14 +9,14 @@ namespace Tensorflow.Keras.Engine.DataAdapters
///
public class TensorLikeDataAdapter : IDataAdapter
{
- TensorLikeDataAdapterArgs args;
+ DataAdapterArgs args;
int _size;
int _batch_size;
int num_samples;
int num_full_batches;
IDatasetV2 _dataset;
- public TensorLikeDataAdapter(TensorLikeDataAdapterArgs args)
+ public TensorLikeDataAdapter(DataAdapterArgs args)
{
this.args = args;
_process_tensorlike();
diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs
index 56d0863a..b2b109ba 100644
--- a/src/TensorFlowNET.Keras/Engine/Functional.cs
+++ b/src/TensorFlowNET.Keras/Engine/Functional.cs
@@ -39,10 +39,12 @@ namespace Tensorflow.Keras.Engine
_input_coordinates = new List();
_output_coordinates = new List();
tensor_usage_count = new Dictionary();
+ if (this is Sequential)
+ return;
_init_graph_network(inputs, outputs);
}
- void _init_graph_network(Tensors inputs, Tensors outputs)
+ protected void _init_graph_network(Tensors inputs, Tensors outputs)
{
_is_graph_network = true;
this.inputs = inputs;
diff --git a/src/TensorFlowNET.Keras/Engine/Model.Compile.cs b/src/TensorFlowNET.Keras/Engine/Model.Compile.cs
index 003262d9..dd91a5de 100644
--- a/src/TensorFlowNET.Keras/Engine/Model.Compile.cs
+++ b/src/TensorFlowNET.Keras/Engine/Model.Compile.cs
@@ -9,10 +9,6 @@ namespace Tensorflow.Keras.Engine
{
LossesContainer compiled_loss;
MetricsContainer compiled_metrics;
- public void compile(string optimizerName, ILossFunc lossName)
- {
- throw new NotImplementedException("");
- }
public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics)
{
@@ -29,12 +25,12 @@ namespace Tensorflow.Keras.Engine
this.loss = loss;
}
- public void compile(string optimizerName, string lossName)
+ public void compile(string optimizer, string loss, string[] metrics)
{
- switch (optimizerName)
+ switch (optimizer)
{
case "rmsprop":
- optimizer = new RMSprop(new RMSpropArgs
+ this.optimizer = new RMSprop(new RMSpropArgs
{
});
diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs
index 951908f5..3c4960e7 100644
--- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs
+++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs
@@ -68,5 +68,49 @@ namespace Tensorflow.Keras.Engine
Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
}
}
+
+ public void fit(IDatasetV2 dataset,
+ IDatasetV2 validation_data = null,
+ int batch_size = -1,
+ int epochs = 1,
+ int verbose = 1,
+ float validation_split = 0f,
+ bool shuffle = true,
+ int initial_epoch = 0,
+ int max_queue_size = 10,
+ int workers = 1,
+ bool use_multiprocessing = false)
+ {
+ data_handler = new DataHandler(new DataHandlerArgs
+ {
+ Dataset = dataset,
+ BatchSize = batch_size,
+ InitialEpoch = initial_epoch,
+ Epochs = epochs,
+ Shuffle = shuffle,
+ MaxQueueSize = max_queue_size,
+ Workers = workers,
+ UseMultiprocessing = use_multiprocessing,
+ Model = this,
+ StepsPerExecution = _steps_per_execution
+ });
+
+ stop_training = false;
+ _train_counter.assign(0);
+ Console.WriteLine($"Training...");
+ foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
+ {
+ // reset_metrics();
+ // callbacks.on_epoch_begin(epoch)
+ // data_handler.catch_stop_iteration();
+ IEnumerable<(string, Tensor)> results = null;
+ foreach (var step in data_handler.steps())
+ {
+ // callbacks.on_train_batch_begin(step)
+ results = step_function(iterator);
+ }
+ Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
+ }
+ }
}
}
diff --git a/src/TensorFlowNET.Keras/Engine/Node.cs b/src/TensorFlowNET.Keras/Engine/Node.cs
index d78e5533..fad58534 100644
--- a/src/TensorFlowNET.Keras/Engine/Node.cs
+++ b/src/TensorFlowNET.Keras/Engine/Node.cs
@@ -35,7 +35,7 @@ namespace Tensorflow.Keras.Engine
public int[] node_indices;
public int[] tensor_indices;
- public Tensors input_tensors => args.InputTensors;
+ public Tensors input_tensors => is_input ? Outputs : args.InputTensors;
public Tensors Outputs => args.Outputs;
public TensorShape[] input_shapes;
public TensorShape[] output_shapes;
diff --git a/src/TensorFlowNET.Keras/Engine/Sequential.cs b/src/TensorFlowNET.Keras/Engine/Sequential.cs
index 50974cf7..e2267d53 100644
--- a/src/TensorFlowNET.Keras/Engine/Sequential.cs
+++ b/src/TensorFlowNET.Keras/Engine/Sequential.cs
@@ -17,6 +17,7 @@
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Layers;
+using Tensorflow.Keras.Utils;
using static Tensorflow.KerasApi;
namespace Tensorflow.Keras.Engine
@@ -25,36 +26,40 @@ namespace Tensorflow.Keras.Engine
/// `Sequential` groups a linear stack of layers into a `tf.keras.Model`.
/// `Sequential` provides training and inference features on this model.
///
- public class Sequential : Model
+ public class Sequential : Functional
{
SequentialArgs args;
bool _is_graph_network;
- Tensor inputs;
- Tensor outputs;
-
- bool computeOutputAndMaskJointly;
- bool autoTrackSubLayers;
- TensorShape inferredInputShape;
- bool hasExplicitInputShape;
- TF_DataType inputDType;
- List layers => args.Layers;
- public TensorShape output_shape => outputs.TensorShape;
+ Tensors inputs;
+ Tensors outputs;
+
+ bool _compute_output_and_mask_jointly;
+ bool _auto_track_sub_layers;
+ TensorShape _inferred_input_shape;
+ bool _has_explicit_input_shape;
+ TF_DataType _input_dtype;
+
+ public TensorShape output_shape => outputs[0].TensorShape;
bool built = false;
public Sequential(SequentialArgs args)
- : base(new ModelArgs
- {
- Name = args.Name
- })
+ : base(args.Inputs, args.Outputs, name: args.Name)
{
this.args = args;
if (args.Layers == null)
args.Layers = new List();
// SupportsMasking = true;
- computeOutputAndMaskJointly = true;
- autoTrackSubLayers = false;
- hasExplicitInputShape = false;
+ _compute_output_and_mask_jointly = true;
+ _auto_track_sub_layers = false;
+ _has_explicit_input_shape = false;
_is_graph_network = false;
+
+ // Add to the model any layers passed to the constructor.
+ if (args.Layers != null)
+ {
+ foreach (var layer in args.Layers)
+ add(layer as Layer);
+ }
}
public void add(Tensor tensor)
@@ -71,7 +76,7 @@ namespace Tensorflow.Keras.Engine
{
built = false;
var set_inputs = false;
- if (layers.Count == 0)
+ if (_layers.Count == 0)
{
if (layer is InputLayer)
{
@@ -83,7 +88,7 @@ namespace Tensorflow.Keras.Engine
{
// Instantiate an input layer.
var x = keras.Input(
- shape: layer.BatchInputShape,
+ batch_input_shape: layer.BatchInputShape,
dtype: layer.DType,
name: layer.Name + "_input");
@@ -99,36 +104,26 @@ namespace Tensorflow.Keras.Engine
{
// If an input layer (placeholder) is available.
outputs = layer.InboundNodes[^1].Outputs;
+ inputs = layer_utils.get_source_inputs(outputs[0]);
+ built = true;
+ _has_explicit_input_shape = true;
}
-
}
else if (outputs != null)
{
outputs = layer.Apply(outputs);
+ built = true;
}
if (set_inputs || _is_graph_network)
{
_init_graph_network(inputs, outputs);
+ _is_graph_network = true;
}
else
{
}
}
-
- void _init_graph_network(Tensor inputs, Tensor outputs)
- {
- _is_graph_network = true;
- this.inputs = inputs;
- this.outputs = outputs;
- built = true;
- _map_graph_network(inputs, outputs);
- }
-
- void _map_graph_network(Tensor inputs, Tensor outputs)
- {
- layers.add(outputs.KerasHistory.Layer);
- }
}
}
diff --git a/src/TensorFlowNET.Keras/KerasInterface.cs b/src/TensorFlowNET.Keras/KerasInterface.cs
index 6cb733d3..40519ac4 100644
--- a/src/TensorFlowNET.Keras/KerasInterface.cs
+++ b/src/TensorFlowNET.Keras/KerasInterface.cs
@@ -62,16 +62,21 @@ namespace Tensorflow.Keras
///
public Tensor Input(TensorShape shape = null,
int batch_size = -1,
+ TensorShape batch_input_shape = null,
TF_DataType dtype = TF_DataType.DtInvalid,
string name = null,
bool sparse = false,
bool ragged = false,
Tensor tensor = null)
{
+ if (batch_input_shape != null)
+ shape = batch_input_shape.dims[1..];
+
var args = new InputLayerArgs
{
Name = name,
InputShape = shape,
+ BatchInputShape = batch_input_shape,
BatchSize = batch_size,
DType = dtype,
Sparse = sparse,
diff --git a/src/TensorFlowNET.Keras/Layers/Rescaling/Rescaling.cs b/src/TensorFlowNET.Keras/Layers/Rescaling/Rescaling.cs
index f81ee161..7466685f 100644
--- a/src/TensorFlowNET.Keras/Layers/Rescaling/Rescaling.cs
+++ b/src/TensorFlowNET.Keras/Layers/Rescaling/Rescaling.cs
@@ -23,5 +23,10 @@ namespace Tensorflow.Keras.Layers
offset = math_ops.cast(args.Offset, args.DType);
return math_ops.cast(inputs, args.DType) * scale + offset;
}
+
+ public override TensorShape ComputeOutputShape(TensorShape input_shape)
+ {
+ return input_shape;
+ }
}
}
diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs
index 316cab8c..66235198 100644
--- a/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs
+++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs
@@ -1,4 +1,5 @@
using System;
+using Tensorflow.Framework;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;
@@ -15,6 +16,7 @@ namespace Tensorflow.Keras.Layers
public Flatten(FlattenArgs args)
: base(args)
{
+ this.args = args;
args.DataFormat = conv_utils.normalize_data_format(args.DataFormat);
input_spec = new InputSpec(min_ndim: 1);
_channels_first = args.DataFormat == "channels_first";
@@ -31,8 +33,29 @@ namespace Tensorflow.Keras.Layers
{
return array_ops.reshape(inputs, new[] { inputs.shape[0], -1 });
}
+ else
+ {
+ var input_shape = inputs.shape;
+ var rank = inputs.shape.rank;
+ if (rank == 1)
+ return array_ops.expand_dims(inputs, axis: 1);
+ var batch_dim = tensor_shape.dimension_value(input_shape[0]);
+ if (batch_dim != -1)
+ {
+ return array_ops.reshape(inputs, new[] { batch_dim, -1 });
+ }
- throw new NotImplementedException("");
+ var non_batch_dims = ((int[])input_shape)[1..];
+ var num = 1;
+ if (non_batch_dims.Length > 0)
+ {
+ for (var i = 0; i < non_batch_dims.Length; i++)
+ {
+ num *= non_batch_dims[i];
+ }
+ }
+ return array_ops.reshape(inputs, new[] { inputs.shape[0], num });
+ }
}
}
}
diff --git a/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs b/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs
index 33754b00..cf7ef12c 100644
--- a/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs
+++ b/src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs
@@ -40,8 +40,8 @@ namespace Tensorflow.Keras.Preprocessings
labels.AddRange(Enumerable.Range(0, files.Length).Select(x => label));
}
- var return_labels = new int[labels.Count];
- var return_file_paths = new string[file_paths.Count];
+ var return_labels = labels.Select(x => x).ToArray();
+ var return_file_paths = file_paths.Select(x => x).ToArray();
if (shuffle)
{
diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs
index c9af1915..a57ac73e 100644
--- a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs
+++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs
@@ -41,7 +41,7 @@ namespace Tensorflow.Keras
int num_channels = 0;
if (color_mode == "rgb")
num_channels = 3;
- // C:/Users/haipi/.keras/datasets/flower_photos
+
var (image_paths, label_list, class_name_list) = keras.preprocessing.dataset_utils.index_directory(directory,
formats: WHITELIST_FORMATS,
class_names: class_names,
diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs
index ad950fc9..abf07735 100644
--- a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs
+++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs
@@ -16,27 +16,11 @@ namespace Tensorflow.Keras
var path_ds = tf.data.Dataset.from_tensor_slices(image_paths);
var img_ds = path_ds.map(x => path_to_image(x, image_size, num_channels, interpolation));
- /*Shape shape = (image_paths.Length, image_size.dims[0], image_size.dims[1], num_channels);
- Console.WriteLine($"Allocating memory for shape{shape}, {NPTypeCode.Float}");
- var data = np.zeros(shape, NPTypeCode.Float);
-
- for (var i = 0; i < image_paths.Length; i++)
- {
- var image = path_to_image(image_paths[i], image_size, num_channels, interpolation);
- data[i] = image.numpy();
- if (i % 100 == 0)
- Console.WriteLine($"Filled {i}/{image_paths.Length} data into ndarray.");
- }
-
- var img_ds = tf.data.Dataset.from_tensor_slices(data);
-
if (label_mode == "int")
{
- var label_ds = tf.keras.preprocessing.dataset_utils.labels_to_dataset(labels, label_mode, num_classes);
+ var label_ds = dataset_utils.labels_to_dataset(labels, label_mode, num_classes);
img_ds = tf.data.Dataset.zip(img_ds, label_ds);
}
- else*/
- throw new NotImplementedException("");
return img_ds;
}
@@ -47,6 +31,7 @@ namespace Tensorflow.Keras
img = tf.image.decode_image(
img, channels: num_channels, expand_animations: false);
img = tf.image.resize_images_v2(img, image_size, method: interpolation);
+ // img.set_shape((image_size[0], image_size[1], num_channels));
return img;
}
}
diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
index aec973ad..b97fdb8d 100644
--- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
+++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
@@ -6,7 +6,7 @@
8.0
Tensorflow.Keras
AnyCPU;x64
- 0.2.1
+ 0.3.0
Haiping Chen
Keras for .NET
Apache 2.0, Haiping Chen 2020
@@ -25,11 +25,13 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
SciSharp STACK
true
tensorflow, keras, deep learning, machine learning
- false
+ true
Git
true
Open.snk
- 0.2.1.0
+ 0.3.0.0
+ 0.3.0.0
+ LICENSE
@@ -55,4 +57,11 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
+
+
+ True
+
+
+
+
diff --git a/src/TensorFlowNET.Keras/Utils/layer_utils.cs b/src/TensorFlowNET.Keras/Utils/layer_utils.cs
index 34f553d4..4166fd51 100644
--- a/src/TensorFlowNET.Keras/Utils/layer_utils.cs
+++ b/src/TensorFlowNET.Keras/Utils/layer_utils.cs
@@ -187,5 +187,34 @@ namespace Tensorflow.Keras.Utils
var total = weight_shapes.Select(p => (int)np.prod(p.dims)).Sum();
return total;
}
+
+ public static Tensors get_source_inputs(Tensor tensor, ILayer layer = null, int node_index = -1)
+ {
+ if (layer == null)
+ (layer, node_index, _) = tensor.KerasHistory;
+ if (layer.InboundNodes == null || layer.InboundNodes.Count == 0)
+ return tensor;
+ else
+ {
+ var node = layer.InboundNodes[node_index];
+ if (node.is_input)
+ return node.input_tensors;
+ else
+ {
+ var source_tensors = new List();
+ foreach (var _layer in node.iterate_inbound())
+ {
+ (layer, node_index, tensor) = (_layer.Item1, _layer.Item2, _layer.Item4);
+ var previous_sources = get_source_inputs(tensor, layer, node_index);
+ foreach(var x in previous_sources)
+ {
+ // should be check if exist?
+ source_tensors.append(x);
+ }
+ }
+ return source_tensors;
+ }
+ }
+ }
}
}
diff --git a/tensorflowlib/README.md b/tensorflowlib/README.md
index eb957b46..20d30f6f 100644
--- a/tensorflowlib/README.md
+++ b/tensorflowlib/README.md
@@ -24,7 +24,7 @@ More information about [System.Drawing on Linux]( Install-Package SciSharp.TensorFlow.Redist-Windows-GPU
PM> Install-Package SciSharp.TensorFlow.Redist-Linux-GPU
```
+Since NuGet limits file size for 250M, we can't ship Linux GPU version as NuGet, you can download the library from [Google TensorFlow Storage](https://storage.googleapis.com/tensorflow).
+
### Download prebuild binary manually
-Tensorflow packages are built nightly and uploaded to GCS for all supported platforms. They are uploaded to the [libtensorflow-nightly](https://www.tensorflow.org/install/lang_c) GCS bucket and are indexed by operating system and date built.
+TensorFlow packages are built nightly and uploaded to GCS for all supported platforms. They are uploaded to the [libtensorflow-nightly](https://www.tensorflow.org/install/lang_c) GCS bucket and are indexed by operating system and date built.
### Build from source for Windows
diff --git a/test/TensorFlowNET.UnitTest/ImageTest.cs b/test/TensorFlowNET.UnitTest/ImageTest.cs
index 48794650..b32c659e 100644
--- a/test/TensorFlowNET.UnitTest/ImageTest.cs
+++ b/test/TensorFlowNET.UnitTest/ImageTest.cs
@@ -28,7 +28,7 @@ namespace TensorFlowNET.UnitTest.Basics
public void decode_image()
{
var img = tf.image.decode_image(contents);
- Assert.AreEqual(img.name, "decode_image/cond_jpeg/Merge:0");
+ Assert.AreEqual(img.name, "decode_image/Identity:0");
}
[TestMethod]
diff --git a/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs b/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs
index 279a5db5..c6858ba0 100644
--- a/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs
+++ b/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs
@@ -57,7 +57,7 @@ namespace TensorFlowNET.UnitTest.Keras
{ 2, 3, 4, 5 },
{ 3, 4, 5, 6 }
});
- model.compile("rmsprop", "mse");
+ // model.compile("rmsprop", "mse");
var output_array = model.predict(input_array);
Assert.AreEqual((32, 10, 64), output_array.TensorShape);
}