Browse Source

Fix Sequential model.

tags/keras_v0.3.0
Oceania2018 Haiping 4 years ago
parent
commit
fd64ad1b44
21 changed files with 240 additions and 76 deletions
  1. +6
    -6
      README.md
  2. +9
    -1
      src/SciSharp.TensorFlow.Redist/README.md
  3. +1
    -1
      src/TensorFlowNET.Console/TensorFlowNET.Console.csproj
  4. +1
    -1
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  5. +33
    -13
      src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
  6. +35
    -0
      src/TensorFlowNET.Keras/Engine/DataAdapters/DatasetAdapter.cs
  7. +2
    -2
      src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
  8. +3
    -1
      src/TensorFlowNET.Keras/Engine/Functional.cs
  9. +3
    -7
      src/TensorFlowNET.Keras/Engine/Model.Compile.cs
  10. +44
    -0
      src/TensorFlowNET.Keras/Engine/Model.Fit.cs
  11. +1
    -1
      src/TensorFlowNET.Keras/Engine/Node.cs
  12. +30
    -35
      src/TensorFlowNET.Keras/Engine/Sequential.cs
  13. +5
    -0
      src/TensorFlowNET.Keras/KerasInterface.cs
  14. +5
    -0
      src/TensorFlowNET.Keras/Layers/Rescaling/Rescaling.cs
  15. +24
    -1
      src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs
  16. +1
    -1
      src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs
  17. +29
    -0
      src/TensorFlowNET.Keras/Utils/layer_utils.cs
  18. +4
    -2
      tensorflowlib/README.md
  19. +1
    -1
      test/TensorFlowNET.UnitTest/ImageTest.cs
  20. +1
    -1
      test/TensorFlowNET.UnitTest/Keras/LayersTest.cs
  21. +2
    -2
      test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj

+ 6
- 6
README.md View File

@@ -26,12 +26,12 @@ In comparison to other projects, like for instance [TensorFlowSharp](https://www


### How to use ### How to use


| TensorFlow | tf native1.14 | tf native 1.15 | tf native 2.3 |
| -------------------------- | ------------- | -------------- | ------------- |
| tf.net 0.3x, tf.keras 0.2 | | | x |
| tf.net 0.2x | | x | x |
| tf.net 0.15 | x | x | |
| tf.net 0.14 | x | | |
| TensorFlow | tf native1.14, cuda 10.0 | tf native 1.15, cuda 10.0 | tf native 2.3, cuda 10.1 | tf native 2.4, cuda 11 |
| -------------------------- | ------------- | -------------- | ------------- | ------------- |
| tf.net 0.3x, tf.keras 0.2 | | | x | not compatible |
| tf.net 0.2x | | x | x | |
| tf.net 0.15 | x | x | | |
| tf.net 0.14 | x | | | |


Troubleshooting of running example or installation, please refer [here](tensorflowlib/README.md). Troubleshooting of running example or installation, please refer [here](tensorflowlib/README.md).




+ 9
- 1
src/SciSharp.TensorFlow.Redist/README.md View File

@@ -22,11 +22,19 @@ https://www.nuget.org/packages/SciSharp.TensorFlow.Redist


Related merged [commits](https://github.com/SciSharp/TensorFlow.NET/commit/854a5ba61ad0e400623821236bd117cc24c6cb77). Related merged [commits](https://github.com/SciSharp/TensorFlow.NET/commit/854a5ba61ad0e400623821236bd117cc24c6cb77).




#### Download pre-build package

[Mac OSX CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-2.4.0.tar.gz), [Linux CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.4.0.tar.gz), [Linux GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-2.4.0.tar.gz), [Windows CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-2.4.0.tar.gz), [Windows GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-windows-x86_64-2.4.0.zip)



#### Pack and Deploy #### #### Pack and Deploy ####


On Windows, the tar command does not support extracting archives with symlinks. So when `dotnet pack` runs on Windows it will only package the Windows binaries. On Windows, the tar command does not support extracting archives with symlinks. So when `dotnet pack` runs on Windows it will only package the Windows binaries.


1. Run `dotnet pack SciSharp.TensorFlow.Redist.nupkgproj` under `src/SciSharp.TensorFlow.Redist` directory in Linux. 1. Run `dotnet pack SciSharp.TensorFlow.Redist.nupkgproj` under `src/SciSharp.TensorFlow.Redist` directory in Linux.
2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.3.1.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600`
2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.4.0.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600`





+ 1
- 1
src/TensorFlowNET.Console/TensorFlowNET.Console.csproj View File

@@ -8,7 +8,7 @@
</PropertyGroup> </PropertyGroup>


<ItemGroup> <ItemGroup>
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.0" />
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.1" />
</ItemGroup> </ItemGroup>


<ItemGroup> <ItemGroup>


+ 1
- 1
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -574,7 +574,7 @@ would not be rank 1.", tensor.op.get_attr("axis")));
return string.Join(string.Empty, nd.ToArray<byte>() return string.Join(string.Empty, nd.ToArray<byte>()
.Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString())); .Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString()));
case TF_DataType.TF_BOOL: case TF_DataType.TF_BOOL:
return (nd.GetByte(0) > 0).ToString();
return nd.GetBoolean(0).ToString();
case TF_DataType.TF_VARIANT: case TF_DataType.TF_VARIANT:
case TF_DataType.TF_RESOURCE: case TF_DataType.TF_RESOURCE:
return "<unprintable>"; return "<unprintable>";


