Browse Source

Merge branch 'master' of https://github.com/SciSharp/TensorFlow.NET

tags/keras_v0.3.0
Oceania2018 4 years ago
parent
commit
fda7e6df0b
30 changed files with 359 additions and 167 deletions
  1. +6
    -6
      README.md
  2. +9
    -1
      src/SciSharp.TensorFlow.Redist/README.md
  3. +18
    -1
      src/TensorFlowNET.Core/Data/DatasetV2.cs
  4. +2
    -0
      src/TensorFlowNET.Core/Data/IDatasetV2.cs
  5. +7
    -1
      src/TensorFlowNET.Core/Data/MapDataset.cs
  6. +5
    -0
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  7. +2
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs
  8. +1
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs
  9. +63
    -58
      src/TensorFlowNET.Core/Operations/image_ops_impl.cs
  10. +3
    -3
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  11. +6
    -9
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  12. +33
    -13
      src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
  13. +35
    -0
      src/TensorFlowNET.Keras/Engine/DataAdapters/DatasetAdapter.cs
  14. +2
    -2
      src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
  15. +3
    -1
      src/TensorFlowNET.Keras/Engine/Functional.cs
  16. +3
    -7
      src/TensorFlowNET.Keras/Engine/Model.Compile.cs
  17. +44
    -0
      src/TensorFlowNET.Keras/Engine/Model.Fit.cs
  18. +1
    -1
      src/TensorFlowNET.Keras/Engine/Node.cs
  19. +30
    -35
      src/TensorFlowNET.Keras/Engine/Sequential.cs
  20. +5
    -0
      src/TensorFlowNET.Keras/KerasInterface.cs
  21. +5
    -0
      src/TensorFlowNET.Keras/Layers/Rescaling/Rescaling.cs
  22. +24
    -1
      src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs
  23. +2
    -2
      src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs
  24. +1
    -1
      src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs
  25. +2
    -17
      src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs
  26. +12
    -3
      src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
  27. +29
    -0
      src/TensorFlowNET.Keras/Utils/layer_utils.cs
  28. +4
    -2
      tensorflowlib/README.md
  29. +1
    -1
      test/TensorFlowNET.UnitTest/ImageTest.cs
  30. +1
    -1
      test/TensorFlowNET.UnitTest/Keras/LayersTest.cs

+ 6
- 6
README.md View File

@@ -26,12 +26,12 @@ In comparison to other projects, like for instance [TensorFlowSharp](https://www

### How to use

| TensorFlow | tf native1.14 | tf native 1.15 | tf native 2.3 |
| -------------------------- | ------------- | -------------- | ------------- |
| tf.net 0.3x, tf.keras 0.2 | | | x |
| tf.net 0.2x | | x | x |
| tf.net 0.15 | x | x | |
| tf.net 0.14 | x | | |
| TensorFlow | tf native1.14, cuda 10.0 | tf native 1.15, cuda 10.0 | tf native 2.3, cuda 10.1 | tf native 2.4, cuda 11 |
| -------------------------- | ------------- | -------------- | ------------- | ------------- |
| tf.net 0.3x, tf.keras 0.2 | | | x | not compatible |
| tf.net 0.2x | | x | x | |
| tf.net 0.15 | x | x | | |
| tf.net 0.14 | x | | | |

Troubleshooting of running example or installation, please refer [here](tensorflowlib/README.md).



+ 9
- 1
src/SciSharp.TensorFlow.Redist/README.md View File

@@ -22,11 +22,19 @@ https://www.nuget.org/packages/SciSharp.TensorFlow.Redist

Related merged [commits](https://github.com/SciSharp/TensorFlow.NET/commit/854a5ba61ad0e400623821236bd117cc24c6cb77).



#### Download pre-build package

[Mac OSX CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-2.4.0.tar.gz), [Linux CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.4.0.tar.gz), [Linux GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-2.4.0.tar.gz), [Windows CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-2.4.0.tar.gz), [Windows GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-windows-x86_64-2.4.0.zip)



#### Pack and Deploy ####

On Windows, the tar command does not support extracting archives with symlinks. So when `dotnet pack` runs on Windows it will only package the Windows binaries.

1. Run `dotnet pack SciSharp.TensorFlow.Redist.nupkgproj` under `src/SciSharp.TensorFlow.Redist` directory in Linux.
2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.3.1.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600`
2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.4.0.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600`



+ 18
- 1
src/TensorFlowNET.Core/Data/DatasetV2.cs View File

@@ -3,6 +3,7 @@ using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Framework.Models;
using static Tensorflow.Binding;

namespace Tensorflow
{
@@ -98,6 +99,20 @@ namespace Tensorflow
return dataset;
}

public Tensor dataset_cardinality(string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"DatasetCardinality", name,
null,
variant_tensor);
return results[0];
}

throw new NotImplementedException("");
}

