@@ -26,12 +26,12 @@ In comparison to other projects, like for instance [TensorFlowSharp](https://www | |||
### How to use | |||
| TensorFlow | tf native1.14 | tf native 1.15 | tf native 2.3 | | |||
| -------------------------- | ------------- | -------------- | ------------- | | |||
| tf.net 0.3x, tf.keras 0.2 | | | x | | |||
| tf.net 0.2x | | x | x | | |||
| tf.net 0.15 | x | x | | | |||
| tf.net 0.14 | x | | | | |||
| TensorFlow | tf native1.14, cuda 10.0 | tf native 1.15, cuda 10.0 | tf native 2.3, cuda 10.1 | tf native 2.4, cuda 11 | | |||
| -------------------------- | ------------- | -------------- | ------------- | ------------- | | |||
| tf.net 0.3x, tf.keras 0.2 | | | x | not compatible | | |||
| tf.net 0.2x | | x | x | | | |||
| tf.net 0.15 | x | x | | | | |||
| tf.net 0.14 | x | | | | | |||
Troubleshooting of running example or installation, please refer [here](tensorflowlib/README.md). | |||
@@ -22,11 +22,19 @@ https://www.nuget.org/packages/SciSharp.TensorFlow.Redist | |||
Related merged [commits](https://github.com/SciSharp/TensorFlow.NET/commit/854a5ba61ad0e400623821236bd117cc24c6cb77). | |||
#### Download pre-build package | |||
[Mac OSX CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-2.4.0.tar.gz), [Linux CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.4.0.tar.gz), [Linux GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-2.4.0.tar.gz), [Windows CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-2.4.0.tar.gz), [Windows GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-windows-x86_64-2.4.0.zip) | |||
#### Pack and Deploy #### | |||
On Windows, the tar command does not support extracting archives with symlinks. So when `dotnet pack` runs on Windows it will only package the Windows binaries. | |||
1. Run `dotnet pack SciSharp.TensorFlow.Redist.nupkgproj` under `src/SciSharp.TensorFlow.Redist` directory in Linux. | |||
2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.3.1.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600` | |||
2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.4.0.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600` | |||
@@ -3,6 +3,7 @@ using System.Collections; | |||
using System.Collections.Generic; | |||
using System.Linq; | |||
using Tensorflow.Framework.Models; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
@@ -98,6 +99,20 @@ namespace Tensorflow | |||
return dataset; | |||
} | |||
public Tensor dataset_cardinality(string name = null) | |||
{ | |||
if (tf.Context.executing_eagerly()) | |||
{ | |||
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
"DatasetCardinality", name, | |||
null, | |||
variant_tensor); | |||
return results[0]; | |||
} | |||
throw new NotImplementedException(""); | |||
} | |||
public override string ToString() | |||
=> $"{GetType().Name} shapes: {string.Join(", ", structure.Select(x => x.shape))}, types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}"; | |||
@@ -117,7 +132,9 @@ namespace Tensorflow | |||
break; | |||
} | |||
yield return (results[0], results.Length == 1 ? null : results[1]); | |||
yield return results.Length == 2 | |||
? (results[0], results[1]) | |||
: (null, results[0]); | |||
} | |||
} | |||
@@ -74,5 +74,7 @@ namespace Tensorflow | |||
/// </summary> | |||
/// <returns></returns> | |||
IDatasetV2 apply_options(); | |||
Tensor dataset_cardinality(string name = null); | |||
} | |||
} |
@@ -1,5 +1,6 @@ | |||
using System; | |||
using Tensorflow.Functions; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
@@ -14,7 +15,12 @@ namespace Tensorflow | |||
bool preserve_cardinality = false, | |||
bool use_legacy_function = false) : base(input_dataset) | |||
{ | |||
var func = new ConcreteFunction(map_func, input_dataset.element_spec[0].dtype); | |||
using var func = new ConcreteFunction($"autograph_{map_func.Method.Name}"); | |||
var input = tf.placeholder(input_dataset.element_spec[0].dtype); | |||
var output = map_func(input); | |||
func.ToGraph(input, output); | |||
structure = func.OutputStructure; | |||
variant_tensor = ops.map_dataset(input_dataset.variant_tensor, | |||
func, | |||
@@ -109,6 +109,8 @@ namespace Tensorflow.Functions | |||
inputs, | |||
outputs, | |||
null); | |||
OutputStructure = outputs.Select(x => x.ToTensorSpec()).ToArray(); | |||
} | |||
public Tensors Invoke(Tensors inputs) | |||
@@ -128,6 +130,9 @@ namespace Tensorflow.Functions | |||
return new ForwardBackwardCall(functions, args, tape_watching: true); | |||
} | |||
public override string ToString() | |||
=> Name; | |||
public void Dispose() | |||
{ | |||
c_api.TFE_ContextRemoveFunction(tf.Context.Handle, Name, tf.Status.Handle); | |||
@@ -2,10 +2,11 @@ | |||
namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public class TensorLikeDataAdapterArgs | |||
public class DataAdapterArgs | |||
{ | |||
public Tensor X { get; set; } | |||
public Tensor Y { get; set; } | |||
public IDatasetV2 Dataset { get; set; } | |||
public int BatchSize { get; set; } = 32; | |||
public int Steps { get; set; } | |||
public int Epochs { get; set; } |
@@ -6,6 +6,7 @@ namespace Tensorflow.Keras.ArgsDefinition | |||
{ | |||
public Tensor X { get; set; } | |||
public Tensor Y { get; set; } | |||
public IDatasetV2 Dataset { get; set; } | |||
public int BatchSize { get; set; } = 32; | |||
public int StepsPerEpoch { get; set; } = -1; | |||
public int InitialEpoch { get; set; } = 0; | |||
@@ -1702,74 +1702,79 @@ new_height, new_width"); | |||
public static Tensor decode_image(Tensor contents, int channels = 0, TF_DataType dtype = TF_DataType.TF_UINT8, | |||
string name = null, bool expand_animations = true) | |||
{ | |||
Func<ITensorOrOperation> _jpeg = () => | |||
return tf_with(ops.name_scope(name, "decode_image"), scope => | |||
{ | |||
int jpeg_channels = channels; | |||
var good_channels = math_ops.not_equal(jpeg_channels, 4, name: "check_jpeg_channels"); | |||
string channels_msg = "Channels must be in (None, 0, 1, 3) when decoding JPEG 'images'"; | |||
var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg }); | |||
return tf_with(ops.control_dependencies(new[] { assert_channels }), delegate | |||
var substr = tf.strings.substr(contents, 0, 3); | |||
Func<ITensorOrOperation> _jpeg = () => | |||
{ | |||
return convert_image_dtype(gen_image_ops.decode_jpeg(contents, channels), dtype); | |||
}); | |||
}; | |||
int jpeg_channels = channels; | |||
var good_channels = math_ops.not_equal(jpeg_channels, 4, name: "check_jpeg_channels"); | |||
string channels_msg = "Channels must be in (None, 0, 1, 3) when decoding JPEG 'images'"; | |||
var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg }); | |||
return tf_with(ops.control_dependencies(new[] { assert_channels }), delegate | |||
{ | |||
return convert_image_dtype(gen_image_ops.decode_jpeg(contents, channels), dtype); | |||
}); | |||
}; | |||
Func<ITensorOrOperation> _gif = () => | |||
{ | |||
int gif_channels = channels; | |||
var good_channels = math_ops.logical_and( | |||
math_ops.not_equal(gif_channels, 1, name: "check_gif_channels"), | |||
math_ops.not_equal(gif_channels, 4, name: "check_gif_channels")); | |||
string channels_msg = "Channels must be in (None, 0, 3) when decoding GIF images"; | |||
var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg }); | |||
return tf_with(ops.control_dependencies(new[] { assert_channels }), delegate | |||
/*Func<ITensorOrOperation> _gif = () => | |||
{ | |||
var result = convert_image_dtype(gen_image_ops.decode_gif(contents), dtype); | |||
if (!expand_animations) | |||
result = array_ops.gather(result, 0); | |||
return result; | |||
}); | |||
}; | |||
int gif_channels = channels; | |||
var good_channels = math_ops.logical_and( | |||
math_ops.not_equal(gif_channels, 1, name: "check_gif_channels"), | |||
math_ops.not_equal(gif_channels, 4, name: "check_gif_channels")); | |||
string channels_msg = "Channels must be in (None, 0, 3) when decoding GIF images"; | |||
var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg }); | |||
return tf_with(ops.control_dependencies(new[] { assert_channels }), delegate | |||
{ | |||
var result = convert_image_dtype(gen_image_ops.decode_gif(contents), dtype); | |||
if (!expand_animations) | |||
result = array_ops.gather(result, 0); | |||
return result; | |||
}); | |||
}; | |||
Func<ITensorOrOperation> _bmp = () => | |||
{ | |||
int bmp_channels = channels; | |||
var signature = tf.strings.substr(contents, 0, 2); | |||
var is_bmp = math_ops.equal(signature, "BM", name: "is_bmp"); | |||
string decode_msg = "Unable to decode bytes as JPEG, PNG, GIF, or BMP"; | |||
var assert_decode = control_flow_ops.Assert(is_bmp, new string[] { decode_msg }); | |||
var good_channels = math_ops.not_equal(bmp_channels, 1, name: "check_channels"); | |||
string channels_msg = "Channels must be in (None, 0, 3) when decoding BMP images"; | |||
var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg }); | |||
return tf_with(ops.control_dependencies(new[] { assert_decode, assert_channels }), delegate | |||
Func<ITensorOrOperation> _bmp = () => | |||
{ | |||
return convert_image_dtype(gen_image_ops.decode_bmp(contents), dtype); | |||
}); | |||
}; | |||
int bmp_channels = channels; | |||
var signature = tf.strings.substr(contents, 0, 2); | |||
var is_bmp = math_ops.equal(signature, "BM", name: "is_bmp"); | |||
string decode_msg = "Unable to decode bytes as JPEG, PNG, GIF, or BMP"; | |||
var assert_decode = control_flow_ops.Assert(is_bmp, new string[] { decode_msg }); | |||
var good_channels = math_ops.not_equal(bmp_channels, 1, name: "check_channels"); | |||
string channels_msg = "Channels must be in (None, 0, 3) when decoding BMP images"; | |||
var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg }); | |||
return tf_with(ops.control_dependencies(new[] { assert_decode, assert_channels }), delegate | |||
{ | |||
return convert_image_dtype(gen_image_ops.decode_bmp(contents), dtype); | |||
}); | |||
}; | |||
Func<ITensorOrOperation> _png = () => | |||
{ | |||
return convert_image_dtype(gen_image_ops.decode_png( | |||
contents, | |||
channels, | |||
dtype: dtype), | |||
dtype); | |||
}; | |||
Func<ITensorOrOperation> _png = () => | |||
{ | |||
return convert_image_dtype(gen_image_ops.decode_png( | |||
contents, | |||
channels, | |||
dtype: dtype), | |||
dtype); | |||
}; | |||
Func<ITensorOrOperation> check_gif = () => | |||
{ | |||
return control_flow_ops.cond(is_gif(contents), _gif, _bmp, name: "cond_gif"); | |||
}; | |||
Func<ITensorOrOperation> check_gif = () => | |||
{ | |||
var gif = tf.constant(new byte[] { 0x47, 0x49, 0x46 }, TF_DataType.TF_STRING); | |||
var is_gif = math_ops.equal(substr, gif, name: name); | |||
return control_flow_ops.cond(is_gif, _gif, _bmp, name: "cond_gif"); | |||
}; | |||
Func<ITensorOrOperation> check_png = () => | |||
{ | |||
return control_flow_ops.cond(is_png(contents), _png, check_gif, name: "cond_png"); | |||
}; | |||
Func<ITensorOrOperation> check_png = () => | |||
{ | |||
return control_flow_ops.cond(is_png(contents), _png, check_gif, name: "cond_png"); | |||
};*/ | |||
return tf_with(ops.name_scope(name, "decode_image"), scope => | |||
{ | |||
return control_flow_ops.cond(is_jpeg(contents), _jpeg, check_png, name: "cond_jpeg"); | |||
// return control_flow_ops.cond(is_jpeg(contents), _jpeg, check_png, name: "cond_jpeg"); | |||
return _jpeg() as Tensor; | |||
}); | |||
} | |||
@@ -5,7 +5,7 @@ | |||
<AssemblyName>TensorFlow.NET</AssemblyName> | |||
<RootNamespace>Tensorflow</RootNamespace> | |||
<TargetTensorFlow>2.2.0</TargetTensorFlow> | |||
<Version>0.31.1</Version> | |||
<Version>0.31.2</Version> | |||
<LangVersion>8.0</LangVersion> | |||
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | |||
<Company>SciSharp STACK</Company> | |||
@@ -19,7 +19,7 @@ | |||
<Description>Google's TensorFlow full binding in .NET Standard. | |||
Building, training and infering deep learning models. | |||
https://tensorflownet.readthedocs.io</Description> | |||
<AssemblyVersion>0.31.1.0</AssemblyVersion> | |||
<AssemblyVersion>0.31.2.0</AssemblyVersion> | |||
<PackageReleaseNotes>tf.net 0.20.x and above are based on tensorflow native 2.x. | |||
* Eager Mode is added finally. | |||
@@ -30,7 +30,7 @@ https://tensorflownet.readthedocs.io</Description> | |||
TensorFlow .NET v0.30 is focused on making more Keras API work including: | |||
* tf.keras.datasets | |||
* Building keras model in subclass, functional and sequential api</PackageReleaseNotes> | |||
<FileVersion>0.31.1.0</FileVersion> | |||
<FileVersion>0.31.2.0</FileVersion> | |||
<PackageLicenseFile>LICENSE</PackageLicenseFile> | |||
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | |||
<SignAssembly>true</SignAssembly> | |||
@@ -20,6 +20,7 @@ using System.Collections.Generic; | |||
using System.Linq; | |||
using System.Text; | |||
using Tensorflow.Eager; | |||
using static Tensorflow.Binding; | |||
namespace Tensorflow | |||
{ | |||
@@ -410,14 +411,10 @@ would not be rank 1.", tensor.op.get_attr("axis"))); | |||
var value = constant_value(tensor); | |||
if (!(value is null)) | |||
{ | |||
int[] d_ = { }; | |||
foreach (int d in value) | |||
{ | |||
if (d >= 0) | |||
d_[d_.Length] = d; | |||
else | |||
d_[d_.Length] = -1; // None | |||
} | |||
var d_ = new int[value.size]; | |||
foreach (var (index, d) in enumerate(value.ToArray<int>())) | |||
d_[index] = d >= 0 ? d : -1; | |||
ret = ret.merge_with(new TensorShape(d_)); | |||
} | |||
return ret; | |||
@@ -577,7 +574,7 @@ would not be rank 1.", tensor.op.get_attr("axis"))); | |||
return string.Join(string.Empty, nd.ToArray<byte>() | |||
.Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString())); | |||
case TF_DataType.TF_BOOL: | |||
return (nd.GetByte(0) > 0).ToString(); | |||
return nd.GetBoolean(0).ToString(); | |||
case TF_DataType.TF_VARIANT: | |||
case TF_DataType.TF_RESOURCE: | |||
return "<unprintable>"; | |||
@@ -37,19 +37,38 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
_steps_per_execution_value = args.StepsPerExecution.numpy(); | |||
} | |||
_adapter = new TensorLikeDataAdapter(new TensorLikeDataAdapterArgs | |||
if(args.Dataset == null) | |||
{ | |||
X = args.X, | |||
Y = args.Y, | |||
BatchSize = args.BatchSize, | |||
Steps = args.StepsPerEpoch, | |||
Epochs = args.Epochs - args.InitialEpoch, | |||
Shuffle = args.Shuffle, | |||
MaxQueueSize = args.MaxQueueSize, | |||
Worker = args.Workers, | |||
UseMultiprocessing = args.UseMultiprocessing, | |||
Model = args.Model | |||
}); | |||
_adapter = new TensorLikeDataAdapter(new DataAdapterArgs | |||
{ | |||
X = args.X, | |||
Y = args.Y, | |||
BatchSize = args.BatchSize, | |||
Steps = args.StepsPerEpoch, | |||
Epochs = args.Epochs - args.InitialEpoch, | |||
Shuffle = args.Shuffle, | |||
MaxQueueSize = args.MaxQueueSize, | |||
Worker = args.Workers, | |||
UseMultiprocessing = args.UseMultiprocessing, | |||
Model = args.Model | |||
}); | |||
} | |||
else | |||
{ | |||
_adapter = new DatasetAdapter(new DataAdapterArgs | |||
{ | |||
Dataset = args.Dataset, | |||
BatchSize = args.BatchSize, | |||
Steps = args.StepsPerEpoch, | |||
Epochs = args.Epochs - args.InitialEpoch, | |||
Shuffle = args.Shuffle, | |||
MaxQueueSize = args.MaxQueueSize, | |||
Worker = args.Workers, | |||
UseMultiprocessing = args.UseMultiprocessing, | |||
Model = args.Model | |||
}); | |||
} | |||
_dataset = _adapter.GetDataset(); | |||
_inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset); | |||
_current_step = 0; | |||
@@ -66,7 +85,8 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
if (adapter_steps > -1) | |||
return adapter_steps; | |||
throw new NotImplementedException(""); | |||
var size = dataset.dataset_cardinality(); | |||
return size.numpy(); | |||
} | |||
public IEnumerable<(int, OwnedIterator)> enumerate_epochs() | |||
@@ -0,0 +1,35 @@ | |||
using System; | |||
using System.Collections.Generic; | |||
using System.Text; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
namespace Tensorflow.Keras.Engine.DataAdapters | |||
{ | |||
public class DatasetAdapter : IDataAdapter | |||
{ | |||
DataAdapterArgs args; | |||
IDatasetV2 _dataset => args.Dataset; | |||
public DatasetAdapter(DataAdapterArgs args) | |||
{ | |||
this.args = args; | |||
} | |||
public bool CanHandle(Tensor x, Tensor y = null) | |||
{ | |||
throw new NotImplementedException(); | |||
} | |||
public IDatasetV2 GetDataset() | |||
=> _dataset; | |||
public int GetSize() | |||
=> -1; | |||
public (Tensor, Tensor) Expand1d(Tensor x, Tensor y) | |||
{ | |||
if (y.TensorShape.ndim == 1) | |||
y = array_ops.expand_dims(y, axis: -1); | |||
return (x, y); | |||
} | |||
} | |||
} |
@@ -9,14 +9,14 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||
/// </summary> | |||
public class TensorLikeDataAdapter : IDataAdapter | |||
{ | |||
TensorLikeDataAdapterArgs args; | |||
DataAdapterArgs args; | |||
int _size; | |||
int _batch_size; | |||
int num_samples; | |||
int num_full_batches; | |||
IDatasetV2 _dataset; | |||
public TensorLikeDataAdapter(TensorLikeDataAdapterArgs args) | |||
public TensorLikeDataAdapter(DataAdapterArgs args) | |||
{ | |||
this.args = args; | |||
_process_tensorlike(); | |||
@@ -39,10 +39,12 @@ namespace Tensorflow.Keras.Engine | |||
_input_coordinates = new List<KerasHistory>(); | |||
_output_coordinates = new List<KerasHistory>(); | |||
tensor_usage_count = new Dictionary<int, int>(); | |||
if (this is Sequential) | |||
return; | |||
_init_graph_network(inputs, outputs); | |||
} | |||
void _init_graph_network(Tensors inputs, Tensors outputs) | |||
protected void _init_graph_network(Tensors inputs, Tensors outputs) | |||
{ | |||
_is_graph_network = true; | |||
this.inputs = inputs; | |||
@@ -9,10 +9,6 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
LossesContainer compiled_loss; | |||
MetricsContainer compiled_metrics; | |||
public void compile(string optimizerName, ILossFunc lossName) | |||
{ | |||
throw new NotImplementedException(""); | |||
} | |||
public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics) | |||
{ | |||
@@ -29,12 +25,12 @@ namespace Tensorflow.Keras.Engine | |||
this.loss = loss; | |||
} | |||
public void compile(string optimizerName, string lossName) | |||
public void compile(string optimizer, string loss, string[] metrics) | |||
{ | |||
switch (optimizerName) | |||
switch (optimizer) | |||
{ | |||
case "rmsprop": | |||
optimizer = new RMSprop(new RMSpropArgs | |||
this.optimizer = new RMSprop(new RMSpropArgs | |||
{ | |||
}); | |||
@@ -68,5 +68,49 @@ namespace Tensorflow.Keras.Engine | |||
Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}"))); | |||
} | |||
} | |||
public void fit(IDatasetV2 dataset, | |||
IDatasetV2 validation_data = null, | |||
int batch_size = -1, | |||
int epochs = 1, | |||
int verbose = 1, | |||
float validation_split = 0f, | |||
bool shuffle = true, | |||
int initial_epoch = 0, | |||
int max_queue_size = 10, | |||
int workers = 1, | |||
bool use_multiprocessing = false) | |||
{ | |||
data_handler = new DataHandler(new DataHandlerArgs | |||
{ | |||
Dataset = dataset, | |||
BatchSize = batch_size, | |||
InitialEpoch = initial_epoch, | |||
Epochs = epochs, | |||
Shuffle = shuffle, | |||
MaxQueueSize = max_queue_size, | |||
Workers = workers, | |||
UseMultiprocessing = use_multiprocessing, | |||
Model = this, | |||
StepsPerExecution = _steps_per_execution | |||
}); | |||
stop_training = false; | |||
_train_counter.assign(0); | |||
Console.WriteLine($"Training..."); | |||
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||
{ | |||
// reset_metrics(); | |||
// callbacks.on_epoch_begin(epoch) | |||
// data_handler.catch_stop_iteration(); | |||
IEnumerable<(string, Tensor)> results = null; | |||
foreach (var step in data_handler.steps()) | |||
{ | |||
// callbacks.on_train_batch_begin(step) | |||
results = step_function(iterator); | |||
} | |||
Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}"))); | |||
} | |||
} | |||
} | |||
} |
@@ -35,7 +35,7 @@ namespace Tensorflow.Keras.Engine | |||
public int[] node_indices; | |||
public int[] tensor_indices; | |||
public Tensors input_tensors => args.InputTensors; | |||
public Tensors input_tensors => is_input ? Outputs : args.InputTensors; | |||
public Tensors Outputs => args.Outputs; | |||
public TensorShape[] input_shapes; | |||
public TensorShape[] output_shapes; | |||
@@ -17,6 +17,7 @@ | |||
using System.Collections.Generic; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Layers; | |||
using Tensorflow.Keras.Utils; | |||
using static Tensorflow.KerasApi; | |||
namespace Tensorflow.Keras.Engine | |||
@@ -25,36 +26,40 @@ namespace Tensorflow.Keras.Engine | |||
/// `Sequential` groups a linear stack of layers into a `tf.keras.Model`. | |||
/// `Sequential` provides training and inference features on this model. | |||
/// </summary> | |||
public class Sequential : Model | |||
public class Sequential : Functional | |||
{ | |||
SequentialArgs args; | |||
bool _is_graph_network; | |||
Tensor inputs; | |||
Tensor outputs; | |||
bool computeOutputAndMaskJointly; | |||
bool autoTrackSubLayers; | |||
TensorShape inferredInputShape; | |||
bool hasExplicitInputShape; | |||
TF_DataType inputDType; | |||
List<ILayer> layers => args.Layers; | |||
public TensorShape output_shape => outputs.TensorShape; | |||
Tensors inputs; | |||
Tensors outputs; | |||
bool _compute_output_and_mask_jointly; | |||
bool _auto_track_sub_layers; | |||
TensorShape _inferred_input_shape; | |||
bool _has_explicit_input_shape; | |||
TF_DataType _input_dtype; | |||
public TensorShape output_shape => outputs[0].TensorShape; | |||
bool built = false; | |||
public Sequential(SequentialArgs args) | |||
: base(new ModelArgs | |||
{ | |||
Name = args.Name | |||
}) | |||
: base(args.Inputs, args.Outputs, name: args.Name) | |||
{ | |||
this.args = args; | |||
if (args.Layers == null) | |||
args.Layers = new List<ILayer>(); | |||
// SupportsMasking = true; | |||
computeOutputAndMaskJointly = true; | |||
autoTrackSubLayers = false; | |||
hasExplicitInputShape = false; | |||
_compute_output_and_mask_jointly = true; | |||
_auto_track_sub_layers = false; | |||
_has_explicit_input_shape = false; | |||
_is_graph_network = false; | |||
// Add to the model any layers passed to the constructor. | |||
if (args.Layers != null) | |||
{ | |||
foreach (var layer in args.Layers) | |||
add(layer as Layer); | |||
} | |||
} | |||
public void add(Tensor tensor) | |||
@@ -71,7 +76,7 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
built = false; | |||
var set_inputs = false; | |||
if (layers.Count == 0) | |||
if (_layers.Count == 0) | |||
{ | |||
if (layer is InputLayer) | |||
{ | |||
@@ -83,7 +88,7 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
// Instantiate an input layer. | |||
var x = keras.Input( | |||
shape: layer.BatchInputShape, | |||
batch_input_shape: layer.BatchInputShape, | |||
dtype: layer.DType, | |||
name: layer.Name + "_input"); | |||
@@ -99,36 +104,26 @@ namespace Tensorflow.Keras.Engine | |||
{ | |||
// If an input layer (placeholder) is available. | |||
outputs = layer.InboundNodes[^1].Outputs; | |||
inputs = layer_utils.get_source_inputs(outputs[0]); | |||
built = true; | |||
_has_explicit_input_shape = true; | |||
} | |||
} | |||
else if (outputs != null) | |||
{ | |||
outputs = layer.Apply(outputs); | |||
built = true; | |||
} | |||
if (set_inputs || _is_graph_network) | |||
{ | |||
_init_graph_network(inputs, outputs); | |||
_is_graph_network = true; | |||
} | |||
else | |||
{ | |||
} | |||
} | |||
void _init_graph_network(Tensor inputs, Tensor outputs) | |||
{ | |||
_is_graph_network = true; | |||
this.inputs = inputs; | |||
this.outputs = outputs; | |||
built = true; | |||
_map_graph_network(inputs, outputs); | |||
} | |||
void _map_graph_network(Tensor inputs, Tensor outputs) | |||
{ | |||
layers.add(outputs.KerasHistory.Layer); | |||
} | |||
} | |||
} |
@@ -62,16 +62,21 @@ namespace Tensorflow.Keras | |||
/// <returns></returns> | |||
public Tensor Input(TensorShape shape = null, | |||
int batch_size = -1, | |||
TensorShape batch_input_shape = null, | |||
TF_DataType dtype = TF_DataType.DtInvalid, | |||
string name = null, | |||
bool sparse = false, | |||
bool ragged = false, | |||
Tensor tensor = null) | |||
{ | |||
if (batch_input_shape != null) | |||
shape = batch_input_shape.dims[1..]; | |||
var args = new InputLayerArgs | |||
{ | |||
Name = name, | |||
InputShape = shape, | |||
BatchInputShape = batch_input_shape, | |||
BatchSize = batch_size, | |||
DType = dtype, | |||
Sparse = sparse, | |||
@@ -23,5 +23,10 @@ namespace Tensorflow.Keras.Layers | |||
offset = math_ops.cast(args.Offset, args.DType); | |||
return math_ops.cast(inputs, args.DType) * scale + offset; | |||
} | |||
public override TensorShape ComputeOutputShape(TensorShape input_shape) | |||
{ | |||
return input_shape; | |||
} | |||
} | |||
} |
@@ -1,4 +1,5 @@ | |||
using System; | |||
using Tensorflow.Framework; | |||
using Tensorflow.Keras.ArgsDefinition; | |||
using Tensorflow.Keras.Engine; | |||
using Tensorflow.Keras.Utils; | |||
@@ -15,6 +16,7 @@ namespace Tensorflow.Keras.Layers | |||
public Flatten(FlattenArgs args) | |||
: base(args) | |||
{ | |||
this.args = args; | |||
args.DataFormat = conv_utils.normalize_data_format(args.DataFormat); | |||
input_spec = new InputSpec(min_ndim: 1); | |||
_channels_first = args.DataFormat == "channels_first"; | |||
@@ -31,8 +33,29 @@ namespace Tensorflow.Keras.Layers | |||
{ | |||
return array_ops.reshape(inputs, new[] { inputs.shape[0], -1 }); | |||
} | |||
else | |||
{ | |||
var input_shape = inputs.shape; | |||
var rank = inputs.shape.rank; | |||
if (rank == 1) | |||
return array_ops.expand_dims(inputs, axis: 1); | |||
var batch_dim = tensor_shape.dimension_value(input_shape[0]); | |||
if (batch_dim != -1) | |||
{ | |||
return array_ops.reshape(inputs, new[] { batch_dim, -1 }); | |||
} | |||
throw new NotImplementedException(""); | |||
var non_batch_dims = ((int[])input_shape)[1..]; | |||
var num = 1; | |||
if (non_batch_dims.Length > 0) | |||
{ | |||
for (var i = 0; i < non_batch_dims.Length; i++) | |||
{ | |||
num *= non_batch_dims[i]; | |||
} | |||
} | |||
return array_ops.reshape(inputs, new[] { inputs.shape[0], num }); | |||
} | |||
} | |||
} | |||
} |
@@ -40,8 +40,8 @@ namespace Tensorflow.Keras.Preprocessings | |||
labels.AddRange(Enumerable.Range(0, files.Length).Select(x => label)); | |||
} | |||
var return_labels = new int[labels.Count]; | |||
var return_file_paths = new string[file_paths.Count]; | |||
var return_labels = labels.Select(x => x).ToArray(); | |||
var return_file_paths = file_paths.Select(x => x).ToArray(); | |||
if (shuffle) | |||
{ | |||
@@ -41,7 +41,7 @@ namespace Tensorflow.Keras | |||
int num_channels = 0; | |||
if (color_mode == "rgb") | |||
num_channels = 3; | |||
// C:/Users/haipi/.keras/datasets/flower_photos | |||
var (image_paths, label_list, class_name_list) = keras.preprocessing.dataset_utils.index_directory(directory, | |||
formats: WHITELIST_FORMATS, | |||
class_names: class_names, | |||
@@ -16,27 +16,11 @@ namespace Tensorflow.Keras | |||
var path_ds = tf.data.Dataset.from_tensor_slices(image_paths); | |||
var img_ds = path_ds.map(x => path_to_image(x, image_size, num_channels, interpolation)); | |||
/*Shape shape = (image_paths.Length, image_size.dims[0], image_size.dims[1], num_channels); | |||
Console.WriteLine($"Allocating memory for shape{shape}, {NPTypeCode.Float}"); | |||
var data = np.zeros(shape, NPTypeCode.Float); | |||
for (var i = 0; i < image_paths.Length; i++) | |||
{ | |||
var image = path_to_image(image_paths[i], image_size, num_channels, interpolation); | |||
data[i] = image.numpy(); | |||
if (i % 100 == 0) | |||
Console.WriteLine($"Filled {i}/{image_paths.Length} data into ndarray."); | |||
} | |||
var img_ds = tf.data.Dataset.from_tensor_slices(data); | |||
if (label_mode == "int") | |||
{ | |||
var label_ds = tf.keras.preprocessing.dataset_utils.labels_to_dataset(labels, label_mode, num_classes); | |||
var label_ds = dataset_utils.labels_to_dataset(labels, label_mode, num_classes); | |||
img_ds = tf.data.Dataset.zip(img_ds, label_ds); | |||
} | |||
else*/ | |||
throw new NotImplementedException(""); | |||
return img_ds; | |||
} | |||
@@ -47,6 +31,7 @@ namespace Tensorflow.Keras | |||
img = tf.image.decode_image( | |||
img, channels: num_channels, expand_animations: false); | |||
img = tf.image.resize_images_v2(img, image_size, method: interpolation); | |||
// img.set_shape((image_size[0], image_size[1], num_channels)); | |||
return img; | |||
} | |||
} | |||
@@ -6,7 +6,7 @@ | |||
<LangVersion>8.0</LangVersion> | |||
<RootNamespace>Tensorflow.Keras</RootNamespace> | |||
<Platforms>AnyCPU;x64</Platforms> | |||
<Version>0.2.1</Version> | |||
<Version>0.3.0</Version> | |||
<Authors>Haiping Chen</Authors> | |||
<Product>Keras for .NET</Product> | |||
<Copyright>Apache 2.0, Haiping Chen 2020</Copyright> | |||
@@ -25,11 +25,13 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||
<Company>SciSharp STACK</Company> | |||
<GeneratePackageOnBuild>true</GeneratePackageOnBuild> | |||
<PackageTags>tensorflow, keras, deep learning, machine learning</PackageTags> | |||
<PackageRequireLicenseAcceptance>false</PackageRequireLicenseAcceptance> | |||
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | |||
<RepositoryType>Git</RepositoryType> | |||
<SignAssembly>true</SignAssembly> | |||
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile> | |||
<AssemblyVersion>0.2.1.0</AssemblyVersion> | |||
<AssemblyVersion>0.3.0.0</AssemblyVersion> | |||
<FileVersion>0.3.0.0</FileVersion> | |||
<PackageLicenseFile>LICENSE</PackageLicenseFile> | |||
</PropertyGroup> | |||
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||
@@ -55,4 +57,11 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||
<ProjectReference Include="..\TensorFlowNET.Core\Tensorflow.Binding.csproj" /> | |||
</ItemGroup> | |||
<ItemGroup> | |||
<None Include="..\..\LICENSE"> | |||
<Pack>True</Pack> | |||
<PackagePath></PackagePath> | |||
</None> | |||
</ItemGroup> | |||
</Project> |
@@ -187,5 +187,34 @@ namespace Tensorflow.Keras.Utils | |||
var total = weight_shapes.Select(p => (int)np.prod(p.dims)).Sum(); | |||
return total; | |||
} | |||
public static Tensors get_source_inputs(Tensor tensor, ILayer layer = null, int node_index = -1) | |||
{ | |||
if (layer == null) | |||
(layer, node_index, _) = tensor.KerasHistory; | |||
if (layer.InboundNodes == null || layer.InboundNodes.Count == 0) | |||
return tensor; | |||
else | |||
{ | |||
var node = layer.InboundNodes[node_index]; | |||
if (node.is_input) | |||
return node.input_tensors; | |||
else | |||
{ | |||
var source_tensors = new List<Tensor>(); | |||
foreach (var _layer in node.iterate_inbound()) | |||
{ | |||
(layer, node_index, tensor) = (_layer.Item1, _layer.Item2, _layer.Item4); | |||
var previous_sources = get_source_inputs(tensor, layer, node_index); | |||
foreach(var x in previous_sources) | |||
{ | |||
// should be check if exist? | |||
source_tensors.append(x); | |||
} | |||
} | |||
return source_tensors; | |||
} | |||
} | |||
} | |||
} | |||
} |
@@ -24,7 +24,7 @@ More information about [System.Drawing on Linux](<https://www.hanselman.com/blog | |||
Before running verify you installed CUDA and cuDNN (TensorFlow v1.15 is compatible with CUDA v10.0 and cuDNN v7.4 , TensorFlow v2.x is compatible with CUDA v10.2 and cuDNN v7.65), and make sure the corresponding cuda version is compatible. | |||
#### Mac OS | |||
There is no GPU support for macOS. | |||
There is no GPU support for macOS, in the future TensorFlow will support [Apple M1 chip](https://github.com/apple/tensorflow_macos). | |||
#### GPU for Windows | |||
@@ -37,9 +37,11 @@ PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU | |||
PM> Install-Package SciSharp.TensorFlow.Redist-Linux-GPU | |||
``` | |||
Since NuGet limits file size for 250M, we can't ship Linux GPU version as NuGet, you can download the library from [Google TensorFlow Storage](https://storage.googleapis.com/tensorflow). | |||
### Download prebuild binary manually | |||
Tensorflow packages are built nightly and uploaded to GCS for all supported platforms. They are uploaded to the [libtensorflow-nightly](https://www.tensorflow.org/install/lang_c) GCS bucket and are indexed by operating system and date built. | |||
TensorFlow packages are built nightly and uploaded to GCS for all supported platforms. They are uploaded to the [libtensorflow-nightly](https://www.tensorflow.org/install/lang_c) GCS bucket and are indexed by operating system and date built. | |||
### Build from source for Windows | |||
@@ -28,7 +28,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||
public void decode_image() | |||
{ | |||
var img = tf.image.decode_image(contents); | |||
Assert.AreEqual(img.name, "decode_image/cond_jpeg/Merge:0"); | |||
Assert.AreEqual(img.name, "decode_image/Identity:0"); | |||
} | |||
[TestMethod] | |||
@@ -57,7 +57,7 @@ namespace TensorFlowNET.UnitTest.Keras | |||
{ 2, 3, 4, 5 }, | |||
{ 3, 4, 5, 6 } | |||
}); | |||
model.compile("rmsprop", "mse"); | |||
// model.compile("rmsprop", "mse"); | |||
var output_array = model.predict(input_array); | |||
Assert.AreEqual((32, 10, 64), output_array.TensorShape); | |||
} | |||