+ 33
- 13
src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs View File

@@ -37,19 +37,38 @@ namespace Tensorflow.Keras.Engine.DataAdapters
_steps_per_execution_value = args.StepsPerExecution.numpy(); _steps_per_execution_value = args.StepsPerExecution.numpy();
} }


_adapter = new TensorLikeDataAdapter(new TensorLikeDataAdapterArgs
if(args.Dataset == null)
{ {
X = args.X,
Y = args.Y,
BatchSize = args.BatchSize,
Steps = args.StepsPerEpoch,
Epochs = args.Epochs - args.InitialEpoch,
Shuffle = args.Shuffle,
MaxQueueSize = args.MaxQueueSize,
Worker = args.Workers,
UseMultiprocessing = args.UseMultiprocessing,
Model = args.Model
});
_adapter = new TensorLikeDataAdapter(new DataAdapterArgs
{
X = args.X,
Y = args.Y,
BatchSize = args.BatchSize,
Steps = args.StepsPerEpoch,
Epochs = args.Epochs - args.InitialEpoch,
Shuffle = args.Shuffle,
MaxQueueSize = args.MaxQueueSize,
Worker = args.Workers,
UseMultiprocessing = args.UseMultiprocessing,
Model = args.Model
});
}
else
{
_adapter = new DatasetAdapter(new DataAdapterArgs
{
Dataset = args.Dataset,
BatchSize = args.BatchSize,
Steps = args.StepsPerEpoch,
Epochs = args.Epochs - args.InitialEpoch,
Shuffle = args.Shuffle,
MaxQueueSize = args.MaxQueueSize,
Worker = args.Workers,
UseMultiprocessing = args.UseMultiprocessing,
Model = args.Model
});
}
_dataset = _adapter.GetDataset(); _dataset = _adapter.GetDataset();
_inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset); _inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset);
_current_step = 0; _current_step = 0;
@@ -66,7 +85,8 @@ namespace Tensorflow.Keras.Engine.DataAdapters
if (adapter_steps > -1) if (adapter_steps > -1)
return adapter_steps; return adapter_steps;


throw new NotImplementedException("");
var size = dataset.dataset_cardinality();
return size.numpy();
} }


public IEnumerable<(int, OwnedIterator)> enumerate_epochs() public IEnumerable<(int, OwnedIterator)> enumerate_epochs()


+ 35
- 0
src/TensorFlowNET.Keras/Engine/DataAdapters/DatasetAdapter.cs View File

@@ -0,0 +1,35 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;

namespace Tensorflow.Keras.Engine.DataAdapters
{
public class DatasetAdapter : IDataAdapter
{
DataAdapterArgs args;
IDatasetV2 _dataset => args.Dataset;
public DatasetAdapter(DataAdapterArgs args)
{
this.args = args;
}

public bool CanHandle(Tensor x, Tensor y = null)
{
throw new NotImplementedException();
}

public IDatasetV2 GetDataset()
=> _dataset;

public int GetSize()
=> -1;

public (Tensor, Tensor) Expand1d(Tensor x, Tensor y)
{
if (y.TensorShape.ndim == 1)
y = array_ops.expand_dims(y, axis: -1);
return (x, y);
}
}
}

+ 2
- 2
src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs View File