public override string ToString()
=> $"{GetType().Name} shapes: {string.Join(", ", structure.Select(x => x.shape))}, types: {string.Join(", ", structure.Select(x => "tf." + x.dtype.as_numpy_name()))}";

@@ -117,7 +132,9 @@ namespace Tensorflow
break;
}

yield return (results[0], results.Length == 1 ? null : results[1]);
yield return results.Length == 2
? (results[0], results[1])
: (null, results[0]);
}
}



+ 2
- 0
src/TensorFlowNET.Core/Data/IDatasetV2.cs View File

@@ -74,5 +74,7 @@ namespace Tensorflow
/// </summary>
/// <returns></returns>
IDatasetV2 apply_options();

Tensor dataset_cardinality(string name = null);
}
}

+ 7
- 1
src/TensorFlowNET.Core/Data/MapDataset.cs View File

@@ -1,5 +1,6 @@
using System;
using Tensorflow.Functions;
using static Tensorflow.Binding;

namespace Tensorflow
{
@@ -14,7 +15,12 @@ namespace Tensorflow
bool preserve_cardinality = false,
bool use_legacy_function = false) : base(input_dataset)
{
var func = new ConcreteFunction(map_func, input_dataset.element_spec[0].dtype);
using var func = new ConcreteFunction($"autograph_{map_func.Method.Name}");
var input = tf.placeholder(input_dataset.element_spec[0].dtype);
var output = map_func(input);
func.ToGraph(input, output);
structure = func.OutputStructure;

variant_tensor = ops.map_dataset(input_dataset.variant_tensor,
func,


+ 5
- 0
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -109,6 +109,8 @@ namespace Tensorflow.Functions
inputs,
outputs,
null);

OutputStructure = outputs.Select(x => x.ToTensorSpec()).ToArray();
}

public Tensors Invoke(Tensors inputs)
@@ -128,6 +130,9 @@ namespace Tensorflow.Functions
return new ForwardBackwardCall(functions, args, tape_watching: true);
}

public override string ToString()
=> Name;

