diff --git a/README.md b/README.md
index f437ad8e..d04a8a90 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@

-**TensorFlow.NET** (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in C# which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework.
+**TensorFlow.NET** (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in C# which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework. TensorFlow.NET has built-in Keras high-level interface and is released as an independent package [TensorFlow.Keras](https://www.nuget.org/packages/TensorFlow.Keras/).
[](https://gitter.im/sci-sharp/community)
[](https://ci.appveyor.com/project/Haiping-Chen/tensorflow-net)
@@ -14,7 +14,7 @@

-### Why TensorFlow.NET ?
+### Why TensorFlow in .NET/ C# ?
`SciSharp STACK`'s mission is to bring popular data science technology into the .NET world and to provide .NET developers with a powerful Machine Learning tool set without reinventing the wheel. Since the APIs are kept as similar as possible you can immediately adapt any existing Tensorflow code in C# with a zero learning curve. Take a look at a comparison picture and see how comfortably a Tensorflow/Python script translates into a C# program with TensorFlow.NET.
@@ -22,20 +22,23 @@
SciSharp's philosophy allows a large number of machine learning code written in Python to be quickly migrated to .NET, enabling .NET developers to use cutting edge machine learning models and access a vast number of Tensorflow resources which would not be possible without this project.
-In comparison to other projects, like for instance TensorFlowSharp which only provide Tensorflow's low-level C++ API and can only run models that were built using Python, Tensorflow.NET also implements Tensorflow's high level API where all the magic happens. This computation graph building layer is still under active development. Once it is completely implemented you can build new Machine Learning models in C#.
+In comparison to other projects, like for instance [TensorFlowSharp](https://www.nuget.org/packages/TensorFlowSharp/) which only provide Tensorflow's low-level C++ API and can only run models that were built using Python, Tensorflow.NET also implements Tensorflow's high level API where all the magic happens. This computation graph building layer is still under active development. Once it is completely implemented you can build new Machine Learning models in C#.
### How to use
-| TensorFlow | tf native1.14 | tf native 1.15 | tf native 2.3 |
-| ----------- | ------------- | -------------- | ------------- |
-| tf.net 0.20 | | x | x |
-| tf.net 0.15 | x | x | |
-| tf.net 0.14 | x | | |
+| TensorFlow | tf native1.14 | tf native 1.15 | tf native 2.3 |
+| ------------------------- | ------------- | -------------- | ------------- |
+| tf.net 0.30, tf.keras 0.1 | | | x |
+| tf.net 0.20 | | x | x |
+| tf.net 0.15 | x | x | |
+| tf.net 0.14 | x | | |
Install TF.NET and TensorFlow binary through NuGet.
```sh
### install tensorflow C# binding
PM> Install-Package TensorFlow.NET
+### install keras for tensorflow
+PM> Install-Package TensorFlow.Keras
### Install tensorflow binary
### For CPU version
@@ -45,13 +48,14 @@ PM> Install-Package SciSharp.TensorFlow.Redist
PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU
```
-Import TF.NET in your project.
+Import TF.NET and Keras API in your project.
```cs
using static Tensorflow.Binding;
+using static Tensorflow.KerasApi;
```
-Linear Regression:
+Linear Regression in `Eager` mode:
```c#
// Parameters
@@ -92,6 +96,52 @@ foreach (var step in range(1, training_steps + 1))
Run this example in [Jupyter Notebook](https://github.com/SciSharp/SciSharpCube).
+Toy version of `ResNet` in `Keras` functional API:
+
+```csharp
+// input layer
+var inputs = keras.Input(shape: (32, 32, 3), name: "img");
+
+// convolutional layer
+var x = layers.Conv2D(32, 3, activation: "relu").Apply(inputs);
+x = layers.Conv2D(64, 3, activation: "relu").Apply(x);
+var block_1_output = layers.MaxPooling2D(3).Apply(x);
+
+x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(block_1_output);
+x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(x);
+var block_2_output = layers.add(x, block_1_output);
+
+x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(block_2_output);
+x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(x);
+var block_3_output = layers.add(x, block_2_output);
+
+x = layers.Conv2D(64, 3, activation: "relu").Apply(block_3_output);
+x = layers.GlobalAveragePooling2D().Apply(x);
+x = layers.Dense(256, activation: "relu").Apply(x);
+x = layers.Dropout(0.5f).Apply(x);
+
+// output layer
+var outputs = layers.Dense(10).Apply(x);
+
+// build keras model
+model = keras.Model(inputs, outputs, name: "toy_resnet");
+model.summary();
+
+// compile keras model in tensorflow static graph
+model.compile(optimizer: keras.optimizers.RMSprop(1e-3f),
+ loss: keras.losses.CategoricalCrossentropy(from_logits: true),
+ metrics: new[] { "acc" });
+
+// prepare dataset
+var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data();
+
+// training
+model.fit(x_train[new Slice(0, 1000)], y_train[new Slice(0, 1000)],
+ batch_size: 64,
+ epochs: 10,
+ validation_split: 0.2f);
+```
+
Read the docs & book [The Definitive Guide to Tensorflow.NET](https://tensorflownet.readthedocs.io/en/latest/FrontCover.html).
There are many examples reside at [TensorFlow.NET Examples](https://github.com/SciSharp/TensorFlow.NET-Examples).
diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs
index b88b1bb7..8454c9bf 100644
--- a/src/TensorFlowNET.Keras/BackendImpl.cs
+++ b/src/TensorFlowNET.Keras/BackendImpl.cs
@@ -167,5 +167,21 @@ namespace Tensorflow.Keras
public class _DummyEagerGraph
{ }
+
+ ///
+ /// Categorical crossentropy between an output tensor and a target tensor.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public Tensor categorical_crossentropy(Tensor target, Tensor output, bool from_logits = false, int axis = -1)
+ {
+ if (from_logits)
+ return tf.nn.softmax_cross_entropy_with_logits_v2(labels: target, logits: output, axis: axis);
+
+ throw new NotImplementedException("");
+ }
}
}
diff --git a/src/TensorFlowNET.Keras/Datasets/Cifar10.cs b/src/TensorFlowNET.Keras/Datasets/Cifar10.cs
new file mode 100644
index 00000000..c449def2
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Datasets/Cifar10.cs
@@ -0,0 +1,135 @@
+using NumSharp;
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Text;
+using static Tensorflow.Binding;
+using Tensorflow.Keras.Utils;
+
+namespace Tensorflow.Keras.Datasets
+{
+ public class Cifar10
+ {
+ string origin_folder = "https://www.cs.toronto.edu/~kriz/";
+ string file_name = "cifar-10-python.tar.gz";
+ string dest_folder = "cifar-10-batches";
+
+ ///
+ /// Loads [CIFAR10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html).
+ ///
+ ///
+ public DatasetPass load_data()
+ {
+ var dst = Download();
+
+ var data_list = new List();
+ var label_list = new List();
+
+ foreach (var i in range(1, 6))
+ {
+ var fpath = Path.Combine(dst, $"data_batch_{i}");
+ var (data, labels) = load_batch(fpath);
+ data_list.Add(data);
+ label_list.Add(labels);
+ }
+
+ var x_train_tensor = tf.concat(data_list, 0);
+ var y_train_tensor = tf.concat(label_list, 0);
+ var y_train = np.array(y_train_tensor.BufferToArray()).reshape(y_train_tensor.shape);
+
+ // test data
+ var fpath_test = Path.Combine(dst, "test_batch");
+ var (x_test, y_test) = load_batch(fpath_test);
+
+ // channels_last
+ x_train_tensor = tf.transpose(x_train_tensor, new[] { 0, 2, 3, 1 });
+ var x_train = np.array(x_train_tensor.BufferToArray()).reshape(x_train_tensor.shape);
+
+ var x_test_tensor = tf.transpose(x_test, new[] { 0, 2, 3, 1 });
+ x_test = np.array(x_test_tensor.BufferToArray()).reshape(x_test_tensor.shape);
+
+ return new DatasetPass
+ {
+ Train = (x_train, y_train),
+ Test = (x_test, y_test)
+ };
+ }
+
+ (NDArray, NDArray) load_batch(string fpath, string label_key = "labels")
+ {
+ var pickle = File.ReadAllBytes(fpath);
+ // read description
+ var start_pos = 7;
+ var desc = read_description(ref start_pos, pickle);
+ var labels = read_labels(ref start_pos, pickle);
+ var data = read_data(ref start_pos, pickle);
+
+ return (data.Item2, labels.Item2);
+ }
+
+ (string, string) read_description(ref int start_pos, byte[] pickle)
+ {
+ var key_length = pickle[start_pos];
+ start_pos++;
+ var span = new Span(pickle, start_pos, key_length);
+ var key = Encoding.ASCII.GetString(span.ToArray());
+ start_pos += key_length + 3;
+
+ var value_length = pickle[start_pos];
+ start_pos++;
+ var value = Encoding.ASCII.GetString(new Span(pickle, start_pos, value_length).ToArray());
+ start_pos += value_length;
+ start_pos += 3;
+
+ return (key, value);
+ }
+
+ (string, NDArray) read_labels(ref int start_pos, byte[] pickle)
+ {
+ byte[] value = new byte[10000];
+
+ var key_length = pickle[start_pos];
+ start_pos++;
+ var span = new Span(pickle, start_pos, key_length);
+ var key = Encoding.ASCII.GetString(span.ToArray());
+ start_pos += key_length + 6;
+
+ var value_length = 10000;
+ for (int i = 0; i < value_length; i++)
+ {
+ if (i > 0 && i % 1000 == 0)
+ start_pos += 2;
+ value[i] = pickle[start_pos + 1];
+ start_pos += 2;
+ }
+ start_pos += 2;
+
+ return (key, np.array(value));
+ }
+
+ (string, NDArray) read_data(ref int start_pos, byte[] pickle)
+ {
+ var key_length = pickle[start_pos];
+ start_pos++;
+ var span = new Span(pickle, start_pos, key_length);
+ var key = Encoding.ASCII.GetString(span.ToArray());
+ start_pos += key_length + 133;
+ var value_length = 3072 * 10000;
+ var value = new Span(pickle, start_pos, value_length).ToArray();
+ start_pos += value_length;
+
+ return (key, np.array(value).reshape(10000, 3, 32, 32));
+ }
+
+ string Download()
+ {
+ var dst = Path.Combine(Path.GetTempPath(), dest_folder);
+ Directory.CreateDirectory(dst);
+
+ Web.Download(origin_folder + file_name, dst, file_name);
+ Compress.ExtractTGZ(Path.Combine(Path.GetTempPath(), file_name), dst);
+
+ return Path.Combine(dst, "cifar-10-batches-py");
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Datasets/KerasDataset.cs b/src/TensorFlowNET.Keras/Datasets/KerasDataset.cs
index 49f90e09..d5442298 100644
--- a/src/TensorFlowNET.Keras/Datasets/KerasDataset.cs
+++ b/src/TensorFlowNET.Keras/Datasets/KerasDataset.cs
@@ -19,5 +19,6 @@ namespace Tensorflow.Keras.Datasets
public class KerasDataset
{
public Mnist mnist { get; } = new Mnist();
+ public Cifar10 cifar10 { get; } = new Cifar10();
}
}
diff --git a/src/TensorFlowNET.Keras/Datasets/MNIST.cs b/src/TensorFlowNET.Keras/Datasets/MNIST.cs
index 4cc4dbdb..9cdc56b5 100644
--- a/src/TensorFlowNET.Keras/Datasets/MNIST.cs
+++ b/src/TensorFlowNET.Keras/Datasets/MNIST.cs
@@ -17,7 +17,7 @@
using NumSharp;
using System;
using System.IO;
-using System.Net;
+using Tensorflow.Keras.Utils;
namespace Tensorflow.Keras.Datasets
{
@@ -65,8 +65,7 @@ namespace Tensorflow.Keras.Datasets
return fileSaveTo;
}
- using var wc = new WebClient();
- wc.DownloadFileTaskAsync(origin_folder + file_name, fileSaveTo).Wait();
+ Web.Download(origin_folder + file_name, Path.GetTempPath(), file_name);
return fileSaveTo;
}
diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
index 6b7819f9..b99d9be4 100644
--- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
+++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
@@ -6,7 +6,7 @@
8.0
Tensorflow.Keras
AnyCPU;x64
- 0.1.0
+ 0.2.0
Haiping Chen
Keras for .NET
Apache 2.0, Haiping Chen 2020
@@ -14,7 +14,10 @@
https://github.com/SciSharp/TensorFlow.NET
https://avatars3.githubusercontent.com/u/44989469?s=200&v=4
https://github.com/SciSharp/TensorFlow.NET
- Keras for .NET is a C# version of Keras ported from the python version.
+ Keras for .NET is a C# version of Keras ported from the python version.
+
+* Support CIFAR-10 dataset in keras.datasets.
+* Support Conv2D functional API.
Keras for .NET
Keras is an API designed for human beings, not machines. Keras follows best practices for reducing cognitive load: it offers consistent & simple APIs, it minimizes the number of user actions required for common use cases, and it provides clear & actionable error messages.
@@ -27,11 +30,17 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
DEBUG;TRACE
+ false
+
+
+
+ false
+
diff --git a/src/TensorFlowNET.Keras/Utils/Compress.cs b/src/TensorFlowNET.Keras/Utils/Compress.cs
new file mode 100644
index 00000000..5a4f99c7
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Utils/Compress.cs
@@ -0,0 +1,102 @@
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using ICSharpCode.SharpZipLib.Core;
+using ICSharpCode.SharpZipLib.GZip;
+using ICSharpCode.SharpZipLib.Tar;
+using System;
+using System.IO;
+using System.IO.Compression;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Tensorflow.Keras.Utils
+{
+ public class Compress
+ {
+ public static void ExtractGZip(string gzipFileName, string targetDir)
+ {
+ // Use a 4K buffer. Any larger is a waste.
+ byte[] dataBuffer = new byte[4096];
+
+ using (System.IO.Stream fs = new FileStream(gzipFileName, FileMode.Open, FileAccess.Read))
+ {
+ using (GZipInputStream gzipStream = new GZipInputStream(fs))
+ {
+ // Change this to your needs
+ string fnOut = Path.Combine(targetDir, Path.GetFileNameWithoutExtension(gzipFileName));
+
+ using (FileStream fsOut = File.Create(fnOut))
+ {
+ StreamUtils.Copy(gzipStream, fsOut, dataBuffer);
+ }
+ }
+ }
+ }
+
+ public static void UnZip(String gzArchiveName, String destFolder)
+ {
+ var flag = gzArchiveName.Split(Path.DirectorySeparatorChar).Last().Split('.').First() + ".bin";
+ if (File.Exists(Path.Combine(destFolder, flag))) return;
+
+ Console.WriteLine($"Extracting.");
+ var task = Task.Run(() =>
+ {
+ ZipFile.ExtractToDirectory(gzArchiveName, destFolder);
+ });
+
+ while (!task.IsCompleted)
+ {
+ Thread.Sleep(200);
+ Console.Write(".");
+ }
+
+ File.Create(Path.Combine(destFolder, flag));
+ Console.WriteLine("");
+ Console.WriteLine("Extracting is completed.");
+ }
+
+ public static void ExtractTGZ(String gzArchiveName, String destFolder)
+ {
+ var flag = gzArchiveName.Split(Path.DirectorySeparatorChar).Last().Split('.').First() + ".bin";
+ if (File.Exists(Path.Combine(destFolder, flag))) return;
+
+ Console.WriteLine($"Extracting.");
+ var task = Task.Run(() =>
+ {
+ using (var inStream = File.OpenRead(gzArchiveName))
+ {
+ using (var gzipStream = new GZipInputStream(inStream))
+ {
+ using (TarArchive tarArchive = TarArchive.CreateInputTarArchive(gzipStream))
+ tarArchive.ExtractContents(destFolder);
+ }
+ }
+ });
+
+ while (!task.IsCompleted)
+ {
+ Thread.Sleep(200);
+ Console.Write(".");
+ }
+
+ File.Create(Path.Combine(destFolder, flag));
+ Console.WriteLine("");
+ Console.WriteLine("Extracting is completed.");
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Utils/Web.cs b/src/TensorFlowNET.Keras/Utils/Web.cs
new file mode 100644
index 00000000..839b6470
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Utils/Web.cs
@@ -0,0 +1,57 @@
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using System;
+using System.IO;
+using System.Linq;
+using System.Net;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Tensorflow.Keras.Utils
+{
+ public class Web
+ {
+ public static bool Download(string url, string destDir, string destFileName)
+ {
+ if (destFileName == null)
+ destFileName = url.Split(Path.DirectorySeparatorChar).Last();
+
+ Directory.CreateDirectory(destDir);
+
+ string relativeFilePath = Path.Combine(destDir, destFileName);
+
+ if (File.Exists(relativeFilePath))
+ {
+ Console.WriteLine($"{relativeFilePath} already exists.");
+ return false;
+ }
+
+ var wc = new WebClient();
+ Console.WriteLine($"Downloading {relativeFilePath}");
+ var download = Task.Run(() => wc.DownloadFile(url, relativeFilePath));
+ while (!download.IsCompleted)
+ {
+ Thread.Sleep(1000);
+ Console.Write(".");
+ }
+ Console.WriteLine("");
+ Console.WriteLine($"Downloaded {relativeFilePath}");
+
+ return true;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Utils/np_utils.cs b/src/TensorFlowNET.Keras/Utils/np_utils.cs
new file mode 100644
index 00000000..595254dc
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Utils/np_utils.cs
@@ -0,0 +1,31 @@
+using NumSharp;
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Keras.Utils
+{
+ public class np_utils
+ {
+ ///
+ /// Converts a class vector (integers) to binary class matrix.
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static NDArray to_categorical(NDArray y, int num_classes = -1, TF_DataType dtype = TF_DataType.TF_FLOAT)
+ {
+ var y1 = y.astype(NPTypeCode.Int32).ToArray();
+ // var input_shape = y.shape[..^1];
+ var categorical = np.zeros((y.size, num_classes), dtype: dtype.as_numpy_dtype());
+ // categorical[np.arange(y.size), y] = 1;
+ for (int i = 0; i < y.size; i++)
+ {
+ categorical[i][y1[i]] = 1;
+ }
+
+ return categorical;
+ }
+ }
+}