@@ -9,14 +9,14 @@ namespace Tensorflow.Keras.Engine.DataAdapters
/// </summary> /// </summary>
public class TensorLikeDataAdapter : IDataAdapter public class TensorLikeDataAdapter : IDataAdapter
{ {
TensorLikeDataAdapterArgs args;
DataAdapterArgs args;
int _size; int _size;
int _batch_size; int _batch_size;
int num_samples; int num_samples;
int num_full_batches; int num_full_batches;
IDatasetV2 _dataset; IDatasetV2 _dataset;


public TensorLikeDataAdapter(TensorLikeDataAdapterArgs args)
public TensorLikeDataAdapter(DataAdapterArgs args)
{ {
this.args = args; this.args = args;
_process_tensorlike(); _process_tensorlike();


+ 3
- 1
src/TensorFlowNET.Keras/Engine/Functional.cs View File

@@ -39,10 +39,12 @@ namespace Tensorflow.Keras.Engine
_input_coordinates = new List<KerasHistory>(); _input_coordinates = new List<KerasHistory>();
_output_coordinates = new List<KerasHistory>(); _output_coordinates = new List<KerasHistory>();
tensor_usage_count = new Dictionary<int, int>(); tensor_usage_count = new Dictionary<int, int>();
if (this is Sequential)
return;
_init_graph_network(inputs, outputs); _init_graph_network(inputs, outputs);
} }


void _init_graph_network(Tensors inputs, Tensors outputs)
protected void _init_graph_network(Tensors inputs, Tensors outputs)
{ {
_is_graph_network = true; _is_graph_network = true;
this.inputs = inputs; this.inputs = inputs;


+ 3
- 7
src/TensorFlowNET.Keras/Engine/Model.Compile.cs View File

@@ -9,10 +9,6 @@ namespace Tensorflow.Keras.Engine
{ {
LossesContainer compiled_loss; LossesContainer compiled_loss;
MetricsContainer compiled_metrics; MetricsContainer compiled_metrics;
public void compile(string optimizerName, ILossFunc lossName)
{
throw new NotImplementedException("");
}


public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics) public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics)
{ {
@@ -29,12 +25,12 @@ namespace Tensorflow.Keras.Engine
this.loss = loss; this.loss = loss;
} }


public void compile(string optimizerName, string lossName)
public void compile(string optimizer, string loss, string[] metrics)
{ {
switch (optimizerName)
switch (optimizer)
{ {
case "rmsprop": case "rmsprop":
optimizer = new RMSprop(new RMSpropArgs
this.optimizer = new RMSprop(new RMSpropArgs
{ {


}); });


+ 44
- 0
src/TensorFlowNET.Keras/Engine/Model.Fit.cs View File

@@ -68,5 +68,49 @@ namespace Tensorflow.Keras.Engine
Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}"))); Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
} }
} }

public void fit(IDatasetV2 dataset,
IDatasetV2 validation_data = null,
int batch_size = -1,
int epochs = 1,
int verbose = 1,
float validation_split = 0f,
bool shuffle = true,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
bool use_multiprocessing = false)
{
data_handler = new DataHandler(new DataHandlerArgs
{
Dataset = dataset,
BatchSize = batch_size,
InitialEpoch = initial_epoch,
Epochs = epochs,
Shuffle = shuffle,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
Model = this,
StepsPerExecution = _steps_per_execution
});

stop_training = false;
_train_counter.assign(0);
Console.WriteLine($"Training...");
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
// reset_metrics();
// callbacks.on_epoch_begin(epoch)
// data_handler.catch_stop_iteration();
IEnumerable<(string, Tensor)> results = null;
foreach (var step in data_handler.steps())
{
// callbacks.on_train_batch_begin(step)
results = step_function(iterator);
}
Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
}
}
} }
} }

+ 1
- 1
src/TensorFlowNET.Keras/Engine/Node.cs View File

@@ -35,7 +35,7 @@ namespace Tensorflow.Keras.Engine


public int[] node_indices; public int[] node_indices;
public int[] tensor_indices; public int[] tensor_indices;
public Tensors input_tensors => args.InputTensors;
public Tensors input_tensors => is_input ? Outputs : args.InputTensors;
public Tensors Outputs => args.Outputs; public Tensors Outputs => args.Outputs;
public TensorShape[] input_shapes; public TensorShape[] input_shapes;
public TensorShape[] output_shapes; public TensorShape[] output_shapes;


+ 30
- 35
src/TensorFlowNET.Keras/Engine/Sequential.cs View File