public void Dispose()
{
c_api.TFE_ContextRemoveFunction(tf.Context.Handle, Name, tf.Status.Handle);


src/TensorFlowNET.Core/Keras/ArgsDefinition/TensorLikeDataAdapterArgs.cs → src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs View File

@@ -2,10 +2,11 @@

namespace Tensorflow.Keras.ArgsDefinition
{
public class TensorLikeDataAdapterArgs
public class DataAdapterArgs
{
public Tensor X { get; set; }
public Tensor Y { get; set; }
public IDatasetV2 Dataset { get; set; }
public int BatchSize { get; set; } = 32;
public int Steps { get; set; }
public int Epochs { get; set; }

+ 1
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs View File

@@ -6,6 +6,7 @@ namespace Tensorflow.Keras.ArgsDefinition
{
public Tensor X { get; set; }
public Tensor Y { get; set; }
public IDatasetV2 Dataset { get; set; }
public int BatchSize { get; set; } = 32;
public int StepsPerEpoch { get; set; } = -1;
public int InitialEpoch { get; set; } = 0;


+ 63
- 58
src/TensorFlowNET.Core/Operations/image_ops_impl.cs View File

@@ -1702,74 +1702,79 @@ new_height, new_width");
public static Tensor decode_image(Tensor contents, int channels = 0, TF_DataType dtype = TF_DataType.TF_UINT8,
string name = null, bool expand_animations = true)
{
Func<ITensorOrOperation> _jpeg = () =>
return tf_with(ops.name_scope(name, "decode_image"), scope =>
{
int jpeg_channels = channels;
var good_channels = math_ops.not_equal(jpeg_channels, 4, name: "check_jpeg_channels");
string channels_msg = "Channels must be in (None, 0, 1, 3) when decoding JPEG 'images'";
var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg });
return tf_with(ops.control_dependencies(new[] { assert_channels }), delegate
var substr = tf.strings.substr(contents, 0, 3);

Func<ITensorOrOperation> _jpeg = () =>
{
return convert_image_dtype(gen_image_ops.decode_jpeg(contents, channels), dtype);
});
};
int jpeg_channels = channels;
var good_channels = math_ops.not_equal(jpeg_channels, 4, name: "check_jpeg_channels");
string channels_msg = "Channels must be in (None, 0, 1, 3) when decoding JPEG 'images'";
var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg });
return tf_with(ops.control_dependencies(new[] { assert_channels }), delegate
{
return convert_image_dtype(gen_image_ops.decode_jpeg(contents, channels), dtype);
});
};

Func<ITensorOrOperation> _gif = () =>
{
int gif_channels = channels;
var good_channels = math_ops.logical_and(
math_ops.not_equal(gif_channels, 1, name: "check_gif_channels"),
math_ops.not_equal(gif_channels, 4, name: "check_gif_channels"));

string channels_msg = "Channels must be in (None, 0, 3) when decoding GIF images";
var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg });
return tf_with(ops.control_dependencies(new[] { assert_channels }), delegate
/*Func<ITensorOrOperation> _gif = () =>
{
var result = convert_image_dtype(gen_image_ops.decode_gif(contents), dtype);
if (!expand_animations)
result = array_ops.gather(result, 0);
return result;
});
};
int gif_channels = channels;
var good_channels = math_ops.logical_and(
math_ops.not_equal(gif_channels, 1, name: "check_gif_channels"),
math_ops.not_equal(gif_channels, 4, name: "check_gif_channels"));

string channels_msg = "Channels must be in (None, 0, 3) when decoding GIF images";
var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg });
return tf_with(ops.control_dependencies(new[] { assert_channels }), delegate
{
var result = convert_image_dtype(gen_image_ops.decode_gif(contents), dtype);
if (!expand_animations)
result = array_ops.gather(result, 0);
return result;
});
};

Func<ITensorOrOperation> _bmp = () =>
{
int bmp_channels = channels;
var signature = tf.strings.substr(contents, 0, 2);
var is_bmp = math_ops.equal(signature, "BM", name: "is_bmp");
string decode_msg = "Unable to decode bytes as JPEG, PNG, GIF, or BMP";
var assert_decode = control_flow_ops.Assert(is_bmp, new string[] { decode_msg });
var good_channels = math_ops.not_equal(bmp_channels, 1, name: "check_channels");
string channels_msg = "Channels must be in (None, 0, 3) when decoding BMP images";
var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg });
return tf_with(ops.control_dependencies(new[] { assert_decode, assert_channels }), delegate
Func<ITensorOrOperation> _bmp = () =>
{
return convert_image_dtype(gen_image_ops.decode_bmp(contents), dtype);
});
};
int bmp_channels = channels;
var signature = tf.strings.substr(contents, 0, 2);
var is_bmp = math_ops.equal(signature, "BM", name: "is_bmp");
string decode_msg = "Unable to decode bytes as JPEG, PNG, GIF, or BMP";
var assert_decode = control_flow_ops.Assert(is_bmp, new string[] { decode_msg });
var good_channels = math_ops.not_equal(bmp_channels, 1, name: "check_channels");
string channels_msg = "Channels must be in (None, 0, 3) when decoding BMP images";
var assert_channels = control_flow_ops.Assert(good_channels, new string[] { channels_msg });
return tf_with(ops.control_dependencies(new[] { assert_decode, assert_channels }), delegate
{
return convert_image_dtype(gen_image_ops.decode_bmp(contents), dtype);
});
};

Func<ITensorOrOperation> _png = () =>
{
return convert_image_dtype(gen_image_ops.decode_png(
contents,
channels,
dtype: dtype),
dtype);
};
Func<ITensorOrOperation> _png = () =>
{
return convert_image_dtype(gen_image_ops.decode_png(
contents,
channels,
dtype: dtype),
dtype);
};

Func<ITensorOrOperation> check_gif = () =>
{
return control_flow_ops.cond(is_gif(contents), _gif, _bmp, name: "cond_gif");
};
Func<ITensorOrOperation> check_gif = () =>
{
var gif = tf.constant(new byte[] { 0x47, 0x49, 0x46 }, TF_DataType.TF_STRING);
var is_gif = math_ops.equal(substr, gif, name: name);
return control_flow_ops.cond(is_gif, _gif, _bmp, name: "cond_gif");
};

Func<ITensorOrOperation> check_png = () =>
{
return control_flow_ops.cond(is_png(contents), _png, check_gif, name: "cond_png");
};
Func<ITensorOrOperation> check_png = () =>
{
return control_flow_ops.cond(is_png(contents), _png, check_gif, name: "cond_png");
};*/

return tf_with(ops.name_scope(name, "decode_image"), scope =>
{
return control_flow_ops.cond(is_jpeg(contents), _jpeg, check_png, name: "cond_jpeg");
// return control_flow_ops.cond(is_jpeg(contents), _jpeg, check_png, name: "cond_jpeg");
return _jpeg() as Tensor;
});
}



