@@ -26,12 +26,12 @@ In comparison to other projects, like for instance [TensorFlowSharp](https://www | |||||
### How to use | ### How to use | ||||
| TensorFlow | tf native1.14 | tf native 1.15 | tf native 2.3 | | |||||
| -------------------------- | ------------- | -------------- | ------------- | | |||||
| tf.net 0.3x, tf.keras 0.2 | | | x | | |||||
| tf.net 0.2x | | x | x | | |||||
| tf.net 0.15 | x | x | | | |||||
| tf.net 0.14 | x | | | | |||||
| TensorFlow | tf native1.14, cuda 10.0 | tf native 1.15, cuda 10.0 | tf native 2.3, cuda 10.1 | tf native 2.4, cuda 11 | | |||||
| -------------------------- | ------------- | -------------- | ------------- | ------------- | | |||||
| tf.net 0.3x, tf.keras 0.2 | | | x | not compatible | | |||||
| tf.net 0.2x | | x | x | | | |||||
| tf.net 0.15 | x | x | | | | |||||
| tf.net 0.14 | x | | | | | |||||
Troubleshooting of running example or installation, please refer [here](tensorflowlib/README.md). | Troubleshooting of running example or installation, please refer [here](tensorflowlib/README.md). | ||||
@@ -22,11 +22,19 @@ https://www.nuget.org/packages/SciSharp.TensorFlow.Redist | |||||
Related merged [commits](https://github.com/SciSharp/TensorFlow.NET/commit/854a5ba61ad0e400623821236bd117cc24c6cb77). | Related merged [commits](https://github.com/SciSharp/TensorFlow.NET/commit/854a5ba61ad0e400623821236bd117cc24c6cb77). | ||||
#### Download pre-build package | |||||
[Mac OSX CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-2.4.0.tar.gz), [Linux CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.4.0.tar.gz), [Linux GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-2.4.0.tar.gz), [Windows CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-2.4.0.tar.gz), [Windows GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-windows-x86_64-2.4.0.zip) | |||||
#### Pack and Deploy #### | #### Pack and Deploy #### | ||||
On Windows, the tar command does not support extracting archives with symlinks. So when `dotnet pack` runs on Windows it will only package the Windows binaries. | On Windows, the tar command does not support extracting archives with symlinks. So when `dotnet pack` runs on Windows it will only package the Windows binaries. | ||||
1. Run `dotnet pack SciSharp.TensorFlow.Redist.nupkgproj` under `src/SciSharp.TensorFlow.Redist` directory in Linux. | 1. Run `dotnet pack SciSharp.TensorFlow.Redist.nupkgproj` under `src/SciSharp.TensorFlow.Redist` directory in Linux. | ||||
2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.3.1.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600` | |||||
2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.4.0.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600` | |||||
@@ -8,7 +8,7 @@ | |||||
</PropertyGroup> | </PropertyGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.0" /> | |||||
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.1" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||
@@ -574,7 +574,7 @@ would not be rank 1.", tensor.op.get_attr("axis"))); | |||||
return string.Join(string.Empty, nd.ToArray<byte>() | return string.Join(string.Empty, nd.ToArray<byte>() | ||||
.Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString())); | .Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString())); | ||||
case TF_DataType.TF_BOOL: | case TF_DataType.TF_BOOL: | ||||
return (nd.GetByte(0) > 0).ToString(); | |||||
return nd.GetBoolean(0).ToString(); | |||||
case TF_DataType.TF_VARIANT: | case TF_DataType.TF_VARIANT: | ||||
case TF_DataType.TF_RESOURCE: | case TF_DataType.TF_RESOURCE: | ||||
return "<unprintable>"; | return "<unprintable>"; | ||||
@@ -37,19 +37,38 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
_steps_per_execution_value = args.StepsPerExecution.numpy(); | _steps_per_execution_value = args.StepsPerExecution.numpy(); | ||||
} | } | ||||
_adapter = new TensorLikeDataAdapter(new TensorLikeDataAdapterArgs | |||||
if(args.Dataset == null) | |||||
{ | { | ||||
X = args.X, | |||||
Y = args.Y, | |||||
BatchSize = args.BatchSize, | |||||
Steps = args.StepsPerEpoch, | |||||
Epochs = args.Epochs - args.InitialEpoch, | |||||
Shuffle = args.Shuffle, | |||||
MaxQueueSize = args.MaxQueueSize, | |||||
Worker = args.Workers, | |||||
UseMultiprocessing = args.UseMultiprocessing, | |||||
Model = args.Model | |||||
}); | |||||
_adapter = new TensorLikeDataAdapter(new DataAdapterArgs | |||||
{ | |||||
X = args.X, | |||||
Y = args.Y, | |||||
BatchSize = args.BatchSize, | |||||
Steps = args.StepsPerEpoch, | |||||
Epochs = args.Epochs - args.InitialEpoch, | |||||
Shuffle = args.Shuffle, | |||||
MaxQueueSize = args.MaxQueueSize, | |||||
Worker = args.Workers, | |||||
UseMultiprocessing = args.UseMultiprocessing, | |||||
Model = args.Model | |||||
}); | |||||
} | |||||
else | |||||
{ | |||||
_adapter = new DatasetAdapter(new DataAdapterArgs | |||||
{ | |||||
Dataset = args.Dataset, | |||||
BatchSize = args.BatchSize, | |||||
Steps = args.StepsPerEpoch, | |||||
Epochs = args.Epochs - args.InitialEpoch, | |||||
Shuffle = args.Shuffle, | |||||
MaxQueueSize = args.MaxQueueSize, | |||||
Worker = args.Workers, | |||||
UseMultiprocessing = args.UseMultiprocessing, | |||||
Model = args.Model | |||||
}); | |||||
} | |||||
_dataset = _adapter.GetDataset(); | _dataset = _adapter.GetDataset(); | ||||
_inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset); | _inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset); | ||||
_current_step = 0; | _current_step = 0; | ||||
@@ -66,7 +85,8 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
if (adapter_steps > -1) | if (adapter_steps > -1) | ||||
return adapter_steps; | return adapter_steps; | ||||
throw new NotImplementedException(""); | |||||
var size = dataset.dataset_cardinality(); | |||||
return size.numpy(); | |||||
} | } | ||||
public IEnumerable<(int, OwnedIterator)> enumerate_epochs() | public IEnumerable<(int, OwnedIterator)> enumerate_epochs() | ||||
@@ -0,0 +1,35 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
namespace Tensorflow.Keras.Engine.DataAdapters | |||||
{ | |||||
public class DatasetAdapter : IDataAdapter | |||||
{ | |||||
DataAdapterArgs args; | |||||
IDatasetV2 _dataset => args.Dataset; | |||||
public DatasetAdapter(DataAdapterArgs args) | |||||
{ | |||||
this.args = args; | |||||
} | |||||
public bool CanHandle(Tensor x, Tensor y = null) | |||||
{ | |||||
throw new NotImplementedException(); | |||||
} | |||||
public IDatasetV2 GetDataset() | |||||
=> _dataset; | |||||
public int GetSize() | |||||
=> -1; | |||||
public (Tensor, Tensor) Expand1d(Tensor x, Tensor y) | |||||
{ | |||||
if (y.TensorShape.ndim == 1) | |||||
y = array_ops.expand_dims(y, axis: -1); | |||||
return (x, y); | |||||
} | |||||
} | |||||
} |
@@ -9,14 +9,14 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
/// </summary> | /// </summary> | ||||
public class TensorLikeDataAdapter : IDataAdapter | public class TensorLikeDataAdapter : IDataAdapter | ||||
{ | { | ||||
TensorLikeDataAdapterArgs args; | |||||
DataAdapterArgs args; | |||||
int _size; | int _size; | ||||
int _batch_size; | int _batch_size; | ||||
int num_samples; | int num_samples; | ||||
int num_full_batches; | int num_full_batches; | ||||
IDatasetV2 _dataset; | IDatasetV2 _dataset; | ||||
public TensorLikeDataAdapter(TensorLikeDataAdapterArgs args) | |||||
public TensorLikeDataAdapter(DataAdapterArgs args) | |||||
{ | { | ||||
this.args = args; | this.args = args; | ||||
_process_tensorlike(); | _process_tensorlike(); | ||||
@@ -39,10 +39,12 @@ namespace Tensorflow.Keras.Engine | |||||
_input_coordinates = new List<KerasHistory>(); | _input_coordinates = new List<KerasHistory>(); | ||||
_output_coordinates = new List<KerasHistory>(); | _output_coordinates = new List<KerasHistory>(); | ||||
tensor_usage_count = new Dictionary<int, int>(); | tensor_usage_count = new Dictionary<int, int>(); | ||||
if (this is Sequential) | |||||
return; | |||||
_init_graph_network(inputs, outputs); | _init_graph_network(inputs, outputs); | ||||
} | } | ||||
void _init_graph_network(Tensors inputs, Tensors outputs) | |||||
protected void _init_graph_network(Tensors inputs, Tensors outputs) | |||||
{ | { | ||||
_is_graph_network = true; | _is_graph_network = true; | ||||
this.inputs = inputs; | this.inputs = inputs; | ||||
@@ -9,10 +9,6 @@ namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
LossesContainer compiled_loss; | LossesContainer compiled_loss; | ||||
MetricsContainer compiled_metrics; | MetricsContainer compiled_metrics; | ||||
public void compile(string optimizerName, ILossFunc lossName) | |||||
{ | |||||
throw new NotImplementedException(""); | |||||
} | |||||
public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics) | public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics) | ||||
{ | { | ||||
@@ -29,12 +25,12 @@ namespace Tensorflow.Keras.Engine | |||||
this.loss = loss; | this.loss = loss; | ||||
} | } | ||||
public void compile(string optimizerName, string lossName) | |||||
public void compile(string optimizer, string loss, string[] metrics) | |||||
{ | { | ||||
switch (optimizerName) | |||||
switch (optimizer) | |||||
{ | { | ||||
case "rmsprop": | case "rmsprop": | ||||
optimizer = new RMSprop(new RMSpropArgs | |||||
this.optimizer = new RMSprop(new RMSpropArgs | |||||
{ | { | ||||
}); | }); | ||||
@@ -68,5 +68,49 @@ namespace Tensorflow.Keras.Engine | |||||
Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}"))); | Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}"))); | ||||
} | } | ||||
} | } | ||||
public void fit(IDatasetV2 dataset, | |||||
IDatasetV2 validation_data = null, | |||||
int batch_size = -1, | |||||
int epochs = 1, | |||||
int verbose = 1, | |||||
float validation_split = 0f, | |||||
bool shuffle = true, | |||||
int initial_epoch = 0, | |||||
int max_queue_size = 10, | |||||
int workers = 1, | |||||
bool use_multiprocessing = false) | |||||
{ | |||||
data_handler = new DataHandler(new DataHandlerArgs | |||||
{ | |||||
Dataset = dataset, | |||||
BatchSize = batch_size, | |||||
InitialEpoch = initial_epoch, | |||||
Epochs = epochs, | |||||
Shuffle = shuffle, | |||||
MaxQueueSize = max_queue_size, | |||||
Workers = workers, | |||||
UseMultiprocessing = use_multiprocessing, | |||||
Model = this, | |||||
StepsPerExecution = _steps_per_execution | |||||
}); | |||||
stop_training = false; | |||||
_train_counter.assign(0); | |||||
Console.WriteLine($"Training..."); | |||||
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||||
{ | |||||
// reset_metrics(); | |||||
// callbacks.on_epoch_begin(epoch) | |||||
// data_handler.catch_stop_iteration(); | |||||
IEnumerable<(string, Tensor)> results = null; | |||||
foreach (var step in data_handler.steps()) | |||||
{ | |||||
// callbacks.on_train_batch_begin(step) | |||||
results = step_function(iterator); | |||||
} | |||||
Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}"))); | |||||
} | |||||
} | |||||
} | } | ||||
} | } |
@@ -35,7 +35,7 @@ namespace Tensorflow.Keras.Engine | |||||
public int[] node_indices; | public int[] node_indices; | ||||
public int[] tensor_indices; | public int[] tensor_indices; | ||||
public Tensors input_tensors => args.InputTensors; | |||||
public Tensors input_tensors => is_input ? Outputs : args.InputTensors; | |||||
public Tensors Outputs => args.Outputs; | public Tensors Outputs => args.Outputs; | ||||
public TensorShape[] input_shapes; | public TensorShape[] input_shapes; | ||||
public TensorShape[] output_shapes; | public TensorShape[] output_shapes; | ||||
@@ -17,6 +17,7 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Layers; | using Tensorflow.Keras.Layers; | ||||
using Tensorflow.Keras.Utils; | |||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
@@ -25,36 +26,40 @@ namespace Tensorflow.Keras.Engine | |||||
/// `Sequential` groups a linear stack of layers into a `tf.keras.Model`. | /// `Sequential` groups a linear stack of layers into a `tf.keras.Model`. | ||||
/// `Sequential` provides training and inference features on this model. | /// `Sequential` provides training and inference features on this model. | ||||
/// </summary> | /// </summary> | ||||
public class Sequential : Model | |||||
public class Sequential : Functional | |||||
{ | { | ||||
SequentialArgs args; | SequentialArgs args; | ||||
bool _is_graph_network; | bool _is_graph_network; | ||||
Tensor inputs; | |||||
Tensor outputs; | |||||
bool computeOutputAndMaskJointly; | |||||
bool autoTrackSubLayers; | |||||
TensorShape inferredInputShape; | |||||
bool hasExplicitInputShape; | |||||
TF_DataType inputDType; | |||||
List<ILayer> layers => args.Layers; | |||||
public TensorShape output_shape => outputs.TensorShape; | |||||
Tensors inputs; | |||||
Tensors outputs; | |||||
bool _compute_output_and_mask_jointly; | |||||
bool _auto_track_sub_layers; | |||||
TensorShape _inferred_input_shape; | |||||
bool _has_explicit_input_shape; | |||||
TF_DataType _input_dtype; | |||||
public TensorShape output_shape => outputs[0].TensorShape; | |||||
bool built = false; | bool built = false; | ||||
public Sequential(SequentialArgs args) | public Sequential(SequentialArgs args) | ||||
: base(new ModelArgs | |||||
{ | |||||
Name = args.Name | |||||
}) | |||||
: base(args.Inputs, args.Outputs, name: args.Name) | |||||
{ | { | ||||
this.args = args; | this.args = args; | ||||
if (args.Layers == null) | if (args.Layers == null) | ||||
args.Layers = new List<ILayer>(); | args.Layers = new List<ILayer>(); | ||||
// SupportsMasking = true; | // SupportsMasking = true; | ||||
computeOutputAndMaskJointly = true; | |||||
autoTrackSubLayers = false; | |||||
hasExplicitInputShape = false; | |||||
_compute_output_and_mask_jointly = true; | |||||
_auto_track_sub_layers = false; | |||||
_has_explicit_input_shape = false; | |||||
_is_graph_network = false; | _is_graph_network = false; | ||||
// Add to the model any layers passed to the constructor. | |||||
if (args.Layers != null) | |||||
{ | |||||
foreach (var layer in args.Layers) | |||||
add(layer as Layer); | |||||
} | |||||
} | } | ||||
public void add(Tensor tensor) | public void add(Tensor tensor) | ||||
@@ -71,7 +76,7 @@ namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
built = false; | built = false; | ||||
var set_inputs = false; | var set_inputs = false; | ||||
if (layers.Count == 0) | |||||
if (_layers.Count == 0) | |||||
{ | { | ||||
if (layer is InputLayer) | if (layer is InputLayer) | ||||
{ | { | ||||
@@ -83,7 +88,7 @@ namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
// Instantiate an input layer. | // Instantiate an input layer. | ||||
var x = keras.Input( | var x = keras.Input( | ||||
shape: layer.BatchInputShape, | |||||
batch_input_shape: layer.BatchInputShape, | |||||
dtype: layer.DType, | dtype: layer.DType, | ||||
name: layer.Name + "_input"); | name: layer.Name + "_input"); | ||||
@@ -99,36 +104,26 @@ namespace Tensorflow.Keras.Engine | |||||
{ | { | ||||
// If an input layer (placeholder) is available. | // If an input layer (placeholder) is available. | ||||
outputs = layer.InboundNodes[^1].Outputs; | outputs = layer.InboundNodes[^1].Outputs; | ||||
inputs = layer_utils.get_source_inputs(outputs[0]); | |||||
built = true; | |||||
_has_explicit_input_shape = true; | |||||
} | } | ||||
} | } | ||||
else if (outputs != null) | else if (outputs != null) | ||||
{ | { | ||||
outputs = layer.Apply(outputs); | outputs = layer.Apply(outputs); | ||||
built = true; | |||||
} | } | ||||
if (set_inputs || _is_graph_network) | if (set_inputs || _is_graph_network) | ||||
{ | { | ||||
_init_graph_network(inputs, outputs); | _init_graph_network(inputs, outputs); | ||||
_is_graph_network = true; | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
} | } | ||||
} | } | ||||
void _init_graph_network(Tensor inputs, Tensor outputs) | |||||
{ | |||||
_is_graph_network = true; | |||||
this.inputs = inputs; | |||||
this.outputs = outputs; | |||||
built = true; | |||||
_map_graph_network(inputs, outputs); | |||||
} | |||||
void _map_graph_network(Tensor inputs, Tensor outputs) | |||||
{ | |||||
layers.add(outputs.KerasHistory.Layer); | |||||
} | |||||
} | } | ||||
} | } |
@@ -62,16 +62,21 @@ namespace Tensorflow.Keras | |||||
/// <returns></returns> | /// <returns></returns> | ||||
public Tensor Input(TensorShape shape = null, | public Tensor Input(TensorShape shape = null, | ||||
int batch_size = -1, | int batch_size = -1, | ||||
TensorShape batch_input_shape = null, | |||||
TF_DataType dtype = TF_DataType.DtInvalid, | TF_DataType dtype = TF_DataType.DtInvalid, | ||||
string name = null, | string name = null, | ||||
bool sparse = false, | bool sparse = false, | ||||
bool ragged = false, | bool ragged = false, | ||||
Tensor tensor = null) | Tensor tensor = null) | ||||
{ | { | ||||
if (batch_input_shape != null) | |||||
shape = batch_input_shape.dims[1..]; | |||||
var args = new InputLayerArgs | var args = new InputLayerArgs | ||||
{ | { | ||||
Name = name, | Name = name, | ||||
InputShape = shape, | InputShape = shape, | ||||
BatchInputShape = batch_input_shape, | |||||
BatchSize = batch_size, | BatchSize = batch_size, | ||||
DType = dtype, | DType = dtype, | ||||
Sparse = sparse, | Sparse = sparse, | ||||
@@ -23,5 +23,10 @@ namespace Tensorflow.Keras.Layers | |||||
offset = math_ops.cast(args.Offset, args.DType); | offset = math_ops.cast(args.Offset, args.DType); | ||||
return math_ops.cast(inputs, args.DType) * scale + offset; | return math_ops.cast(inputs, args.DType) * scale + offset; | ||||
} | } | ||||
public override TensorShape ComputeOutputShape(TensorShape input_shape) | |||||
{ | |||||
return input_shape; | |||||
} | |||||
} | } | ||||
} | } |
@@ -1,4 +1,5 @@ | |||||
using System; | using System; | ||||
using Tensorflow.Framework; | |||||
using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
@@ -15,6 +16,7 @@ namespace Tensorflow.Keras.Layers | |||||
public Flatten(FlattenArgs args) | public Flatten(FlattenArgs args) | ||||
: base(args) | : base(args) | ||||
{ | { | ||||
this.args = args; | |||||
args.DataFormat = conv_utils.normalize_data_format(args.DataFormat); | args.DataFormat = conv_utils.normalize_data_format(args.DataFormat); | ||||
input_spec = new InputSpec(min_ndim: 1); | input_spec = new InputSpec(min_ndim: 1); | ||||
_channels_first = args.DataFormat == "channels_first"; | _channels_first = args.DataFormat == "channels_first"; | ||||
@@ -31,8 +33,29 @@ namespace Tensorflow.Keras.Layers | |||||
{ | { | ||||
return array_ops.reshape(inputs, new[] { inputs.shape[0], -1 }); | return array_ops.reshape(inputs, new[] { inputs.shape[0], -1 }); | ||||
} | } | ||||
else | |||||
{ | |||||
var input_shape = inputs.shape; | |||||
var rank = inputs.shape.rank; | |||||
if (rank == 1) | |||||
return array_ops.expand_dims(inputs, axis: 1); | |||||
var batch_dim = tensor_shape.dimension_value(input_shape[0]); | |||||
if (batch_dim != -1) | |||||
{ | |||||
return array_ops.reshape(inputs, new[] { batch_dim, -1 }); | |||||
} | |||||
throw new NotImplementedException(""); | |||||
var non_batch_dims = ((int[])input_shape)[1..]; | |||||
var num = 1; | |||||
if (non_batch_dims.Length > 0) | |||||
{ | |||||
for (var i = 0; i < non_batch_dims.Length; i++) | |||||
{ | |||||
num *= non_batch_dims[i]; | |||||
} | |||||
} | |||||
return array_ops.reshape(inputs, new[] { inputs.shape[0], num }); | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } |
@@ -31,7 +31,7 @@ namespace Tensorflow.Keras | |||||
img = tf.image.decode_image( | img = tf.image.decode_image( | ||||
img, channels: num_channels, expand_animations: false); | img, channels: num_channels, expand_animations: false); | ||||
img = tf.image.resize_images_v2(img, image_size, method: interpolation); | img = tf.image.resize_images_v2(img, image_size, method: interpolation); | ||||
img.set_shape((image_size[0], image_size[1], num_channels)); | |||||
// img.set_shape((image_size[0], image_size[1], num_channels)); | |||||
return img; | return img; | ||||
} | } | ||||
} | } | ||||
@@ -187,5 +187,34 @@ namespace Tensorflow.Keras.Utils | |||||
var total = weight_shapes.Select(p => (int)np.prod(p.dims)).Sum(); | var total = weight_shapes.Select(p => (int)np.prod(p.dims)).Sum(); | ||||
return total; | return total; | ||||
} | } | ||||
public static Tensors get_source_inputs(Tensor tensor, ILayer layer = null, int node_index = -1) | |||||
{ | |||||
if (layer == null) | |||||
(layer, node_index, _) = tensor.KerasHistory; | |||||
if (layer.InboundNodes == null || layer.InboundNodes.Count == 0) | |||||
return tensor; | |||||
else | |||||
{ | |||||
var node = layer.InboundNodes[node_index]; | |||||
if (node.is_input) | |||||
return node.input_tensors; | |||||
else | |||||
{ | |||||
var source_tensors = new List<Tensor>(); | |||||
foreach (var _layer in node.iterate_inbound()) | |||||
{ | |||||
(layer, node_index, tensor) = (_layer.Item1, _layer.Item2, _layer.Item4); | |||||
var previous_sources = get_source_inputs(tensor, layer, node_index); | |||||
foreach(var x in previous_sources) | |||||
{ | |||||
// should be check if exist? | |||||
source_tensors.append(x); | |||||
} | |||||
} | |||||
return source_tensors; | |||||
} | |||||
} | |||||
} | |||||
} | } | ||||
} | } |
@@ -24,7 +24,7 @@ More information about [System.Drawing on Linux](<https://www.hanselman.com/blog | |||||
Before running verify you installed CUDA and cuDNN (TensorFlow v1.15 is compatible with CUDA v10.0 and cuDNN v7.4 , TensorFlow v2.x is compatible with CUDA v10.2 and cuDNN v7.65), and make sure the corresponding cuda version is compatible. | Before running verify you installed CUDA and cuDNN (TensorFlow v1.15 is compatible with CUDA v10.0 and cuDNN v7.4 , TensorFlow v2.x is compatible with CUDA v10.2 and cuDNN v7.65), and make sure the corresponding cuda version is compatible. | ||||
#### Mac OS | #### Mac OS | ||||
There is no GPU support for macOS. | |||||
There is no GPU support for macOS, in the future TensorFlow will support [Apple M1 chip](https://github.com/apple/tensorflow_macos). | |||||
#### GPU for Windows | #### GPU for Windows | ||||
@@ -37,9 +37,11 @@ PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU | |||||
PM> Install-Package SciSharp.TensorFlow.Redist-Linux-GPU | PM> Install-Package SciSharp.TensorFlow.Redist-Linux-GPU | ||||
``` | ``` | ||||
Since NuGet limits file size for 250M, we can't ship Linux GPU version as NuGet, you can download the library from [Google TensorFlow Storage](https://storage.googleapis.com/tensorflow). | |||||
### Download prebuild binary manually | ### Download prebuild binary manually | ||||
Tensorflow packages are built nightly and uploaded to GCS for all supported platforms. They are uploaded to the [libtensorflow-nightly](https://www.tensorflow.org/install/lang_c) GCS bucket and are indexed by operating system and date built. | |||||
TensorFlow packages are built nightly and uploaded to GCS for all supported platforms. They are uploaded to the [libtensorflow-nightly](https://www.tensorflow.org/install/lang_c) GCS bucket and are indexed by operating system and date built. | |||||
### Build from source for Windows | ### Build from source for Windows | ||||
@@ -28,7 +28,7 @@ namespace TensorFlowNET.UnitTest.Basics | |||||
public void decode_image() | public void decode_image() | ||||
{ | { | ||||
var img = tf.image.decode_image(contents); | var img = tf.image.decode_image(contents); | ||||
Assert.AreEqual(img.name, "decode_image/cond_jpeg/Merge:0"); | |||||
Assert.AreEqual(img.name, "decode_image/Identity:0"); | |||||
} | } | ||||
[TestMethod] | [TestMethod] | ||||
@@ -57,7 +57,7 @@ namespace TensorFlowNET.UnitTest.Keras | |||||
{ 2, 3, 4, 5 }, | { 2, 3, 4, 5 }, | ||||
{ 3, 4, 5, 6 } | { 3, 4, 5, 6 } | ||||
}); | }); | ||||
model.compile("rmsprop", "mse"); | |||||
// model.compile("rmsprop", "mse"); | |||||
var output_array = model.predict(input_array); | var output_array = model.predict(input_array); | ||||
Assert.AreEqual((32, 10, 64), output_array.TensorShape); | Assert.AreEqual((32, 10, 64), output_array.TensorShape); | ||||
} | } | ||||
@@ -48,10 +48,10 @@ | |||||
<ItemGroup> | <ItemGroup> | ||||
<PackageReference Include="FluentAssertions" Version="5.10.3" /> | <PackageReference Include="FluentAssertions" Version="5.10.3" /> | ||||
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" /> | <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" /> | ||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.7.1" /> | |||||
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.8.3" /> | |||||
<PackageReference Include="MSTest.TestAdapter" Version="2.1.2" /> | <PackageReference Include="MSTest.TestAdapter" Version="2.1.2" /> | ||||
<PackageReference Include="MSTest.TestFramework" Version="2.1.2" /> | <PackageReference Include="MSTest.TestFramework" Version="2.1.2" /> | ||||
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.0" /> | |||||
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.1" /> | |||||
</ItemGroup> | </ItemGroup> | ||||
<ItemGroup> | <ItemGroup> | ||||