From fd64ad1b446bb2ef3845faa6b28e3c834047c905 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 19 Dec 2020 08:33:52 -0600 Subject: [PATCH] Fix Sequential model. --- README.md | 12 ++-- src/SciSharp.TensorFlow.Redist/README.md | 10 ++- .../TensorFlowNET.Console.csproj | 2 +- src/TensorFlowNET.Core/Tensors/tensor_util.cs | 2 +- .../Engine/DataAdapters/DataHandler.cs | 46 +++++++++---- .../Engine/DataAdapters/DatasetAdapter.cs | 35 ++++++++++ .../DataAdapters/TensorLikeDataAdapter.cs | 4 +- src/TensorFlowNET.Keras/Engine/Functional.cs | 4 +- .../Engine/Model.Compile.cs | 10 +-- src/TensorFlowNET.Keras/Engine/Model.Fit.cs | 44 +++++++++++++ src/TensorFlowNET.Keras/Engine/Node.cs | 2 +- src/TensorFlowNET.Keras/Engine/Sequential.cs | 65 +++++++++---------- src/TensorFlowNET.Keras/KerasInterface.cs | 5 ++ .../Layers/Rescaling/Rescaling.cs | 5 ++ .../Layers/Reshaping/Flatten.cs | 25 ++++++- ...eprocessing.paths_and_labels_to_dataset.cs | 2 +- src/TensorFlowNET.Keras/Utils/layer_utils.cs | 29 +++++++++ tensorflowlib/README.md | 6 +- test/TensorFlowNET.UnitTest/ImageTest.cs | 2 +- .../Keras/LayersTest.cs | 2 +- .../Tensorflow.UnitTest.csproj | 4 +- 21 files changed, 240 insertions(+), 76 deletions(-) create mode 100644 src/TensorFlowNET.Keras/Engine/DataAdapters/DatasetAdapter.cs diff --git a/README.md b/README.md index 4cb56684..2fd46c6b 100644 --- a/README.md +++ b/README.md @@ -26,12 +26,12 @@ In comparison to other projects, like for instance [TensorFlowSharp](https://www ### How to use -| TensorFlow | tf native1.14 | tf native 1.15 | tf native 2.3 | -| -------------------------- | ------------- | -------------- | ------------- | -| tf.net 0.3x, tf.keras 0.2 | | | x | -| tf.net 0.2x | | x | x | -| tf.net 0.15 | x | x | | -| tf.net 0.14 | x | | | +| TensorFlow | tf native1.14, cuda 10.0 | tf native 1.15, cuda 10.0 | tf native 2.3, cuda 10.1 | tf native 2.4, cuda 11 | +| -------------------------- | ------------- | -------------- | ------------- | ------------- | +| tf.net 0.3x, tf.keras 0.2 | | | x | not compatible | +| tf.net 0.2x | | x | x | | +| tf.net 0.15 | x | x | | | +| tf.net 0.14 | x | | | | Troubleshooting of running example or installation, please refer [here](tensorflowlib/README.md). diff --git a/src/SciSharp.TensorFlow.Redist/README.md b/src/SciSharp.TensorFlow.Redist/README.md index 26eec870..141bba35 100644 --- a/src/SciSharp.TensorFlow.Redist/README.md +++ b/src/SciSharp.TensorFlow.Redist/README.md @@ -22,11 +22,19 @@ https://www.nuget.org/packages/SciSharp.TensorFlow.Redist Related merged [commits](https://github.com/SciSharp/TensorFlow.NET/commit/854a5ba61ad0e400623821236bd117cc24c6cb77). + + +#### Download pre-build package + +[Mac OSX CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-2.4.0.tar.gz), [Linux CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.4.0.tar.gz), [Linux GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-2.4.0.tar.gz), [Windows CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-2.4.0.tar.gz), [Windows GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-windows-x86_64-2.4.0.zip) + + + #### Pack and Deploy #### On Windows, the tar command does not support extracting archives with symlinks. So when `dotnet pack` runs on Windows it will only package the Windows binaries. 1. Run `dotnet pack SciSharp.TensorFlow.Redist.nupkgproj` under `src/SciSharp.TensorFlow.Redist` directory in Linux. -2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.3.1.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600` +2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.4.0.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600` diff --git a/src/TensorFlowNET.Console/TensorFlowNET.Console.csproj b/src/TensorFlowNET.Console/TensorFlowNET.Console.csproj index 6cc631f4..a4fd2b27 100644 --- a/src/TensorFlowNET.Console/TensorFlowNET.Console.csproj +++ b/src/TensorFlowNET.Console/TensorFlowNET.Console.csproj @@ -8,7 +8,7 @@ - + diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index b17ee329..e2697aeb 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -574,7 +574,7 @@ would not be rank 1.", tensor.op.get_attr("axis"))); return string.Join(string.Empty, nd.ToArray() .Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString())); case TF_DataType.TF_BOOL: - return (nd.GetByte(0) > 0).ToString(); + return nd.GetBoolean(0).ToString(); case TF_DataType.TF_VARIANT: case TF_DataType.TF_RESOURCE: return ""; diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs index de5f4a8c..1bcb9c1d 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs @@ -37,19 +37,38 @@ namespace Tensorflow.Keras.Engine.DataAdapters _steps_per_execution_value = args.StepsPerExecution.numpy(); } - _adapter = new TensorLikeDataAdapter(new TensorLikeDataAdapterArgs + if(args.Dataset == null) { - X = args.X, - Y = args.Y, - BatchSize = args.BatchSize, - Steps = args.StepsPerEpoch, - Epochs = args.Epochs - args.InitialEpoch, - Shuffle = args.Shuffle, - MaxQueueSize = args.MaxQueueSize, - Worker = args.Workers, - UseMultiprocessing = args.UseMultiprocessing, - Model = args.Model - }); + _adapter = new TensorLikeDataAdapter(new DataAdapterArgs + { + X = args.X, + Y = args.Y, + BatchSize = args.BatchSize, + Steps = args.StepsPerEpoch, + Epochs = args.Epochs - args.InitialEpoch, + Shuffle = args.Shuffle, + MaxQueueSize = args.MaxQueueSize, + Worker = args.Workers, + UseMultiprocessing = args.UseMultiprocessing, + Model = args.Model + }); + } + else + { + _adapter = new DatasetAdapter(new DataAdapterArgs + { + Dataset = args.Dataset, + BatchSize = args.BatchSize, + Steps = args.StepsPerEpoch, + Epochs = args.Epochs - args.InitialEpoch, + Shuffle = args.Shuffle, + MaxQueueSize = args.MaxQueueSize, + Worker = args.Workers, + UseMultiprocessing = args.UseMultiprocessing, + Model = args.Model + }); + } + _dataset = _adapter.GetDataset(); _inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset); _current_step = 0; @@ -66,7 +85,8 @@ namespace Tensorflow.Keras.Engine.DataAdapters if (adapter_steps > -1) return adapter_steps; - throw new NotImplementedException(""); + var size = dataset.dataset_cardinality(); + return size.numpy(); } public IEnumerable<(int, OwnedIterator)> enumerate_epochs() diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DatasetAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DatasetAdapter.cs new file mode 100644 index 00000000..d5f9613f --- /dev/null +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DatasetAdapter.cs @@ -0,0 +1,35 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Keras.ArgsDefinition; + +namespace Tensorflow.Keras.Engine.DataAdapters +{ + public class DatasetAdapter : IDataAdapter + { + DataAdapterArgs args; + IDatasetV2 _dataset => args.Dataset; + public DatasetAdapter(DataAdapterArgs args) + { + this.args = args; + } + + public bool CanHandle(Tensor x, Tensor y = null) + { + throw new NotImplementedException(); + } + + public IDatasetV2 GetDataset() + => _dataset; + + public int GetSize() + => -1; + + public (Tensor, Tensor) Expand1d(Tensor x, Tensor y) + { + if (y.TensorShape.ndim == 1) + y = array_ops.expand_dims(y, axis: -1); + return (x, y); + } + } +} diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs index 13f634dd..ecf5cbf9 100644 --- a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs +++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs @@ -9,14 +9,14 @@ namespace Tensorflow.Keras.Engine.DataAdapters /// public class TensorLikeDataAdapter : IDataAdapter { - TensorLikeDataAdapterArgs args; + DataAdapterArgs args; int _size; int _batch_size; int num_samples; int num_full_batches; IDatasetV2 _dataset; - public TensorLikeDataAdapter(TensorLikeDataAdapterArgs args) + public TensorLikeDataAdapter(DataAdapterArgs args) { this.args = args; _process_tensorlike(); diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs index 56d0863a..b2b109ba 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.cs @@ -39,10 +39,12 @@ namespace Tensorflow.Keras.Engine _input_coordinates = new List(); _output_coordinates = new List(); tensor_usage_count = new Dictionary(); + if (this is Sequential) + return; _init_graph_network(inputs, outputs); } - void _init_graph_network(Tensors inputs, Tensors outputs) + protected void _init_graph_network(Tensors inputs, Tensors outputs) { _is_graph_network = true; this.inputs = inputs; diff --git a/src/TensorFlowNET.Keras/Engine/Model.Compile.cs b/src/TensorFlowNET.Keras/Engine/Model.Compile.cs index 003262d9..dd91a5de 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Compile.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Compile.cs @@ -9,10 +9,6 @@ namespace Tensorflow.Keras.Engine { LossesContainer compiled_loss; MetricsContainer compiled_metrics; - public void compile(string optimizerName, ILossFunc lossName) - { - throw new NotImplementedException(""); - } public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics) { @@ -29,12 +25,12 @@ namespace Tensorflow.Keras.Engine this.loss = loss; } - public void compile(string optimizerName, string lossName) + public void compile(string optimizer, string loss, string[] metrics) { - switch (optimizerName) + switch (optimizer) { case "rmsprop": - optimizer = new RMSprop(new RMSpropArgs + this.optimizer = new RMSprop(new RMSpropArgs { }); diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs index 951908f5..3c4960e7 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs @@ -68,5 +68,49 @@ namespace Tensorflow.Keras.Engine Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}"))); } } + + public void fit(IDatasetV2 dataset, + IDatasetV2 validation_data = null, + int batch_size = -1, + int epochs = 1, + int verbose = 1, + float validation_split = 0f, + bool shuffle = true, + int initial_epoch = 0, + int max_queue_size = 10, + int workers = 1, + bool use_multiprocessing = false) + { + data_handler = new DataHandler(new DataHandlerArgs + { + Dataset = dataset, + BatchSize = batch_size, + InitialEpoch = initial_epoch, + Epochs = epochs, + Shuffle = shuffle, + MaxQueueSize = max_queue_size, + Workers = workers, + UseMultiprocessing = use_multiprocessing, + Model = this, + StepsPerExecution = _steps_per_execution + }); + + stop_training = false; + _train_counter.assign(0); + Console.WriteLine($"Training..."); + foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) + { + // reset_metrics(); + // callbacks.on_epoch_begin(epoch) + // data_handler.catch_stop_iteration(); + IEnumerable<(string, Tensor)> results = null; + foreach (var step in data_handler.steps()) + { + // callbacks.on_train_batch_begin(step) + results = step_function(iterator); + } + Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}"))); + } + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Node.cs b/src/TensorFlowNET.Keras/Engine/Node.cs index d78e5533..fad58534 100644 --- a/src/TensorFlowNET.Keras/Engine/Node.cs +++ b/src/TensorFlowNET.Keras/Engine/Node.cs @@ -35,7 +35,7 @@ namespace Tensorflow.Keras.Engine public int[] node_indices; public int[] tensor_indices; - public Tensors input_tensors => args.InputTensors; + public Tensors input_tensors => is_input ? Outputs : args.InputTensors; public Tensors Outputs => args.Outputs; public TensorShape[] input_shapes; public TensorShape[] output_shapes; diff --git a/src/TensorFlowNET.Keras/Engine/Sequential.cs b/src/TensorFlowNET.Keras/Engine/Sequential.cs index 50974cf7..e2267d53 100644 --- a/src/TensorFlowNET.Keras/Engine/Sequential.cs +++ b/src/TensorFlowNET.Keras/Engine/Sequential.cs @@ -17,6 +17,7 @@ using System.Collections.Generic; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Utils; using static Tensorflow.KerasApi; namespace Tensorflow.Keras.Engine @@ -25,36 +26,40 @@ namespace Tensorflow.Keras.Engine /// `Sequential` groups a linear stack of layers into a `tf.keras.Model`. /// `Sequential` provides training and inference features on this model. /// - public class Sequential : Model + public class Sequential : Functional { SequentialArgs args; bool _is_graph_network; - Tensor inputs; - Tensor outputs; - - bool computeOutputAndMaskJointly; - bool autoTrackSubLayers; - TensorShape inferredInputShape; - bool hasExplicitInputShape; - TF_DataType inputDType; - List layers => args.Layers; - public TensorShape output_shape => outputs.TensorShape; + Tensors inputs; + Tensors outputs; + + bool _compute_output_and_mask_jointly; + bool _auto_track_sub_layers; + TensorShape _inferred_input_shape; + bool _has_explicit_input_shape; + TF_DataType _input_dtype; + + public TensorShape output_shape => outputs[0].TensorShape; bool built = false; public Sequential(SequentialArgs args) - : base(new ModelArgs - { - Name = args.Name - }) + : base(args.Inputs, args.Outputs, name: args.Name) { this.args = args; if (args.Layers == null) args.Layers = new List(); // SupportsMasking = true; - computeOutputAndMaskJointly = true; - autoTrackSubLayers = false; - hasExplicitInputShape = false; + _compute_output_and_mask_jointly = true; + _auto_track_sub_layers = false; + _has_explicit_input_shape = false; _is_graph_network = false; + + // Add to the model any layers passed to the constructor. + if (args.Layers != null) + { + foreach (var layer in args.Layers) + add(layer as Layer); + } } public void add(Tensor tensor) @@ -71,7 +76,7 @@ namespace Tensorflow.Keras.Engine { built = false; var set_inputs = false; - if (layers.Count == 0) + if (_layers.Count == 0) { if (layer is InputLayer) { @@ -83,7 +88,7 @@ namespace Tensorflow.Keras.Engine { // Instantiate an input layer. var x = keras.Input( - shape: layer.BatchInputShape, + batch_input_shape: layer.BatchInputShape, dtype: layer.DType, name: layer.Name + "_input"); @@ -99,36 +104,26 @@ namespace Tensorflow.Keras.Engine { // If an input layer (placeholder) is available. outputs = layer.InboundNodes[^1].Outputs; + inputs = layer_utils.get_source_inputs(outputs[0]); + built = true; + _has_explicit_input_shape = true; } - } else if (outputs != null) { outputs = layer.Apply(outputs); + built = true; } if (set_inputs || _is_graph_network) { _init_graph_network(inputs, outputs); + _is_graph_network = true; } else { } } - - void _init_graph_network(Tensor inputs, Tensor outputs) - { - _is_graph_network = true; - this.inputs = inputs; - this.outputs = outputs; - built = true; - _map_graph_network(inputs, outputs); - } - - void _map_graph_network(Tensor inputs, Tensor outputs) - { - layers.add(outputs.KerasHistory.Layer); - } } } diff --git a/src/TensorFlowNET.Keras/KerasInterface.cs b/src/TensorFlowNET.Keras/KerasInterface.cs index 6cb733d3..40519ac4 100644 --- a/src/TensorFlowNET.Keras/KerasInterface.cs +++ b/src/TensorFlowNET.Keras/KerasInterface.cs @@ -62,16 +62,21 @@ namespace Tensorflow.Keras /// public Tensor Input(TensorShape shape = null, int batch_size = -1, + TensorShape batch_input_shape = null, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool sparse = false, bool ragged = false, Tensor tensor = null) { + if (batch_input_shape != null) + shape = batch_input_shape.dims[1..]; + var args = new InputLayerArgs { Name = name, InputShape = shape, + BatchInputShape = batch_input_shape, BatchSize = batch_size, DType = dtype, Sparse = sparse, diff --git a/src/TensorFlowNET.Keras/Layers/Rescaling/Rescaling.cs b/src/TensorFlowNET.Keras/Layers/Rescaling/Rescaling.cs index f81ee161..7466685f 100644 --- a/src/TensorFlowNET.Keras/Layers/Rescaling/Rescaling.cs +++ b/src/TensorFlowNET.Keras/Layers/Rescaling/Rescaling.cs @@ -23,5 +23,10 @@ namespace Tensorflow.Keras.Layers offset = math_ops.cast(args.Offset, args.DType); return math_ops.cast(inputs, args.DType) * scale + offset; } + + public override TensorShape ComputeOutputShape(TensorShape input_shape) + { + return input_shape; + } } } diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs index 316cab8c..66235198 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs @@ -1,4 +1,5 @@ using System; +using Tensorflow.Framework; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Utils; @@ -15,6 +16,7 @@ namespace Tensorflow.Keras.Layers public Flatten(FlattenArgs args) : base(args) { + this.args = args; args.DataFormat = conv_utils.normalize_data_format(args.DataFormat); input_spec = new InputSpec(min_ndim: 1); _channels_first = args.DataFormat == "channels_first"; @@ -31,8 +33,29 @@ namespace Tensorflow.Keras.Layers { return array_ops.reshape(inputs, new[] { inputs.shape[0], -1 }); } + else + { + var input_shape = inputs.shape; + var rank = inputs.shape.rank; + if (rank == 1) + return array_ops.expand_dims(inputs, axis: 1); + var batch_dim = tensor_shape.dimension_value(input_shape[0]); + if (batch_dim != -1) + { + return array_ops.reshape(inputs, new[] { batch_dim, -1 }); + } - throw new NotImplementedException(""); + var non_batch_dims = ((int[])input_shape)[1..]; + var num = 1; + if (non_batch_dims.Length > 0) + { + for (var i = 0; i < non_batch_dims.Length; i++) + { + num *= non_batch_dims[i]; + } + } + return array_ops.reshape(inputs, new[] { inputs.shape[0], num }); + } } } } diff --git a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs index d1bf1d97..abf07735 100644 --- a/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs +++ b/src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs @@ -31,7 +31,7 @@ namespace Tensorflow.Keras img = tf.image.decode_image( img, channels: num_channels, expand_animations: false); img = tf.image.resize_images_v2(img, image_size, method: interpolation); - img.set_shape((image_size[0], image_size[1], num_channels)); + // img.set_shape((image_size[0], image_size[1], num_channels)); return img; } } diff --git a/src/TensorFlowNET.Keras/Utils/layer_utils.cs b/src/TensorFlowNET.Keras/Utils/layer_utils.cs index 34f553d4..4166fd51 100644 --- a/src/TensorFlowNET.Keras/Utils/layer_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/layer_utils.cs @@ -187,5 +187,34 @@ namespace Tensorflow.Keras.Utils var total = weight_shapes.Select(p => (int)np.prod(p.dims)).Sum(); return total; } + + public static Tensors get_source_inputs(Tensor tensor, ILayer layer = null, int node_index = -1) + { + if (layer == null) + (layer, node_index, _) = tensor.KerasHistory; + if (layer.InboundNodes == null || layer.InboundNodes.Count == 0) + return tensor; + else + { + var node = layer.InboundNodes[node_index]; + if (node.is_input) + return node.input_tensors; + else + { + var source_tensors = new List(); + foreach (var _layer in node.iterate_inbound()) + { + (layer, node_index, tensor) = (_layer.Item1, _layer.Item2, _layer.Item4); + var previous_sources = get_source_inputs(tensor, layer, node_index); + foreach(var x in previous_sources) + { + // should be check if exist? + source_tensors.append(x); + } + } + return source_tensors; + } + } + } } } diff --git a/tensorflowlib/README.md b/tensorflowlib/README.md index eb957b46..20d30f6f 100644 --- a/tensorflowlib/README.md +++ b/tensorflowlib/README.md @@ -24,7 +24,7 @@ More information about [System.Drawing on Linux]( Install-Package SciSharp.TensorFlow.Redist-Windows-GPU PM> Install-Package SciSharp.TensorFlow.Redist-Linux-GPU ``` +Since NuGet limits file size for 250M, we can't ship Linux GPU version as NuGet, you can download the library from [Google TensorFlow Storage](https://storage.googleapis.com/tensorflow). + ### Download prebuild binary manually -Tensorflow packages are built nightly and uploaded to GCS for all supported platforms. They are uploaded to the [libtensorflow-nightly](https://www.tensorflow.org/install/lang_c) GCS bucket and are indexed by operating system and date built. +TensorFlow packages are built nightly and uploaded to GCS for all supported platforms. They are uploaded to the [libtensorflow-nightly](https://www.tensorflow.org/install/lang_c) GCS bucket and are indexed by operating system and date built. ### Build from source for Windows diff --git a/test/TensorFlowNET.UnitTest/ImageTest.cs b/test/TensorFlowNET.UnitTest/ImageTest.cs index 48794650..b32c659e 100644 --- a/test/TensorFlowNET.UnitTest/ImageTest.cs +++ b/test/TensorFlowNET.UnitTest/ImageTest.cs @@ -28,7 +28,7 @@ namespace TensorFlowNET.UnitTest.Basics public void decode_image() { var img = tf.image.decode_image(contents); - Assert.AreEqual(img.name, "decode_image/cond_jpeg/Merge:0"); + Assert.AreEqual(img.name, "decode_image/Identity:0"); } [TestMethod] diff --git a/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs b/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs index 279a5db5..c6858ba0 100644 --- a/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs +++ b/test/TensorFlowNET.UnitTest/Keras/LayersTest.cs @@ -57,7 +57,7 @@ namespace TensorFlowNET.UnitTest.Keras { 2, 3, 4, 5 }, { 3, 4, 5, 6 } }); - model.compile("rmsprop", "mse"); + // model.compile("rmsprop", "mse"); var output_array = model.predict(input_array); Assert.AreEqual((32, 10, 64), output_array.TensorShape); } diff --git a/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj b/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj index 68b70eb4..4d63755b 100644 --- a/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj @@ -48,10 +48,10 @@ - + - +