+ 3
- 3
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -5,7 +5,7 @@
<AssemblyName>TensorFlow.NET</AssemblyName>
<RootNamespace>Tensorflow</RootNamespace>
<TargetTensorFlow>2.2.0</TargetTensorFlow>
<Version>0.31.1</Version>
<Version>0.31.2</Version>
<LangVersion>8.0</LangVersion>
<Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors>
<Company>SciSharp STACK</Company>
@@ -19,7 +19,7 @@
<Description>Google's TensorFlow full binding in .NET Standard.
Building, training and infering deep learning models.
https://tensorflownet.readthedocs.io</Description>
<AssemblyVersion>0.31.1.0</AssemblyVersion>
<AssemblyVersion>0.31.2.0</AssemblyVersion>
<PackageReleaseNotes>tf.net 0.20.x and above are based on tensorflow native 2.x.

* Eager Mode is added finally.
@@ -30,7 +30,7 @@ https://tensorflownet.readthedocs.io</Description>
TensorFlow .NET v0.30 is focused on making more Keras API work including:
* tf.keras.datasets
* Building keras model in subclass, functional and sequential api</PackageReleaseNotes>
<FileVersion>0.31.1.0</FileVersion>
<FileVersion>0.31.2.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile>
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
<SignAssembly>true</SignAssembly>


+ 6
- 9
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -20,6 +20,7 @@ using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Eager;
using static Tensorflow.Binding;

namespace Tensorflow
{
@@ -410,14 +411,10 @@ would not be rank 1.", tensor.op.get_attr("axis")));
var value = constant_value(tensor);
if (!(value is null))
{
int[] d_ = { };
foreach (int d in value)
{
if (d >= 0)
d_[d_.Length] = d;
else
d_[d_.Length] = -1; // None
}
var d_ = new int[value.size];
foreach (var (index, d) in enumerate(value.ToArray<int>()))
d_[index] = d >= 0 ? d : -1;
ret = ret.merge_with(new TensorShape(d_));
}
return ret;
@@ -577,7 +574,7 @@ would not be rank 1.", tensor.op.get_attr("axis")));
return string.Join(string.Empty, nd.ToArray<byte>()
.Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString()));
case TF_DataType.TF_BOOL:
return (nd.GetByte(0) > 0).ToString();
return nd.GetBoolean(0).ToString();
case TF_DataType.TF_VARIANT:
case TF_DataType.TF_RESOURCE:
return "<unprintable>";


+ 33
- 13
src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs View File

@@ -37,19 +37,38 @@ namespace Tensorflow.Keras.Engine.DataAdapters
_steps_per_execution_value = args.StepsPerExecution.numpy();
}

_adapter = new TensorLikeDataAdapter(new TensorLikeDataAdapterArgs
if(args.Dataset == null)
{
X = args.X,
Y = args.Y,
BatchSize = args.BatchSize,
Steps = args.StepsPerEpoch,
Epochs = args.Epochs - args.InitialEpoch,
Shuffle = args.Shuffle,
MaxQueueSize = args.MaxQueueSize,
Worker = args.Workers,
UseMultiprocessing = args.UseMultiprocessing,
Model = args.Model
});
_adapter = new TensorLikeDataAdapter(new DataAdapterArgs
{
X = args.X,
Y = args.Y,
BatchSize = args.BatchSize,
Steps = args.StepsPerEpoch,
Epochs = args.Epochs - args.InitialEpoch,
Shuffle = args.Shuffle,
MaxQueueSize = args.MaxQueueSize,
Worker = args.Workers,
UseMultiprocessing = args.UseMultiprocessing,
Model = args.Model
});
}
else
{
_adapter = new DatasetAdapter(new DataAdapterArgs
{
Dataset = args.Dataset,
BatchSize = args.BatchSize,
Steps = args.StepsPerEpoch,
Epochs = args.Epochs - args.InitialEpoch,
Shuffle = args.Shuffle,
MaxQueueSize = args.MaxQueueSize,
Worker = args.Workers,
UseMultiprocessing = args.UseMultiprocessing,
Model = args.Model
});
}
_dataset = _adapter.GetDataset();
_inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset);
_current_step = 0;
@@ -66,7 +85,8 @@ namespace Tensorflow.Keras.Engine.DataAdapters
if (adapter_steps > -1)
return adapter_steps;