@@ -17,6 +17,7 @@
using System.Collections.Generic; using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Layers; using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Utils;
using static Tensorflow.KerasApi; using static Tensorflow.KerasApi;


namespace Tensorflow.Keras.Engine namespace Tensorflow.Keras.Engine
@@ -25,36 +26,40 @@ namespace Tensorflow.Keras.Engine
/// `Sequential` groups a linear stack of layers into a `tf.keras.Model`. /// `Sequential` groups a linear stack of layers into a `tf.keras.Model`.
/// `Sequential` provides training and inference features on this model. /// `Sequential` provides training and inference features on this model.
/// </summary> /// </summary>
public class Sequential : Model
public class Sequential : Functional
{ {
SequentialArgs args; SequentialArgs args;
bool _is_graph_network; bool _is_graph_network;
Tensor inputs;
Tensor outputs;
bool computeOutputAndMaskJointly;
bool autoTrackSubLayers;
TensorShape inferredInputShape;
bool hasExplicitInputShape;
TF_DataType inputDType;
List<ILayer> layers => args.Layers;
public TensorShape output_shape => outputs.TensorShape;
Tensors inputs;
Tensors outputs;
bool _compute_output_and_mask_jointly;
bool _auto_track_sub_layers;
TensorShape _inferred_input_shape;
bool _has_explicit_input_shape;
TF_DataType _input_dtype;
public TensorShape output_shape => outputs[0].TensorShape;
bool built = false; bool built = false;


public Sequential(SequentialArgs args) public Sequential(SequentialArgs args)
: base(new ModelArgs
{
Name = args.Name
})
: base(args.Inputs, args.Outputs, name: args.Name)
{ {
this.args = args; this.args = args;
if (args.Layers == null) if (args.Layers == null)
args.Layers = new List<ILayer>(); args.Layers = new List<ILayer>();
// SupportsMasking = true; // SupportsMasking = true;
computeOutputAndMaskJointly = true;
autoTrackSubLayers = false;
hasExplicitInputShape = false;
_compute_output_and_mask_jointly = true;
_auto_track_sub_layers = false;
_has_explicit_input_shape = false;
_is_graph_network = false; _is_graph_network = false;

// Add to the model any layers passed to the constructor.
if (args.Layers != null)
{
foreach (var layer in args.Layers)
add(layer as Layer);
}
} }


public void add(Tensor tensor) public void add(Tensor tensor)
@@ -71,7 +76,7 @@ namespace Tensorflow.Keras.Engine
{ {
built = false; built = false;
var set_inputs = false; var set_inputs = false;
if (layers.Count == 0)
if (_layers.Count == 0)
{ {
if (layer is InputLayer) if (layer is InputLayer)
{ {
@@ -83,7 +88,7 @@ namespace Tensorflow.Keras.Engine
{ {
// Instantiate an input layer. // Instantiate an input layer.
var x = keras.Input( var x = keras.Input(
shape: layer.BatchInputShape,
batch_input_shape: layer.BatchInputShape,
dtype: layer.DType, dtype: layer.DType,
name: layer.Name + "_input"); name: layer.Name + "_input");


@@ -99,36 +104,26 @@ namespace Tensorflow.Keras.Engine
{ {
// If an input layer (placeholder) is available. // If an input layer (placeholder) is available.
outputs = layer.InboundNodes[^1].Outputs; outputs = layer.InboundNodes[^1].Outputs;
inputs = layer_utils.get_source_inputs(outputs[0]);
built = true;
_has_explicit_input_shape = true;
} }

} }
else if (outputs != null) else if (outputs != null)
{ {
outputs = layer.Apply(outputs); outputs = layer.Apply(outputs);
built = true;
} }


if (set_inputs || _is_graph_network) if (set_inputs || _is_graph_network)
{ {
_init_graph_network(inputs, outputs); _init_graph_network(inputs, outputs);
_is_graph_network = true;
} }
else else
{ {


} }
} }

void _init_graph_network(Tensor inputs, Tensor outputs)
{
_is_graph_network = true;
this.inputs = inputs;
this.outputs = outputs;
built = true;
_map_graph_network(inputs, outputs);
}

void _map_graph_network(Tensor inputs, Tensor outputs)
{
layers.add(outputs.KerasHistory.Layer);
}
} }
} }

+ 5
- 0
src/TensorFlowNET.Keras/KerasInterface.cs View File

@@ -62,16 +62,21 @@ namespace Tensorflow.Keras
/// <returns></returns> /// <returns></returns>
public Tensor Input(TensorShape shape = null, public Tensor Input(TensorShape shape = null,
int batch_size = -1, int batch_size = -1,
TensorShape batch_input_shape = null,
TF_DataType dtype = TF_DataType.DtInvalid, TF_DataType dtype = TF_DataType.DtInvalid,
string name = null, string name = null,
bool sparse = false, bool sparse = false,
bool ragged = false, bool ragged = false,
Tensor tensor = null) Tensor tensor = null)
{ {
if (batch_input_shape != null)
shape = batch_input_shape.dims[1..];

var args = new InputLayerArgs var args = new InputLayerArgs
{ {
Name = name, Name = name,
InputShape = shape, InputShape = shape,
BatchInputShape = batch_input_shape,
BatchSize = batch_size, BatchSize = batch_size,
DType = dtype, DType = dtype,
Sparse = sparse, Sparse = sparse,


+ 5
- 0
src/TensorFlowNET.Keras/Layers/Rescaling/Rescaling.cs View File

@@ -23,5 +23,10 @@ namespace Tensorflow.Keras.Layers
offset = math_ops.cast(args.Offset, args.DType); offset = math_ops.cast(args.Offset, args.DType);
return math_ops.cast(inputs, args.DType) * scale + offset; return math_ops.cast(inputs, args.DType) * scale + offset;
} }

public override TensorShape ComputeOutputShape(TensorShape input_shape)
{
return input_shape;
}
} }
} }