throw new NotImplementedException("");
var size = dataset.dataset_cardinality();
return size.numpy();
}

public IEnumerable<(int, OwnedIterator)> enumerate_epochs()


+ 35
- 0
src/TensorFlowNET.Keras/Engine/DataAdapters/DatasetAdapter.cs View File

@@ -0,0 +1,35 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;

namespace Tensorflow.Keras.Engine.DataAdapters
{
public class DatasetAdapter : IDataAdapter
{
DataAdapterArgs args;
IDatasetV2 _dataset => args.Dataset;
public DatasetAdapter(DataAdapterArgs args)
{
this.args = args;
}

public bool CanHandle(Tensor x, Tensor y = null)
{
throw new NotImplementedException();
}

public IDatasetV2 GetDataset()
=> _dataset;

public int GetSize()
=> -1;

public (Tensor, Tensor) Expand1d(Tensor x, Tensor y)
{
if (y.TensorShape.ndim == 1)
y = array_ops.expand_dims(y, axis: -1);
return (x, y);
}
}
}

+ 2
- 2
src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs View File

@@ -9,14 +9,14 @@ namespace Tensorflow.Keras.Engine.DataAdapters
/// </summary>
public class TensorLikeDataAdapter : IDataAdapter
{
TensorLikeDataAdapterArgs args;
DataAdapterArgs args;
int _size;
int _batch_size;
int num_samples;
int num_full_batches;
IDatasetV2 _dataset;

public TensorLikeDataAdapter(TensorLikeDataAdapterArgs args)
public TensorLikeDataAdapter(DataAdapterArgs args)
{
this.args = args;
_process_tensorlike();


+ 3
- 1
src/TensorFlowNET.Keras/Engine/Functional.cs View File

@@ -39,10 +39,12 @@ namespace Tensorflow.Keras.Engine
_input_coordinates = new List<KerasHistory>();
_output_coordinates = new List<KerasHistory>();
tensor_usage_count = new Dictionary<int, int>();
if (this is Sequential)
return;
_init_graph_network(inputs, outputs);
}

void _init_graph_network(Tensors inputs, Tensors outputs)
protected void _init_graph_network(Tensors inputs, Tensors outputs)
{
_is_graph_network = true;
this.inputs = inputs;


+ 3
- 7
src/TensorFlowNET.Keras/Engine/Model.Compile.cs View File

@@ -9,10 +9,6 @@ namespace Tensorflow.Keras.Engine
{
LossesContainer compiled_loss;
MetricsContainer compiled_metrics;
public void compile(string optimizerName, ILossFunc lossName)
{
throw new NotImplementedException("");
}

public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics)
{
@@ -29,12 +25,12 @@ namespace Tensorflow.Keras.Engine
this.loss = loss;
}

public void compile(string optimizerName, string lossName)
public void compile(string optimizer, string loss, string[] metrics)
{
switch (optimizerName)
switch (optimizer)
{
case "rmsprop":
optimizer = new RMSprop(new RMSpropArgs
this.optimizer = new RMSprop(new RMSpropArgs
{

});


+ 44
- 0
src/TensorFlowNET.Keras/Engine/Model.Fit.cs View File

@@ -68,5 +68,49 @@ namespace Tensorflow.Keras.Engine
Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
}
}

public void fit(IDatasetV2 dataset,
IDatasetV2 validation_data = null,
int batch_size = -1,
int epochs = 1,
int verbose = 1,
float validation_split = 0f,
bool shuffle = true,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
bool use_multiprocessing = false)
{
data_handler = new DataHandler(new DataHandlerArgs
{
Dataset = dataset,
BatchSize = batch_size,
InitialEpoch = initial_epoch,
Epochs = epochs,
Shuffle = shuffle,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
Model = this,
StepsPerExecution = _steps_per_execution
});

stop_training = false;
_train_counter.assign(0);
Console.WriteLine($"Training...");
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{
// reset_metrics();
// callbacks.on_epoch_begin(epoch)
// data_handler.catch_stop_iteration();
IEnumerable<(string, Tensor)> results = null;
foreach (var step in data_handler.steps())
{
// callbacks.on_train_batch_begin(step)
results = step_function(iterator);
}
Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
}
}
}
}

+ 1
- 1
src/TensorFlowNET.Keras/Engine/Node.cs View File

@@ -35,7 +35,7 @@ namespace Tensorflow.Keras.Engine

public int[] node_indices;
public int[] tensor_indices;
public Tensors input_tensors => args.InputTensors;
public Tensors input_tensors => is_input ? Outputs : args.InputTensors;
public Tensors Outputs => args.Outputs;
public TensorShape[] input_shapes;
public TensorShape[] output_shapes;


+ 30
- 35
src/TensorFlowNET.Keras/Engine/Sequential.cs View File

@@ -17,6 +17,7 @@
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Utils;
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.Engine
@@ -25,36 +26,40 @@ namespace Tensorflow.Keras.Engine
/// `Sequential` groups a linear stack of layers into a `tf.keras.Model`.
/// `Sequential` provides training and inference features on this model.
/// </summary>
public class Sequential : Model
public class Sequential : Functional
{
SequentialArgs args;
bool _is_graph_network;
Tensor inputs;
Tensor outputs;
bool computeOutputAndMaskJointly;
bool autoTrackSubLayers;
TensorShape inferredInputShape;
bool hasExplicitInputShape;
TF_DataType inputDType;
List<ILayer> layers => args.Layers;
public TensorShape output_shape => outputs.TensorShape;
Tensors inputs;
Tensors outputs;
bool _compute_output_and_mask_jointly;
bool _auto_track_sub_layers;
TensorShape _inferred_input_shape;
bool _has_explicit_input_shape;
TF_DataType _input_dtype;
public TensorShape output_shape => outputs[0].TensorShape;
bool built = false;

public Sequential(SequentialArgs args)
: base(new ModelArgs
{
Name = args.Name
})
: base(args.Inputs, args.Outputs, name: args.Name)
{
this.args = args;
if (args.Layers == null)
args.Layers = new List<ILayer>();
// SupportsMasking = true;
computeOutputAndMaskJointly = true;
autoTrackSubLayers = false;
hasExplicitInputShape = false;
_compute_output_and_mask_jointly = true;
_auto_track_sub_layers = false;
_has_explicit_input_shape = false;
_is_graph_network = false;

// Add to the model any layers passed to the constructor.
if (args.Layers != null)
{
foreach (var layer in args.Layers)
add(layer as Layer);
}
}

public void add(Tensor tensor)
@@ -71,7 +76,7 @@ namespace Tensorflow.Keras.Engine
{
built = false;
var set_inputs = false;
if (layers.Count == 0)
if (_layers.Count == 0)
{
if (layer is InputLayer)
{
@@ -83,7 +88,7 @@ namespace Tensorflow.Keras.Engine
{
// Instantiate an input layer.
var x = keras.Input(
shape: layer.BatchInputShape,
batch_input_shape: layer.BatchInputShape,
dtype: layer.DType,
name: layer.Name + "_input");

@@ -99,36 +104,26 @@ namespace Tensorflow.Keras.Engine
{
// If an input layer (placeholder) is available.
outputs = layer.InboundNodes[^1].Outputs;
inputs = layer_utils.get_source_inputs(outputs[0]);
built = true;
_has_explicit_input_shape = true;
}

}
else if (outputs != null)
{
outputs = layer.Apply(outputs);
built = true;
}

if (set_inputs || _is_graph_network)
{
_init_graph_network(inputs, outputs);
_is_graph_network = true;
}
else
{

}
}

void _init_graph_network(Tensor inputs, Tensor outputs)
{
_is_graph_network = true;
this.inputs = inputs;
this.outputs = outputs;
built = true;
_map_graph_network(inputs, outputs);
}

void _map_graph_network(Tensor inputs, Tensor outputs)
{
layers.add(outputs.KerasHistory.Layer);
}
}
}

+ 5
- 0
src/TensorFlowNET.Keras/KerasInterface.cs View File

@@ -62,16 +62,21 @@ namespace Tensorflow.Keras
/// <returns></returns>
public Tensor Input(TensorShape shape = null,
int batch_size = -1,
TensorShape batch_input_shape = null,
TF_DataType dtype = TF_DataType.DtInvalid,
string name = null,
bool sparse = false,
bool ragged = false,
Tensor tensor = null)
{
if (batch_input_shape != null)
shape = batch_input_shape.dims[1..];

var args = new InputLayerArgs
{
Name = name,
InputShape = shape,
BatchInputShape = batch_input_shape,
BatchSize = batch_size,
DType = dtype,
Sparse = sparse,


+ 5
- 0
src/TensorFlowNET.Keras/Layers/Rescaling/Rescaling.cs View File

@@ -23,5 +23,10 @@ namespace Tensorflow.Keras.Layers
offset = math_ops.cast(args.Offset, args.DType);
return math_ops.cast(inputs, args.DType) * scale + offset;
}

public override TensorShape ComputeOutputShape(TensorShape input_shape)
{
return input_shape;
}
}
}

+ 24
- 1
src/TensorFlowNET.Keras/Layers/Reshaping/Flatten.cs View File

@@ -1,4 +1,5 @@
using System;
using Tensorflow.Framework;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Utils;
@@ -15,6 +16,7 @@ namespace Tensorflow.Keras.Layers
public Flatten(FlattenArgs args)
: base(args)
{
this.args = args;
args.DataFormat = conv_utils.normalize_data_format(args.DataFormat);
input_spec = new InputSpec(min_ndim: 1);
_channels_first = args.DataFormat == "channels_first";
@@ -31,8 +33,29 @@ namespace Tensorflow.Keras.Layers
{
return array_ops.reshape(inputs, new[] { inputs.shape[0], -1 });
}
else
{
var input_shape = inputs.shape;
var rank = inputs.shape.rank;
if (rank == 1)
return array_ops.expand_dims(inputs, axis: 1);
var batch_dim = tensor_shape.dimension_value(input_shape[0]);
if (batch_dim != -1)
{
return array_ops.reshape(inputs, new[] { batch_dim, -1 });
}

throw new NotImplementedException("");
var non_batch_dims = ((int[])input_shape)[1..];
var num = 1;
if (non_batch_dims.Length > 0)
{
for (var i = 0; i < non_batch_dims.Length; i++)
{
num *= non_batch_dims[i];
}
}
return array_ops.reshape(inputs, new[] { inputs.shape[0], num });
}
}
}
}

+ 2
- 2
src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs View File

@@ -40,8 +40,8 @@ namespace Tensorflow.Keras.Preprocessings
labels.AddRange(Enumerable.Range(0, files.Length).Select(x => label));
}

var return_labels = new int[labels.Count];
var return_file_paths = new string[file_paths.Count];
var return_labels = labels.Select(x => x).ToArray();
var return_file_paths = file_paths.Select(x => x).ToArray();

if (shuffle)
{


+ 1
- 1
src/TensorFlowNET.Keras/Preprocessings/Preprocessing.image_dataset_from_directory.cs View File

@@ -41,7 +41,7 @@ namespace Tensorflow.Keras
int num_channels = 0;
if (color_mode == "rgb")
num_channels = 3;
// C:/Users/haipi/.keras/datasets/flower_photos
var (image_paths, label_list, class_name_list) = keras.preprocessing.dataset_utils.index_directory(directory,
formats: WHITELIST_FORMATS,
class_names: class_names,


+ 2
- 17
src/TensorFlowNET.Keras/Preprocessings/Preprocessing.paths_and_labels_to_dataset.cs View File

@@ -16,27 +16,11 @@ namespace Tensorflow.Keras
var path_ds = tf.data.Dataset.from_tensor_slices(image_paths);
var img_ds = path_ds.map(x => path_to_image(x, image_size, num_channels, interpolation));

/*Shape shape = (image_paths.Length, image_size.dims[0], image_size.dims[1], num_channels);
Console.WriteLine($"Allocating memory for shape{shape}, {NPTypeCode.Float}");
var data = np.zeros(shape, NPTypeCode.Float);

for (var i = 0; i < image_paths.Length; i++)
{
var image = path_to_image(image_paths[i], image_size, num_channels, interpolation);
data[i] = image.numpy();
if (i % 100 == 0)
Console.WriteLine($"Filled {i}/{image_paths.Length} data into ndarray.");
}

var img_ds = tf.data.Dataset.from_tensor_slices(data);

if (label_mode == "int")
{
var label_ds = tf.keras.preprocessing.dataset_utils.labels_to_dataset(labels, label_mode, num_classes);
var label_ds = dataset_utils.labels_to_dataset(labels, label_mode, num_classes);
img_ds = tf.data.Dataset.zip(img_ds, label_ds);
}
else*/
throw new NotImplementedException("");

return img_ds;
}
@@ -47,6 +31,7 @@ namespace Tensorflow.Keras
img = tf.image.decode_image(
img, channels: num_channels, expand_animations: false);
img = tf.image.resize_images_v2(img, image_size, method: interpolation);
// img.set_shape((image_size[0], image_size[1], num_channels));
return img;
}
}