+ 24
- 1
src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs View File

@@ -1,4 +1,5 @@
using System; using System;
using Tensorflow.Framework;
using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils; using Tensorflow.Keras.Utils;
@@ -15,6 +16,7 @@ namespace Tensorflow.Keras.Layers
public Flatten(FlattenArgs args) public Flatten(FlattenArgs args)
: base(args) : base(args)
{ {
this.args = args;
args.DataFormat = conv_utils.normalize_data_format(args.DataFormat); args.DataFormat = conv_utils.normalize_data_format(args.DataFormat);
input_spec = new InputSpec(min_ndim: 1); input_spec = new InputSpec(min_ndim: 1);
_channels_first = args.DataFormat == "channels_first"; _channels_first = args.DataFormat == "channels_first";
@@ -31,8 +33,29 @@ namespace Tensorflow.Keras.Layers
{ {
return array_ops.reshape(inputs, new[] { inputs.shape[0], -1 }); return array_ops.reshape(inputs, new[] { inputs.shape[0], -1 });
} }
else
{
var input_shape = inputs.shape;
var rank = inputs.shape.rank;
if (rank == 1)
return array_ops.expand_dims(inputs, axis: 1);
var batch_dim = tensor_shape.dimension_value(input_shape[0]);
if (batch_dim != -1)
{
return array_ops.reshape(inputs, new[] { batch_dim, -1 });
}


throw new NotImplementedException("");
var non_batch_dims = ((int[])input_shape)[1..];
var num = 1;
if (non_batch_dims.Length > 0)
{
for (var i = 0; i < non_batch_dims.Length; i++)
{
num *= non_batch_dims[i];
}
}
return array_ops.reshape(inputs, new[] { inputs.shape[0], num });
}
} }
} }
} }

+ 1
- 1
src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs View File

@@ -31,7 +31,7 @@ namespace Tensorflow.Keras
img = tf.image.decode_image( img = tf.image.decode_image(
img, channels: num_channels, expand_animations: false); img, channels: num_channels, expand_animations: false);
img = tf.image.resize_images_v2(img, image_size, method: interpolation); img = tf.image.resize_images_v2(img, image_size, method: interpolation);
img.set_shape((image_size[0], image_size[1], num_channels));
// img.set_shape((image_size[0], image_size[1], num_channels));
return img; return img;
} }
} }


+ 29
- 0
src/TensorFlowNET.Keras/Utils/layer_utils.cs View File

@@ -187,5 +187,34 @@ namespace Tensorflow.Keras.Utils
var total = weight_shapes.Select(p => (int)np.prod(p.dims)).Sum(); var total = weight_shapes.Select(p => (int)np.prod(p.dims)).Sum();
return total; return total;
} }