+ 12
- 3
src/TensorFlowNET.Keras/Tensorflow.Keras.csproj View File

@@ -6,7 +6,7 @@
<LangVersion>8.0</LangVersion>
<RootNamespace>Tensorflow.Keras</RootNamespace>
<Platforms>AnyCPU;x64</Platforms>
<Version>0.2.1</Version>
<Version>0.3.0</Version>
<Authors>Haiping Chen</Authors>
<Product>Keras for .NET</Product>
<Copyright>Apache 2.0, Haiping Chen 2020</Copyright>
@@ -25,11 +25,13 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
<Company>SciSharp STACK</Company>
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
<PackageTags>tensorflow, keras, deep learning, machine learning</PackageTags>
<PackageRequireLicenseAcceptance>false</PackageRequireLicenseAcceptance>
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
<RepositoryType>Git</RepositoryType>
<SignAssembly>true</SignAssembly>
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile>
<AssemblyVersion>0.2.1.0</AssemblyVersion>
<AssemblyVersion>0.3.0.0</AssemblyVersion>
<FileVersion>0.3.0.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
@@ -55,4 +57,11 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
<ProjectReference Include="..\TensorFlowNET.Core\Tensorflow.Binding.csproj" />
</ItemGroup>

<ItemGroup>
<None Include="..\..\LICENSE">
<Pack>True</Pack>
<PackagePath></PackagePath>
</None>
</ItemGroup>

</Project>

+ 29
- 0
src/TensorFlowNET.Keras/Utils/layer_utils.cs View File

@@ -187,5 +187,34 @@ namespace Tensorflow.Keras.Utils
var total = weight_shapes.Select(p => (int)np.prod(p.dims)).Sum();
return total;
}

public static Tensors get_source_inputs(Tensor tensor, ILayer layer = null, int node_index = -1)
{
if (layer == null)
(layer, node_index, _) = tensor.KerasHistory;
if (layer.InboundNodes == null || layer.InboundNodes.Count == 0)
return tensor;
else
{
var node = layer.InboundNodes[node_index];
if (node.is_input)
return node.input_tensors;
else
{
var source_tensors = new List<Tensor>();
foreach (var _layer in node.iterate_inbound())
{
(layer, node_index, tensor) = (_layer.Item1, _layer.Item2, _layer.Item4);
var previous_sources = get_source_inputs(tensor, layer, node_index);
foreach(var x in previous_sources)
{
// should be check if exist?
source_tensors.append(x);
}
}
return source_tensors;
}
}
}
}
}

+ 4
- 2
tensorflowlib/README.md View File

@@ -24,7 +24,7 @@ More information about [System.Drawing on Linux](<https://www.hanselman.com/blog
Before running verify you installed CUDA and cuDNN (TensorFlow v1.15 is compatible with CUDA v10.0 and cuDNN v7.4 , TensorFlow v2.x is compatible with CUDA v10.2 and cuDNN v7.65), and make sure the corresponding cuda version is compatible.

#### Mac OS
There is no GPU support for macOS.
There is no GPU support for macOS, in the future TensorFlow will support [Apple M1 chip](https://github.com/apple/tensorflow_macos).

#### GPU for Windows

@@ -37,9 +37,11 @@ PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU
PM> Install-Package SciSharp.TensorFlow.Redist-Linux-GPU
```

Since NuGet limits file size for 250M, we can't ship Linux GPU version as NuGet, you can download the library from [Google TensorFlow Storage](https://storage.googleapis.com/tensorflow).

### Download prebuild binary manually

Tensorflow packages are built nightly and uploaded to GCS for all supported platforms. They are uploaded to the [libtensorflow-nightly](https://www.tensorflow.org/install/lang_c) GCS bucket and are indexed by operating system and date built.
TensorFlow packages are built nightly and uploaded to GCS for all supported platforms. They are uploaded to the [libtensorflow-nightly](https://www.tensorflow.org/install/lang_c) GCS bucket and are indexed by operating system and date built.


### Build from source for Windows


+ 1
- 1
test/TensorFlowNET.UnitTest/ImageTest.cs View File

@@ -28,7 +28,7 @@ namespace TensorFlowNET.UnitTest.Basics
public void decode_image()
{
var img = tf.image.decode_image(contents);
Assert.AreEqual(img.name, "decode_image/cond_jpeg/Merge:0");
Assert.AreEqual(img.name, "decode_image/Identity:0");
}

[TestMethod]


+ 1
- 1
test/TensorFlowNET.UnitTest/Keras/LayersTest.cs View File

@@ -57,7 +57,7 @@ namespace TensorFlowNET.UnitTest.Keras
{ 2, 3, 4, 5 },
{ 3, 4, 5, 6 }
});
model.compile("rmsprop", "mse");
// model.compile("rmsprop", "mse");
var output_array = model.predict(input_array);
Assert.AreEqual((32, 10, 64), output_array.TensorShape);
}


Loading…
Cancel
Save