public static Tensors get_source_inputs(Tensor tensor, ILayer layer = null, int node_index = -1)
{
if (layer == null)
(layer, node_index, _) = tensor.KerasHistory;
if (layer.InboundNodes == null || layer.InboundNodes.Count == 0)
return tensor;
else
{
var node = layer.InboundNodes[node_index];
if (node.is_input)
return node.input_tensors;
else
{
var source_tensors = new List<Tensor>();
foreach (var _layer in node.iterate_inbound())
{
(layer, node_index, tensor) = (_layer.Item1, _layer.Item2, _layer.Item4);
var previous_sources = get_source_inputs(tensor, layer, node_index);
foreach(var x in previous_sources)
{
// should be check if exist?
source_tensors.append(x);
}
}
return source_tensors;
}
}
}
} }
} }

+ 4
- 2
tensorflowlib/README.md View File

@@ -24,7 +24,7 @@ More information about [System.Drawing on Linux](<https://www.hanselman.com/blog
Before running verify you installed CUDA and cuDNN (TensorFlow v1.15 is compatible with CUDA v10.0 and cuDNN v7.4 , TensorFlow v2.x is compatible with CUDA v10.2 and cuDNN v7.65), and make sure the corresponding cuda version is compatible. Before running verify you installed CUDA and cuDNN (TensorFlow v1.15 is compatible with CUDA v10.0 and cuDNN v7.4 , TensorFlow v2.x is compatible with CUDA v10.2 and cuDNN v7.65), and make sure the corresponding cuda version is compatible.


#### Mac OS #### Mac OS
There is no GPU support for macOS.
There is no GPU support for macOS, in the future TensorFlow will support [Apple M1 chip](https://github.com/apple/tensorflow_macos).


#### GPU for Windows #### GPU for Windows


@@ -37,9 +37,11 @@ PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU
PM> Install-Package SciSharp.TensorFlow.Redist-Linux-GPU PM> Install-Package SciSharp.TensorFlow.Redist-Linux-GPU
``` ```


Since NuGet limits file size for 250M, we can't ship Linux GPU version as NuGet, you can download the library from [Google TensorFlow Storage](https://storage.googleapis.com/tensorflow).

### Download prebuild binary manually ### Download prebuild binary manually


Tensorflow packages are built nightly and uploaded to GCS for all supported platforms. They are uploaded to the [libtensorflow-nightly](https://www.tensorflow.org/install/lang_c) GCS bucket and are indexed by operating system and date built.
TensorFlow packages are built nightly and uploaded to GCS for all supported platforms. They are uploaded to the [libtensorflow-nightly](https://www.tensorflow.org/install/lang_c) GCS bucket and are indexed by operating system and date built.




### Build from source for Windows ### Build from source for Windows


+ 1
- 1
test/TensorFlowNET.UnitTest/ImageTest.cs View File

@@ -28,7 +28,7 @@ namespace TensorFlowNET.UnitTest.Basics
public void decode_image() public void decode_image()
{ {
var img = tf.image.decode_image(contents); var img = tf.image.decode_image(contents);
Assert.AreEqual(img.name, "decode_image/cond_jpeg/Merge:0");
Assert.AreEqual(img.name, "decode_image/Identity:0");
} }


[TestMethod] [TestMethod]


+ 1
- 1
test/TensorFlowNET.UnitTest/Keras/LayersTest.cs View File

@@ -57,7 +57,7 @@ namespace TensorFlowNET.UnitTest.Keras
{ 2, 3, 4, 5 }, { 2, 3, 4, 5 },
{ 3, 4, 5, 6 } { 3, 4, 5, 6 }
}); });
model.compile("rmsprop", "mse");
// model.compile("rmsprop", "mse");
var output_array = model.predict(input_array); var output_array = model.predict(input_array);
Assert.AreEqual((32, 10, 64), output_array.TensorShape); Assert.AreEqual((32, 10, 64), output_array.TensorShape);
} }


+ 2
- 2
test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj View File

@@ -48,10 +48,10 @@
<ItemGroup> <ItemGroup>
<PackageReference Include="FluentAssertions" Version="5.10.3" /> <PackageReference Include="FluentAssertions" Version="5.10.3" />
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" /> <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.7.1" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.8.3" />
<PackageReference Include="MSTest.TestAdapter" Version="2.1.2" /> <PackageReference Include="MSTest.TestAdapter" Version="2.1.2" />
<PackageReference Include="MSTest.TestFramework" Version="2.1.2" /> <PackageReference Include="MSTest.TestFramework" Version="2.1.2" />
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.0" />
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.1" />
</ItemGroup> </ItemGroup>


<ItemGroup> <ItemGroup>


Loading…
Cancel
Save