diff --git a/README.md b/README.md
index 36ec1660..0198c873 100644
--- a/README.md
+++ b/README.md
@@ -15,20 +15,6 @@
English | [中文](docs/README-CN.md)
-**=========================================================**
-
-### [Voting: Naming Convention Approach of v1.0.0](https://github.com/SciSharp/TensorFlow.NET/issues/1074)
-
-Dear all,
-
-We would like to urge you to participate in our upcoming vote regarding the naming convention for TensorFlow.NET version 1.0.0 in [#1074](https://github.com/SciSharp/TensorFlow.NET/issues/1074). Your participation in the vote is essential to help us decide on the best approach for improving the naming convention used in previous versions.
-
-Thank you,
-
-TensorFlow.NET Authors
-
-**=========================================================**
-
*master branch and v0.100.x is corresponding to tensorflow v2.10, v0.6x branch is from tensorflow v2.6, v0.15-tensorflow1.15 is from tensorflow1.15. Please add `https://www.myget.org/F/scisharp/api/v3/index.json` to nuget source to use nightly release.*
@@ -75,9 +61,12 @@ PM> Install-Package TensorFlow.Keras
The second part is the computing support part. Only one of the following packages is needed, depending on your device and system.
```
-### CPU version for Windows, Linux and Mac
+### CPU version for Windows and Linux
PM> Install-Package SciSharp.TensorFlow.Redist
+### CPU version for MacOS
+PM> Install-Package SciSharp.TensorFlow.Redist-OSX
+
### GPU version for Windows (CUDA and cuDNN are required)
PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU
diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln
index 87729e27..214b039d 100644
--- a/TensorFlow.NET.sln
+++ b/TensorFlow.NET.sln
@@ -39,6 +39,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Benchmark", "too
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Console", "tools\TensorFlowNET.Console\Tensorflow.Console.csproj", "{1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}"
EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TensorFlow.Kernel.UnitTest", "test\TensorFlow.Kernel.UnitTest\TensorFlow.Kernel.UnitTest.csproj", "{654A027D-1364-4729-880B-144DFE1FF5BB}"
+EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -322,6 +324,24 @@ Global
{1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Release|x64.Build.0 = Release|x64
{1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Release|x86.ActiveCfg = Release|Any CPU
{1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0}.Release|x86.Build.0 = Release|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.Debug|x64.ActiveCfg = Debug|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.Debug|x64.Build.0 = Debug|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.Debug|x86.ActiveCfg = Debug|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.Debug|x86.Build.0 = Debug|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.GPU|Any CPU.ActiveCfg = Debug|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.GPU|Any CPU.Build.0 = Debug|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.GPU|x64.ActiveCfg = Debug|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.GPU|x64.Build.0 = Debug|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.GPU|x86.ActiveCfg = Debug|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.GPU|x86.Build.0 = Debug|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.Release|Any CPU.Build.0 = Release|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.Release|x64.ActiveCfg = Release|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.Release|x64.Build.0 = Release|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.Release|x86.ActiveCfg = Release|Any CPU
+ {654A027D-1364-4729-880B-144DFE1FF5BB}.Release|x86.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@@ -342,6 +362,7 @@ Global
{D24FCAA5-548C-4251-B226-A1B6535D0845} = {E1A5D2B7-10AF-4876-85C0-7714EF274214}
{C23563DB-FE21-48E7-A411-87A109E4A899} = {E1A5D2B7-10AF-4876-85C0-7714EF274214}
{1DC32255-BA1F-4D6D-A9C9-5BD5ED71CAA0} = {E1A5D2B7-10AF-4876-85C0-7714EF274214}
+ {654A027D-1364-4729-880B-144DFE1FF5BB} = {1B0918B9-65AD-4F34-A287-AF4597B27DBD}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {2DEAD3CC-486B-4918-A607-50B0DE7B114A}
diff --git a/data/img001.bmp b/data/img001.bmp
new file mode 100644
index 00000000..d149d76f
Binary files /dev/null and b/data/img001.bmp differ
diff --git a/src/TensorFlowNET.Core/APIs/c_api.customize.cs b/src/TensorFlowNET.Core/APIs/c_api.customize.cs
index 510e52eb..bee4897e 100644
--- a/src/TensorFlowNET.Core/APIs/c_api.customize.cs
+++ b/src/TensorFlowNET.Core/APIs/c_api.customize.cs
@@ -8,10 +8,10 @@ namespace Tensorflow
public partial class c_api
{
[DllImport(TensorFlowLibName)]
- public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status);
+ public static extern void TF_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
- public static extern SafeBufferHandle TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output);
+ public static extern SafeBufferHandle TF_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output);
[DllImport(TensorFlowLibName)]
- public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status);
+ public static extern void TF_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status);
}
}
diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs
index 4d9c3da5..b529cd31 100644
--- a/src/TensorFlowNET.Core/APIs/tf.array.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.array.cs
@@ -140,6 +140,16 @@ namespace Tensorflow
public Tensor gather(Tensor @params, Tensor indices, string name = null, int axis = 0)
=> array_ops.gather(@params, indices, name: name, axis: ops.convert_to_tensor(axis));
+ ///
+ /// Gather slices from `params` into a Tensor with shape specified by `indices`.
+ ///
+ ///
+ ///
+ ///
+ ///
+ public Tensor gather_nd(Tensor @params, Tensor indices, string name = null)
+ => gen_array_ops.gather_nd(@params, indices, name: name);
+
///
/// Return the elements, either from `x` or `y`, depending on the `condition`.
///
diff --git a/src/TensorFlowNET.Core/APIs/tf.image.cs b/src/TensorFlowNET.Core/APIs/tf.image.cs
index ac9cbc60..41ef5296 100644
--- a/src/TensorFlowNET.Core/APIs/tf.image.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.image.cs
@@ -339,6 +339,13 @@ namespace Tensorflow
=> image_ops_impl.decode_image(contents, channels: channels, dtype: dtype,
name: name, expand_animations: expand_animations);
+ public Tensor encode_png(Tensor contents, string name = null)
+ => image_ops_impl.encode_png(contents, name: name);
+
+ public Tensor encode_jpeg(Tensor contents, string name = null)
+ => image_ops_impl.encode_jpeg(contents, name: name);
+
+
///
/// Convenience function to check if the 'contents' encodes a JPEG image.
///
diff --git a/src/TensorFlowNET.Core/APIs/tf.io.cs b/src/TensorFlowNET.Core/APIs/tf.io.cs
index be1e86e6..ea1e44b2 100644
--- a/src/TensorFlowNET.Core/APIs/tf.io.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.io.cs
@@ -16,6 +16,7 @@
using System.Collections.Generic;
using Tensorflow.IO;
+using Tensorflow.Operations;
namespace Tensorflow
{
@@ -46,6 +47,12 @@ namespace Tensorflow
public Tensor[] restore_v2(Tensor prefix, string[] tensor_names,
string[] shape_and_slices, TF_DataType[] dtypes, string name = null)
=> ops.restore_v2(prefix, tensor_names, shape_and_slices, dtypes, name: name);
+
+ public Operation write_file(string filename, Tensor conentes, string name = null)
+ => write_file(Tensorflow.ops.convert_to_tensor(filename, TF_DataType.TF_STRING), conentes, name);
+
+ public Operation write_file(Tensor filename, Tensor conentes, string name = null)
+ => gen_ops.write_file(filename, conentes, name);
}
public GFile gfile = new GFile();
diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs
index 397c68c7..112c4862 100644
--- a/src/TensorFlowNET.Core/APIs/tf.nn.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs
@@ -101,6 +101,8 @@ namespace Tensorflow
name: name);
public IActivation relu() => new relu();
+
+
public IActivation swish() => new swish();
public IActivation tanh() => new tanh();
@@ -111,6 +113,9 @@ namespace Tensorflow
public Tensor relu(Tensor features, string name = null)
=> gen_nn_ops.relu(features, name);
+ public Tensor relu6(Tensor features, string name = null)
+ => gen_nn_ops.relu6(features, name);
+
public Tensor[] fused_batch_norm(Tensor x,
Tensor scale,
Tensor offset,
diff --git a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
index 59d5fd03..2bdd65f5 100644
--- a/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
@@ -80,6 +80,11 @@ namespace Tensorflow.Eager
Tensor[] op_outputs)
=> (out_grads, unneeded_gradients) =>
{
+ if(!ops.gradientFunctions.ContainsKey(op_name))
+ {
+ throw new Exception($"gradientFunctions not find op_name: {op_name}");
+ }
+
if (ops.gradientFunctions[op_name] == null)
return new Tensor[op_inputs.Length];
diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs
index 4b702799..a4da60ee 100644
--- a/src/TensorFlowNET.Core/Gradients/array_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs
@@ -381,5 +381,48 @@ namespace Tensorflow.Gradients
var axis = op.inputs[1];
return new Tensor[] { array_ops.reverse(grad, axis), null };
}
+
+ [RegisterGradient("Tile")]
+ public static Tensor[] _TileGrad(Operation op, Tensor[] grads)
+ {
+ var grad = grads[0];
+ var input_shape = array_ops.shape(op.inputs[0], out_type: op.inputs[1].dtype);
+ var split_shape = array_ops.reshape(array_ops.transpose(array_ops.stack(new Tensor[] { op.inputs[1], input_shape })), new Shape(-1));
+ var axes = math_ops.range(0, array_ops.size(split_shape), 2);
+
+ //# Sum reduces grad along the first dimension for IndexedSlices
+ //if isinstance(grad, indexed_slices_lib.IndexedSlices):
+ //input_shape_0 = math_ops.cast(input_shape[0], grad.indices.dtype)
+ //grad = math_ops.unsorted_segment_sum(
+ // grad.values, math_ops.mod(grad.indices, input_shape_0), input_shape_0)
+ //split_shape = array_ops.concat([[1], split_shape[1:]], axis = 0)
+
+ var input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes);
+ if (!tf.Context.executing_eagerly())
+ {
+ input_grad.set_shape(op.inputs[0].GetShape());
+ }
+ return new Tensor[] { input_grad, null };
+ }
+
+ [RegisterGradient("GatherNd")]
+ public static Tensor[] _GatherNdGrad(Operation op, Tensor[] grads)
+ {
+ var @ref = op.inputs[0];
+ var indices = op.inputs[1];
+ var grad = grads[0];
+ var ref_shape = array_ops.shape(@ref, out_type: indices.dtype);
+ Tensor ref_grad = null;
+ if (indices.shape.ndim == 2 && indices.shape.dims[indices.shape.Length - 1] == 1)
+ {
+ ref_grad = (Tensor)new IndexedSlices(grad, array_ops.squeeze(indices, axis: -1), ref_shape);
+ }
+ else
+ {
+ ref_grad = gen_array_ops.scatter_nd(indices, grad, ref_shape);
+ }
+ return new Tensor[] { ref_grad, null };
+ }
+
}
}
diff --git a/src/TensorFlowNET.Core/Gradients/nn_grad.cs b/src/TensorFlowNET.Core/Gradients/nn_grad.cs
index a43a91b9..87646a9e 100644
--- a/src/TensorFlowNET.Core/Gradients/nn_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/nn_grad.cs
@@ -229,6 +229,37 @@ namespace Tensorflow.Gradients
};
}
+ ///
+ /// Gradient function for Conv2D.
+ ///
+ ///
+ ///
+ ///
+ [RegisterGradient("DepthwiseConv2dNative")]
+ public static Tensor[] _DepthwiseConv2DGrad(Operation op, Tensor[] grads)
+ {
+ var dilations = op.get_attr_list("dilations");
+ var strides = op.get_attr_list("strides");
+ var padding = op.get_attr("padding");
+ var explicit_paddings = op.get_attr_list("explicit_paddings");
+ var data_format = op.get_attr("data_format");
+ var shape = gen_array_ops.shape_n(new Tensor[] { op.inputs[0], op.inputs[1] });
+
+ return new Tensor[]
+ {
+ gen_nn_ops.depthwise_conv2d_native_backprop_input(
+ shape[0], op.inputs[1], grads[0],
+ strides, padding, explicit_paddings,
+ dilations: dilations,
+ data_format: data_format),
+ gen_nn_ops.depthwise_conv2d_native_backprop_filter(op.inputs[0], shape[1], grads[0],
+ strides, padding,
+ dilations: dilations,
+ explicit_paddings: explicit_paddings,
+ data_format: data_format)
+ };
+ }
+
[RegisterGradient("FusedBatchNorm")]
public static Tensor[] _FusedBatchNormGrad(Operation op, Tensor[] grads)
=> _BaseFusedBatchNormGrad(op, 0, grads);
diff --git a/src/TensorFlowNET.Core/Keras/Activations/Activations.cs b/src/TensorFlowNET.Core/Keras/Activations/Activations.cs
index f0d59ed6..37264104 100644
--- a/src/TensorFlowNET.Core/Keras/Activations/Activations.cs
+++ b/src/TensorFlowNET.Core/Keras/Activations/Activations.cs
@@ -32,6 +32,7 @@ namespace Tensorflow.Keras
Activation Linear { get; }
Activation Relu { get; }
+ Activation Relu6 { get; }
Activation Sigmoid { get; }
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs
index 78882e82..ba033283 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs
@@ -1,5 +1,6 @@
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
+using Tensorflow.NumPy;
namespace Tensorflow.Keras.ArgsDefinition
{
@@ -16,5 +17,7 @@ namespace Tensorflow.Keras.ArgsDefinition
public int Worker { get; set; }
public bool UseMultiprocessing { get; set; }
public IModel Model { get; set; }
+ public Dictionary ClassWeight = null;
+ public NDArray SampleWeight = null;
}
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs
index 82530e95..72d0bb81 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs
@@ -1,5 +1,6 @@
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
+using Tensorflow.NumPy;
namespace Tensorflow.Keras.ArgsDefinition
{
@@ -18,5 +19,7 @@ namespace Tensorflow.Keras.ArgsDefinition
public bool UseMultiprocessing { get; set; } = false;
public IModel Model { get; set; }
public IVariableV1 StepsPerExecution { get; set; }
+ public Dictionary ClassWeight = null;
+ public NDArray SampleWeight = null;
}
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs
index 0140b3dd..9bcf1908 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs
@@ -1,13 +1,15 @@
-using System;
+using Newtonsoft.Json;
+using System;
using System.Collections.Generic;
using System.Text;
namespace Tensorflow.Keras.ArgsDefinition
{
// TODO: complete the implementation
- public class MergeArgs : LayerArgs
+ public class MergeArgs : AutoSerializeLayerArgs
{
public Tensors Inputs { get; set; }
+ [JsonProperty("axis")]
public int Axis { get; set; }
}
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUOptionalArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUOptionalArgs.cs
index d441dc82..1d215576 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUOptionalArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/GRUOptionalArgs.cs
@@ -4,10 +4,8 @@ using System.Text;
namespace Tensorflow.Keras.ArgsDefinition
{
- public class GRUOptionalArgs
+ public class GRUOptionalArgs : RnnOptionalArgs
{
public string Identifier => "GRU";
-
- public Tensor Mask { get; set; } = null;
}
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMOptionalArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMOptionalArgs.cs
new file mode 100644
index 00000000..2829927c
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMOptionalArgs.cs
@@ -0,0 +1,11 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Keras.ArgsDefinition.Rnn
+{
+ public class LSTMOptionalArgs : RnnOptionalArgs
+ {
+ public string Identifier => "LSTM";
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNOptionalArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNOptionalArgs.cs
new file mode 100644
index 00000000..a8b8caf0
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/SimpleRNNOptionalArgs.cs
@@ -0,0 +1,11 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Keras.ArgsDefinition.Rnn
+{
+ public class SimpleRNNOptionalArgs : RnnOptionalArgs
+ {
+ public string Identifier => "SimpleRNN";
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs
index 19f3df9b..889c76d9 100644
--- a/src/TensorFlowNET.Core/Keras/Engine/IModel.cs
+++ b/src/TensorFlowNET.Core/Keras/Engine/IModel.cs
@@ -3,6 +3,7 @@ using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Saving;
using Tensorflow.NumPy;
+using Tensorflow.Util;
namespace Tensorflow.Keras.Engine;
@@ -22,8 +23,11 @@ public interface IModel : ILayer
int verbose = 1,
List callbacks = null,
float validation_split = 0f,
- (NDArray val_x, NDArray val_y)? validation_data = null,
+ ValidationDataPack validation_data = null,
+ int validation_step = 10,
bool shuffle = true,
+ Dictionary class_weight = null,
+ NDArray sample_weight = null,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
@@ -35,8 +39,24 @@ public interface IModel : ILayer
int verbose = 1,
List callbacks = null,
float validation_split = 0f,
- (IEnumerable val_x, NDArray val_y)? validation_data = null,
+ ValidationDataPack validation_data = null,
bool shuffle = true,
+ Dictionary class_weight = null,
+ NDArray sample_weight = null,
+ int initial_epoch = 0,
+ int max_queue_size = 10,
+ int workers = 1,
+ bool use_multiprocessing = false);
+
+ public ICallback fit(IDatasetV2 dataset,
+ int batch_size = -1,
+ int epochs = 1,
+ int verbose = 1,
+ List callbacks = null,
+ IDatasetV2 validation_data = null,
+ int validation_step = 10, // 间隔多少次会进行一次验证
+ bool shuffle = true,
+ Dictionary class_weight = null,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
@@ -63,6 +83,8 @@ public interface IModel : ILayer
Dictionary evaluate(NDArray x, NDArray y,
int batch_size = -1,
int verbose = 1,
+ NDArray sample_weight = null,
+
int steps = -1,
int max_queue_size = 10,
int workers = 1,
@@ -78,6 +100,14 @@ public interface IModel : ILayer
int workers = 1,
bool use_multiprocessing = false);
+ public Tensors predict(IDatasetV2 dataset,
+ int batch_size = -1,
+ int verbose = 0,
+ int steps = -1,
+ int max_queue_size = 10,
+ int workers = 1,
+ bool use_multiprocessing = false);
+
void summary(int line_length = -1, float[] positions = null);
IKerasConfig get_config();
diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
index 5e08eadc..57273eb0 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
@@ -55,6 +55,12 @@ namespace Tensorflow.Keras.Layers
string kernel_initializer = "glorot_uniform",
string bias_initializer = "zeros");
+ public ILayer Conv2D(int filters,
+ Shape kernel_size = null,
+ Shape strides = null,
+ string padding = "valid"
+ );
+
public ILayer Conv2D(int filters,
Shape kernel_size = null,
Shape strides = null,
@@ -95,6 +101,19 @@ namespace Tensorflow.Keras.Layers
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string bias_initializer = "zeros");
+ public ILayer DepthwiseConv2D(Shape kernel_size = null,
+ Shape strides = null,
+ string padding = "valid",
+ string data_format = null,
+ Shape dilation_rate = null,
+ int groups = 1,
+ int depth_multiplier = 1,
+ string activation = null,
+ bool use_bias = false,
+ string kernel_initializer = "glorot_uniform",
+ string bias_initializer = "zeros",
+ string depthwise_initializer = "glorot_uniform"
+ );
public ILayer Dense(int units);
public ILayer Dense(int units,
@@ -161,6 +180,9 @@ namespace Tensorflow.Keras.Layers
public ILayer Normalization(Shape? input_shape = null, int? axis = -1, float? mean = null, float? variance = null, bool invert = false);
public ILayer LeakyReLU(float alpha = 0.3f);
+ public ILayer ReLU6();
+
+
public IRnnCell LSTMCell(int uints,
string activation = "tanh",
string recurrent_activation = "sigmoid",
diff --git a/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs b/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs
index 94085605..5e257417 100644
--- a/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs
+++ b/src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs
@@ -30,6 +30,15 @@ namespace Tensorflow.NumPy
[AutoNumPy]
public static NDArray stack(params NDArray[] arrays) => new NDArray(array_ops.stack(arrays));
+ [AutoNumPy]
+ public static NDArray stack(NDArray[] arrays, int axis = 0) => new NDArray(array_ops.stack(arrays, axis));
+
+ [AutoNumPy]
+ public static NDArray stack((NDArray, NDArray) tuple, int axis = 0) => new NDArray(array_ops.stack(new[] { tuple.Item1, tuple.Item2 }, axis));
+
+ [AutoNumPy]
+ public static NDArray stack((NDArray, NDArray, NDArray) tuple, int axis = 0) => new NDArray(array_ops.stack(new[] { tuple.Item1, tuple.Item2, tuple.Item3 }, axis));
+
[AutoNumPy]
public static NDArray moveaxis(NDArray array, Axis source, Axis destination) => new NDArray(array_ops.moveaxis(array, source, destination));
}
diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs
index e59c381c..2105c53f 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.cs
@@ -437,7 +437,7 @@ namespace Tensorflow
internal void _set_attr_with_buf(string attr_name, Buffer attr_buf)
{
Status status = new();
- c_api.TFC_SetAttr(graph, _handle, attr_name, attr_buf, status);
+ c_api.TF_SetAttr(graph, _handle, attr_name, attr_buf, status);
status.Check(true);
}
}
diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs
index f80dcd2c..548a885e 100644
--- a/src/TensorFlowNET.Core/Operations/array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/array_ops.cs
@@ -166,6 +166,11 @@ namespace Tensorflow
throw new ValueError("mask cannot be scalar.");
var leading_size = gen_math_ops.prod(shape(tensor_tensor)[$"{axis}:{axis + ndims_mask}"], ops.convert_to_tensor(new[] { 0 }));
+ if (leading_size.rank == 0)
+ {
+ leading_size = expand_dims(leading_size, 0);
+ }
+
var shape1 = concat(new[]
{
shape(tensor_tensor)[$":{axis}"],
@@ -185,7 +190,7 @@ namespace Tensorflow
private static Tensor _apply_mask_1d(Tensor reshaped_tensor, Tensor mask, int axis = 0)
{
- var indices = squeeze(where(mask), axis: new[] { 1 });
+ var indices = squeeze(where_v2(mask), axis: new[] { 1 });
return gather(reshaped_tensor, indices, axis: ops.convert_to_tensor(axis));
}
@@ -829,7 +834,7 @@ namespace Tensorflow
/// A `Tensor`. Has the same type as `input`.
/// Contains the same data as `input`, but has one or more dimensions of
/// size 1 removed.
- public static Tensor squeeze(Tensor input, int[] axis = null, string name = null)
+ public static Tensor squeeze(Tensor input, Axis axis = null, string name = null)
=> gen_array_ops.squeeze(input, axis, name);
public static Tensor identity(Tensor input, string name = null)
@@ -990,7 +995,7 @@ namespace Tensorflow
return @params.sparse_read(indices, name);
}
- public static Tensor transpose(T1 a, Axis perm, string name = "transpose", bool conjugate = false)
+ public static Tensor transpose(T1 a, Axis perm = null, string name = "transpose", bool conjugate = false)
{
return tf_with(ops.name_scope(name, "transpose", new { a }), scope =>
{
@@ -1139,5 +1144,18 @@ namespace Tensorflow
var _op = tf.OpDefLib._apply_op_helper("Placeholder", name: name, args: new { dtype, shape });
return _op.output;
}
+
+ public static int get_positive_axis(int axis, int ndims=-100, string axis_name="axis", string ndims_name= "ndims")
+ {
+ if(ndims != -100)
+ {
+ if (axis >= 0 && axis < ndims) return axis;
+ else if (-ndims <= axis && axis < 0) return axis + ndims;
+ else throw new ValueError($"{axis_name}={axis} out of bounds:expected {-ndims}<={axis_name}<{ndims}");
+
+ } else if(axis < 0) throw new ValueError($"{axis_name}={axis} may only be negative if {ndims_name} is statically known.");
+ return axis;
+ }
+
}
}
diff --git a/src/TensorFlowNET.Core/Operations/handle_data_util.cs b/src/TensorFlowNET.Core/Operations/handle_data_util.cs
index a01efc52..363d3144 100644
--- a/src/TensorFlowNET.Core/Operations/handle_data_util.cs
+++ b/src/TensorFlowNET.Core/Operations/handle_data_util.cs
@@ -51,7 +51,7 @@ namespace Tensorflow.Operations
}
Status status = new();
var proto = handle_data.ToByteArray();
- c_api.TFC_SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), proto, proto.Length, status);
+ c_api.TF_SetHandleShapeAndType(target_t.graph.c_graph, target_t._as_tf_output(), proto, proto.Length, status);
status.Check(true);
}
diff --git a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs
index 318b8b14..f1aff28e 100644
--- a/src/TensorFlowNET.Core/Operations/image_ops_impl.cs
+++ b/src/TensorFlowNET.Core/Operations/image_ops_impl.cs
@@ -102,7 +102,10 @@ namespace Tensorflow
{
throw new ValueError("\'image\' must be fully defined.");
}
- var dims = image_shape["-3:"];
+ var dims = new Shape(new[] {
+ image_shape.dims[image_shape.dims.Length - 3],
+ image_shape.dims[image_shape.dims.Length - 2],
+ image_shape.dims[image_shape.dims.Length - 1]});
foreach (var dim in dims.dims)
{
if (dim == 0)
@@ -112,16 +115,18 @@ namespace Tensorflow
}
var image_shape_last_three_elements = new Shape(new[] {
- image_shape.dims[image_shape.dims.Length - 1],
+ image_shape.dims[image_shape.dims.Length - 3],
image_shape.dims[image_shape.dims.Length - 2],
- image_shape.dims[image_shape.dims.Length - 3]});
+ image_shape.dims[image_shape.dims.Length - 1]});
if (!image_shape_last_three_elements.IsFullyDefined)
{
Tensor image_shape_ = array_ops.shape(image);
- var image_shape_return = tf.constant(new[] {
- image_shape_.dims[image_shape.dims.Length - 1],
- image_shape_.dims[image_shape.dims.Length - 2],
- image_shape_.dims[image_shape.dims.Length - 3]});
+ var image_shape_return = tf.slice(image_shape_, new[] { Math.Max(image_shape.dims.Length - 3, 0) }, new[] { 3 });
+
+ //var image_shape_return = tf.constant(new[] {
+ // image_shape_.dims[image_shape_.dims.Length - 3],
+ // image_shape_.dims[image_shape_.dims.Length - 2],
+ // image_shape_.dims[image_shape_.dims.Length - 1]});
return new Operation[] {
check_ops.assert_positive(
@@ -209,10 +214,10 @@ namespace Tensorflow
}
public static Tensor flip_left_right(Tensor image)
- => _flip(image, 0, "flip_left_right");
+ => _flip(image, 1, "flip_left_right");
public static Tensor flip_up_down(Tensor image)
- => _flip(image, 1, "flip_up_down");
+ => _flip(image, 0, "flip_up_down");
internal static Tensor _flip(Tensor image, int flip_index, string scope_name)
{
@@ -223,11 +228,11 @@ namespace Tensorflow
Shape shape = image.shape;
if (shape.ndim == 3 || shape.ndim == Unknown)
{
- return fix_image_flip_shape(image, gen_array_ops.reverse(image, ops.convert_to_tensor(new int[] { flip_index })));
+ return fix_image_flip_shape(image, gen_array_ops.reverse_v2(image, ops.convert_to_tensor(new int[] { flip_index })));
}
else if (shape.ndim == 4)
{
- return gen_array_ops.reverse_v2(image, ops.convert_to_tensor(new[] { (flip_index + 1) % 2 }));
+ return gen_array_ops.reverse_v2(image, ops.convert_to_tensor(new[] { flip_index + 1 }));
}
else
{
@@ -2047,6 +2052,22 @@ new_height, new_width");
});
}
+ public static Tensor encode_jpeg(Tensor contents, string name = null)
+ {
+ return tf_with(ops.name_scope(name, "encode_jpeg"), scope =>
+ {
+ return gen_ops.encode_jpeg(contents, name:name);
+ });
+ }
+
+ public static Tensor encode_png(Tensor contents, string name = null)
+ {
+ return tf_with(ops.name_scope(name, "encode_png"), scope =>
+ {
+ return gen_ops.encode_png(contents, name: name);
+ });
+ }
+
public static Tensor is_jpeg(Tensor contents, string name = null)
{
return tf_with(ops.name_scope(name, "is_jpeg"), scope =>
diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
index be714618..42c0399d 100644
--- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
+++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
@@ -4,8 +4,8 @@
netstandard2.0;net6.0
Tensorflow.Binding
Tensorflow
- 2.11.0
- 0.110.3
+ 2.15.0
+ 0.150.0
10.0
enable
Haiping Chen, Eli Belash, Yaohui Liu, Meinrad Recheis
@@ -20,12 +20,16 @@
Google's TensorFlow full binding in .NET Standard.
Building, training and infering deep learning models.
https://tensorflownet.readthedocs.io
- 0.110.3.0
+ 0.150.0.0
+ tf.net 0.150.x and above are based on tensorflow native 2.15.0
+ * Support BERT model.
+
tf.net 0.110.x and above are based on tensorflow native 2.11.0
* Support RNN, LSTM model.
* Support Transformer model.
-
+ * Added IMDB dataset.
+
tf.net 0.100.x and above are based on tensorflow native 2.10.0
* Eager Mode is added finally.
@@ -42,8 +46,9 @@ https://tensorflownet.readthedocs.io
tf.net 0.7x.x aligns with TensorFlow v2.7.x native library.
tf.net 0.10x.x aligns with TensorFlow v2.10.x native library.
tf.net 0.11x.x aligns with TensorFlow v2.11.x native library.
+ tf.net 0.15x.x aligns with TensorFlow v2.15.x native library.
- 0.110.3.0
+ 0.150.0.0
LICENSE
true
packages
@@ -174,8 +179,8 @@ https://tensorflownet.readthedocs.io
-
-
+
+
diff --git a/src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs b/src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs
index 4f85e108..0f09d412 100644
--- a/src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs
@@ -163,5 +163,38 @@ namespace Tensorflow
{
return tensor.Tag as RaggedTensor;
}
+ public Tensor nrows(TF_DataType out_type, string name = null)
+ {
+ tf_with(ops.name_scope(name, "RaggedNRows"), scope =>
+ {
+ return math_ops.cast(this._row_partition.nrows(), dtype: out_type);
+ });
+ return null;
+ }
+ public RaggedTensor row_lengths(int axis=-1, string name=null)
+ {
+ if (axis == 0) return this._row_partition.nrows();
+ if (axis == 1) return this._row_partition.row_lengths();
+ var values = (RaggedTensor)this._values;
+ axis = array_ops.get_positive_axis(
+ axis, this.shape.rank, ndims_name: "rank(this)");
+ if (axis == 0) return this.nrows(this._row_partition.GetDataType());
+ else if (axis == 1)
+ {
+ var splits = this._row_partition.row_splits;
+ return splits[new Slice(start: 1)] - splits[new Slice(stop: -1)];
+
+ }
+ else if (this._values is RaggedTensor)
+ {
+ return values.row_lengths(axis - 1);
+ }
+ else
+ {
+ var shape = array_ops.shape(values, out_type: this._row_partition.GetDataType());
+ return array_ops.ones(shape[new Slice(stop:axis - 1)], this._row_partition.GetDataType()) *
+ shape[axis - 1];
+ }
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs b/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs
index 29dc525d..9e242ff3 100644
--- a/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs
+++ b/src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs
@@ -14,10 +14,15 @@
limitations under the License.
******************************************************************************/
+using Serilog.Debugging;
using System;
+using System.Collections.Concurrent;
using System.Collections.Generic;
+//using System.ComponentModel.DataAnnotations;
using System.Text;
+using System.Xml.Linq;
using Tensorflow.Framework;
+using Tensorflow.NumPy;
using static Tensorflow.Binding;
namespace Tensorflow
@@ -99,5 +104,55 @@ namespace Tensorflow
return new RowPartition(row_splits);
});
}
+
+ public static RowPartition from_row_lengths(Tensor row_lengths,
+ bool validate=true,
+ TF_DataType dtype = TF_DataType.TF_INT32,
+ TF_DataType dtype_hint= TF_DataType.TF_INT32)
+ {
+ row_lengths = _convert_row_partition(
+ row_lengths, "row_lengths", dtype_hint: dtype_hint, dtype: dtype);
+ Tensor row_limits = math_ops.cumsum(row_lengths, tf.constant(-1));
+ Tensor row_splits = array_ops.concat(new Tensor[] { tf.convert_to_tensor(np.array(new int[] { 0 }, TF_DataType.TF_INT64)), row_limits }, axis:0);
+ return new RowPartition(row_splits: row_splits, row_lengths: row_lengths);
+ }
+
+ public static Tensor _convert_row_partition(Tensor partition, string name, TF_DataType dtype,
+ TF_DataType dtype_hint= TF_DataType.TF_INT64)
+ {
+ if (partition is NDArray && partition.GetDataType() == np.int32) partition = ops.convert_to_tensor(partition, name: name);
+ if (partition.GetDataType() != np.int32 && partition.GetDataType() != np.int64) throw new ValueError($"{name} must have dtype int32 or int64");
+ return partition;
+ }
+
+ public Tensor nrows()
+ {
+ /*Returns the number of rows created by this `RowPartition*/
+ if (this._nrows != null) return this._nrows;
+ var nsplits = tensor_shape.dimension_at_index(this._row_splits.shape, 0);
+ if (nsplits == null) return array_ops.shape(this._row_splits, out_type: this.row_splits.dtype)[0] - 1;
+ else return constant_op.constant(nsplits.value - 1, dtype: this.row_splits.dtype);
+ }
+
+ public Tensor row_lengths()
+ {
+
+ if (this._row_splits != null)
+ {
+ int nrows_plus_one = tensor_shape.dimension_value(this._row_splits.shape[0]);
+ return tf.constant(nrows_plus_one - 1);
+
+ }
+ if (this._row_lengths != null)
+ {
+ var nrows = tensor_shape.dimension_value(this._row_lengths.shape[0]);
+ return tf.constant(nrows);
+ }
+ if(this._nrows != null)
+ {
+ return tensor_util.constant_value(this._nrows);
+ }
+ return tf.constant(-1);
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
index e65c4850..f688d4d5 100644
--- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs
+++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs
@@ -249,6 +249,9 @@ namespace Tensorflow
case sbyte val:
tensor_proto.IntVal.AddRange(new[] { (int)val });
break;
+ case byte val:
+ tensor_proto.IntVal.AddRange(new[] { (int)val });
+ break;
case int val:
tensor_proto.IntVal.AddRange(new[] { val });
break;
@@ -262,7 +265,7 @@ namespace Tensorflow
tensor_proto.DoubleVal.AddRange(new[] { val });
break;
default:
- throw new Exception("make_tensor_proto Not Implemented");
+ throw new Exception($"make_tensor_proto Not Implemented {values.GetType().Name}");
}
}
diff --git a/src/TensorFlowNET.Core/Util/Data.cs b/src/TensorFlowNET.Core/Util/Data.cs
new file mode 100644
index 00000000..a14c69b1
--- /dev/null
+++ b/src/TensorFlowNET.Core/Util/Data.cs
@@ -0,0 +1,66 @@
+using Tensorflow.NumPy;
+
+namespace Tensorflow.Util
+{
+ ///
+ /// ValidationDataPack is used to pass validation data to fit method.
+ /// It can recive data which could be A tuple `(x_val, xy_val)` or `(x_val, y_val, sample_weight_val)` of Numpy arrays.
+ ///
+ public class ValidationDataPack
+ {
+ public NDArray val_x;
+ public NDArray val_y;
+ public NDArray val_sample_weight = null;
+
+ public ValidationDataPack((NDArray, NDArray) validation_data)
+ {
+ this.val_x = validation_data.Item1;
+ this.val_y = validation_data.Item2;
+ }
+
+ public ValidationDataPack((NDArray, NDArray, NDArray) validation_data)
+ {
+ this.val_x = validation_data.Item1;
+ this.val_y = validation_data.Item2;
+ this.val_sample_weight = validation_data.Item3;
+ }
+
+ public ValidationDataPack((IEnumerable, NDArray) validation_data)
+ {
+ this.val_x = validation_data.Item1.ToArray()[0];
+ this.val_y = validation_data.Item2;
+ }
+
+ public ValidationDataPack((IEnumerable, NDArray, NDArray) validation_data)
+ {
+ this.val_x = validation_data.Item1.ToArray()[0];
+ this.val_y = validation_data.Item2;
+ this.val_sample_weight = validation_data.Item3;
+ }
+
+ public static implicit operator ValidationDataPack((NDArray, NDArray) validation_data)
+ => new ValidationDataPack(validation_data);
+
+ public static implicit operator ValidationDataPack((NDArray, NDArray, NDArray) validation_data)
+ => new ValidationDataPack(validation_data);
+
+ public static implicit operator ValidationDataPack((IEnumerable, NDArray) validation_data)
+ => new ValidationDataPack(validation_data);
+
+ public static implicit operator ValidationDataPack((IEnumerable, NDArray, NDArray) validation_data)
+ => new ValidationDataPack(validation_data);
+
+ public void Deconstruct(out NDArray val_x, out NDArray val_y)
+ {
+ val_x = this.val_x;
+ val_y = this.val_y;
+ }
+
+ public void Deconstruct(out NDArray val_x, out NDArray val_y, out NDArray val_sample_weight)
+ {
+ val_x = this.val_x;
+ val_y = this.val_y;
+ val_sample_weight = this.val_sample_weight;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs
index 351fd18f..6f51150a 100644
--- a/src/TensorFlowNET.Core/ops.cs
+++ b/src/TensorFlowNET.Core/ops.cs
@@ -590,7 +590,7 @@ namespace Tensorflow
public static HandleData get_resource_handle_data(Tensor graph_op)
{
- var handle_data = c_api.TFC_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output());
+ var handle_data = c_api.TF_GetHandleShapeAndType(graph_op.graph.c_graph, graph_op._as_tf_output());
try{
var handle_str = c_api.ByteStringPiece(handle_data.DangerousGetHandle() == IntPtr.Zero ? null : new Buffer(handle_data));
return HandleData.Parser.ParseFrom(handle_str);
diff --git a/src/TensorFlowNET.Keras/Activations.cs b/src/TensorFlowNET.Keras/Activations.cs
index ce5b4eb1..d3801902 100644
--- a/src/TensorFlowNET.Keras/Activations.cs
+++ b/src/TensorFlowNET.Keras/Activations.cs
@@ -20,6 +20,11 @@ namespace Tensorflow.Keras
Name = "relu",
ActivationFunction = (features, name) => tf.Context.ExecuteOp("Relu", name, new ExecuteOpArgs(features))
};
+ private static Activation _relu6 = new Activation()
+ {
+ Name = "relu6",
+ ActivationFunction = (features, name) => tf.Context.ExecuteOp("Relu6", name, new ExecuteOpArgs(features))
+ };
private static Activation _sigmoid = new Activation()
{
Name = "sigmoid",
@@ -55,6 +60,7 @@ namespace Tensorflow.Keras
_nameActivationMap = new Dictionary();
RegisterActivation(_relu);
+ RegisterActivation(_relu6);
RegisterActivation(_linear);
RegisterActivation(_sigmoid);
RegisterActivation(_softmax);
@@ -65,6 +71,7 @@ namespace Tensorflow.Keras
public Activation Linear => _linear;
public Activation Relu => _relu;
+ public Activation Relu6 => _relu6;
public Activation Sigmoid => _sigmoid;
diff --git a/src/TensorFlowNET.Keras/Datasets/Imdb.cs b/src/TensorFlowNET.Keras/Datasets/Imdb.cs
index 1c980518..4d6df913 100644
--- a/src/TensorFlowNET.Keras/Datasets/Imdb.cs
+++ b/src/TensorFlowNET.Keras/Datasets/Imdb.cs
@@ -112,35 +112,39 @@ namespace Tensorflow.Keras.Datasets
if (start_char != null)
{
- int[,] new_x_train_array = new int[x_train_array.GetLength(0), x_train_array.GetLength(1) + 1];
- for (var i = 0; i < x_train_array.GetLength(0); i++)
+ var (d1, d2) = (x_train_array.GetLength(0), x_train_array.GetLength(1));
+ int[,] new_x_train_array = new int[d1, d2 + 1];
+ for (var i = 0; i < d1; i++)
{
new_x_train_array[i, 0] = (int)start_char;
- Array.Copy(x_train_array, i * x_train_array.GetLength(1), new_x_train_array, i * new_x_train_array.GetLength(1) + 1, x_train_array.GetLength(1));
+ Array.Copy(x_train_array, i * d2, new_x_train_array, i * (d2 + 1) + 1, d2);
}
- int[,] new_x_test_array = new int[x_test_array.GetLength(0), x_test_array.GetLength(1) + 1];
- for (var i = 0; i < x_test_array.GetLength(0); i++)
+ (d1, d2) = (x_test_array.GetLength(0), x_test_array.GetLength(1));
+ int[,] new_x_test_array = new int[d1, d2 + 1];
+ for (var i = 0; i < d1; i++)
{
new_x_test_array[i, 0] = (int)start_char;
- Array.Copy(x_test_array, i * x_test_array.GetLength(1), new_x_test_array, i * new_x_test_array.GetLength(1) + 1, x_test_array.GetLength(1));
+ Array.Copy(x_test_array, i * d2, new_x_test_array, i * (d2 + 1) + 1, d2);
}
x_train_array = new_x_train_array;
x_test_array = new_x_test_array;
}
else if (index_from != 0)
{
- for (var i = 0; i < x_train_array.GetLength(0); i++)
+ var (d1, d2) = (x_train_array.GetLength(0), x_train_array.GetLength(1));
+ for (var i = 0; i < d1; i++)
{
- for (var j = 0; j < x_train_array.GetLength(1); j++)
+ for (var j = 0; j < d2; j++)
{
if (x_train_array[i, j] == 0)
break;
x_train_array[i, j] += index_from;
}
}
- for (var i = 0; i < x_test_array.GetLength(0); i++)
+ (d1, d2) = (x_test_array.GetLength(0), x_test_array.GetLength(1));
+ for (var i = 0; i < d1; i++)
{
- for (var j = 0; j < x_test_array.GetLength(1); j++)
+ for (var j = 0; j < d2; j++)
{
if (x_test_array[i, j] == 0)
break;
@@ -169,9 +173,10 @@ namespace Tensorflow.Keras.Datasets
if (num_words == null)
{
+ var (d1, d2) = (xs_array.GetLength(0), xs_array.GetLength(1));
num_words = 0;
- for (var i = 0; i < xs_array.GetLength(0); i++)
- for (var j = 0; j < xs_array.GetLength(1); j++)
+ for (var i = 0; i < d1; i++)
+ for (var j = 0; j < d2; j++)
num_words = max((int)num_words, (int)xs_array[i, j]);
}
diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs
index 6c7d53b2..b2750496 100644
--- a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs
+++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataAdapter.cs
@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Util;
namespace Tensorflow.Keras.Engine.DataAdapters
{
@@ -34,9 +35,67 @@ namespace Tensorflow.Keras.Engine.DataAdapters
return (x, y);
}
+ public virtual (Tensors, Tensors, Tensors) Expand1d(Tensors x, Tensors y, Tensors sample_weight)
+ {
+ for (int i = 0; i < x.Length; i++)
+ {
+ if (x[i].shape.ndim == 1)
+ x[i] = array_ops.expand_dims(x[i], axis: -1);
+ }
+ for (int i = 0; i < y.Length; i++)
+ {
+ if (y[i].shape.ndim == 1)
+ y[i] = array_ops.expand_dims(y[i], axis: -1);
+ }
+ for (int i = 0; i < sample_weight.Length; i++)
+ {
+ if (sample_weight[i].shape.ndim == 1)
+ sample_weight[i] = array_ops.expand_dims(sample_weight[i], axis: -1);
+ }
+ return (x, y, sample_weight);
+ }
+
public virtual bool ShouldRecreateIterator()
{
return true;
}
+
+ public static ((NDArray, NDArray, NDArray),ValidationDataPack) train_validation_split((NDArray, NDArray, NDArray) x_y_sample_weight, float validation_split)
+ {
+ var x = x_y_sample_weight.Item1;
+ var y = x_y_sample_weight.Item2;
+ var sample_weight = x_y_sample_weight.Item3;
+ int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split));
+ var train_x = x[new Slice(0, train_count)];
+ var train_y = y[new Slice(0, train_count)];
+ ValidationDataPack validation_data;
+ if (sample_weight != null)
+ {
+ validation_data = (x[new Slice(train_count)], y[new Slice(train_count)], sample_weight[new Slice(train_count)]);
+ sample_weight = sample_weight[new Slice(0, train_count)];
+ }
+ else
+ {
+ validation_data = (x[new Slice(train_count)], y[new Slice(train_count)]);
+ }
+
+ return ((train_x, train_y, sample_weight), validation_data);
+ }
+
+ public static ((IEnumerable, NDArray, NDArray), ValidationDataPack) train_validation_split((IEnumerable, NDArray, NDArray) x_y_sample_weight, float validation_split)
+ {
+ var x = x_y_sample_weight.Item1;
+ var y = x_y_sample_weight.Item2;
+ var sample_weight = x_y_sample_weight.Item3;
+ int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split));
+ var train_x = x.Select(x => x[new Slice(0, train_count)] as NDArray);
+ var train_y = y[new Slice(0, train_count)];
+ var val_x = x.Select(x => x[new Slice(train_count)] as NDArray);
+ var val_y = y[new Slice(train_count)];
+ NDArray tmp_sample_weight = sample_weight;
+ sample_weight = sample_weight[new Slice(0, train_count)];
+ ValidationDataPack validation_data = (val_x, val_y, tmp_sample_weight[new Slice(train_count)]);
+ return ((train_x, train_y, sample_weight), validation_data);
+ }
}
}
diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
index 4723222f..a305e503 100644
--- a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
+++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
@@ -2,6 +2,9 @@
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using static Tensorflow.Binding;
+using Tensorflow.Keras.Utils;
+using Tensorflow.Util;
+using Tensorflow.Framework;
namespace Tensorflow.Keras.Engine.DataAdapters
{
@@ -23,11 +26,13 @@ namespace Tensorflow.Keras.Engine.DataAdapters
long _steps_per_execution_value;
int _initial_epoch => args.InitialEpoch;
int _epochs => args.Epochs;
+ NDArray _sample_weight => args.SampleWeight;
IVariableV1 _steps_per_execution;
public DataHandler(DataHandlerArgs args)
{
this.args = args;
+
if (args.StepsPerExecution == null)
{
_steps_per_execution = tf.Variable(1L);
@@ -48,6 +53,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
BatchSize = args.BatchSize,
Steps = args.StepsPerEpoch,
Epochs = args.Epochs - args.InitialEpoch,
+ SampleWeight = args.SampleWeight,
Shuffle = args.Shuffle,
MaxQueueSize = args.MaxQueueSize,
Worker = args.Workers,
@@ -72,10 +78,75 @@ namespace Tensorflow.Keras.Engine.DataAdapters
}
_dataset = _adapter.GetDataset();
- _inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset);
_current_step = 0;
_step_increment = _steps_per_execution_value - 1;
_insufficient_data = false;
+ _configure_dataset_and_inferred_steps(args.X, args.ClassWeight);
+ }
+
+ void _configure_dataset_and_inferred_steps(Tensors x, Dictionary class_weight)
+ {
+ if (_dataset == null)
+ {
+ _dataset = _adapter.GetDataset();
+ _inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset);
+ }
+
+ if (class_weight != null)
+ {
+ _dataset = _dataset.map(_make_class_weight_map_fn(class_weight));
+ }
+ _inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset);
+ }
+
+
+ Func _make_class_weight_map_fn(Dictionary class_weight)
+ {
+ var class_ids = class_weight.Keys.OrderBy(key => key).ToList();
+ var expected_class_ids = range(class_ids[0], class_ids[class_ids.Count - 1] + 1);
+ if (!class_ids.SequenceEqual(expected_class_ids))
+ {
+ throw new ValueError("Expected `class_weight` to be a dict with keys from 0 to one less "+
+ $"than the number of classes, found {class_weight}");
+ }
+
+ var class_weight_list = new List();
+ foreach (var class_id in class_ids)
+ {
+ class_weight_list.Add(class_weight[class_id]);
+ }
+ var class_weight_tensor = tf.convert_to_tensor(class_weight_list.ToArray());
+
+ Func _class_weight_map_fn = (Tensors data) =>
+ {
+ var x = data[0];
+ var y = data[1];
+ var sw = _sample_weight == null ? null : ops.convert_to_tensor(_sample_weight);
+
+ if (y.shape.rank > 2)
+ {
+ throw new ValueError("`class_weight` not supported for 3+ dimensional targets.");
+ }
+
+ var y_classes = smart_module.smart_cond(
+ y.shape.rank == 2 && y.shape[1] > 1,
+ () => math_ops.argmax(y, dimension: 1),
+ () => math_ops.cast(tf.reshape(y, (-1)), TF_DataType.TF_INT64));
+
+ var cw = array_ops.gather(class_weight_tensor, y_classes);
+ if (sw != null)
+ {
+ cw = tf.cast(cw, sw.dtype);
+ cw *= sw;
+ }
+ else
+ {
+ sw = cw;
+ }
+ return new Tensors { x, y, sw };
+ };
+
+ return _class_weight_map_fn;
}
long _infer_steps(int steps_per_epoch, IDatasetV2 dataset)
diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs
index 4bdc4979..bb71b0a2 100644
--- a/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs
+++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/IDataAdapter.cs
@@ -17,6 +17,8 @@
IDatasetV2 GetDataset();
int GetSize();
(Tensors, Tensors) Expand1d(Tensors x, Tensors y);
+ (Tensors, Tensors, Tensors) Expand1d(Tensors x, Tensors y, Tensors sample_weight);
+
bool ShouldRecreateIterator();
}
}
diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
index 16e646a3..978a3f51 100644
--- a/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
+++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs
@@ -20,7 +20,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
public TensorLikeDataAdapter(DataAdapterArgs args)
{
this.args = args;
- _process_tensorlike();
+ Tensor sample_weight_tensor = args.SampleWeight != null ? _process_tensorlike(args.SampleWeight) : null;
num_samples = (int)args.X.shape[0];
var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize;
_batch_size = batch_size;
@@ -37,6 +37,8 @@ namespace Tensorflow.Keras.Engine.DataAdapters
inputs.AddRange(args.X);
if (args.Y != null)
inputs.AddRange(args.Y);
+ if (sample_weight_tensor != null)
+ inputs.Add(sample_weight_tensor);
dataset = slice_inputs(indices_dataset, inputs);
dataset.FirstInputTensorCount = args.X.Length;
}
@@ -94,8 +96,9 @@ namespace Tensorflow.Keras.Engine.DataAdapters
public override bool ShouldRecreateIterator() => false;
- void _process_tensorlike()
+ Tensor _process_tensorlike(NDArray sample_weights)
{
+ return tf.convert_to_tensor(sample_weights);
}
}
}
diff --git a/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs b/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs
index 7b826af8..375fc910 100644
--- a/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs
+++ b/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs
@@ -30,7 +30,7 @@ namespace Tensorflow.Keras.Engine
created_layers = created_layers ?? new Dictionary();
var node_index_map = new Dictionary<(string, int), int>();
var node_count_by_layer = new Dictionary();
- var unprocessed_nodes = new Dictionary();
+ var unprocessed_nodes = new Dictionary>();
// First, we create all layers and enqueue nodes to be processed
foreach (var layer_data in config.Layers)
process_layer(created_layers, layer_data, unprocessed_nodes, node_count_by_layer);
@@ -79,7 +79,7 @@ namespace Tensorflow.Keras.Engine
static void process_layer(Dictionary created_layers,
LayerConfig layer_data,
- Dictionary unprocessed_nodes,
+ Dictionary> unprocessed_nodes,
Dictionary node_count_by_layer)
{
ILayer layer = null;
@@ -92,32 +92,38 @@ namespace Tensorflow.Keras.Engine
created_layers[layer_name] = layer;
}
- node_count_by_layer[layer] = _should_skip_first_node(layer) ? 1 : 0;
+ node_count_by_layer[layer] = layer_data.InboundNodes.Count - (_should_skip_first_node(layer) ? 1 : 0);
var inbound_nodes_data = layer_data.InboundNodes;
foreach (var node_data in inbound_nodes_data)
{
if (!unprocessed_nodes.ContainsKey(layer))
- unprocessed_nodes[layer] = node_data;
+ unprocessed_nodes[layer] = new List() { node_data };
else
- unprocessed_nodes.Add(layer, node_data);
+ unprocessed_nodes[layer].Add(node_data);
}
}
static void process_node(ILayer layer,
- NodeConfig node_data,
+ List nodes_data,
Dictionary created_layers,
Dictionary node_count_by_layer,
Dictionary<(string, int), int> node_index_map)
{
+
var input_tensors = new List();
- var inbound_layer_name = node_data.Name;
- var inbound_node_index = node_data.NodeIndex;
- var inbound_tensor_index = node_data.TensorIndex;
- var inbound_layer = created_layers[inbound_layer_name];
- var inbound_node = inbound_layer.InboundNodes[inbound_node_index];
- input_tensors.Add(inbound_node.Outputs[inbound_node_index]);
+ for (int i = 0; i < nodes_data.Count; i++)
+ {
+ var node_data = nodes_data[i];
+ var inbound_layer_name = node_data.Name;
+ var inbound_node_index = node_data.NodeIndex;
+ var inbound_tensor_index = node_data.TensorIndex;
+
+ var inbound_layer = created_layers[inbound_layer_name];
+ var inbound_node = inbound_layer.InboundNodes[inbound_node_index];
+ input_tensors.Add(inbound_node.Outputs[inbound_node_index]);
+ }
var output_tensors = layer.Apply(input_tensors);
diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs b/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs
index ed5c2de0..49811417 100644
--- a/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs
+++ b/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs
@@ -27,6 +27,6 @@ public abstract partial class Layer
children = new Dictionary();
}
- return children.Concat(base._trackable_children(save_type, cache)).ToDictionary(x => x.Key, x => x.Value);
+ return children.Concat(base._trackable_children(save_type, cache)).GroupBy(x => x.Key).Select(g => g.First()).ToDictionary(x => x.Key, x => x.Value);
}
}
\ No newline at end of file
diff --git a/src/TensorFlowNET.Keras/Engine/LossesContainer.cs b/src/TensorFlowNET.Keras/Engine/LossesContainer.cs
index 6a91450d..c06fca59 100644
--- a/src/TensorFlowNET.Keras/Engine/LossesContainer.cs
+++ b/src/TensorFlowNET.Keras/Engine/LossesContainer.cs
@@ -26,11 +26,11 @@ namespace Tensorflow.Keras.Engine
///
///
///
- public Tensor Call(Tensor y_true, Tensor y_pred)
+ public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
{
if (!_built)
Build(y_pred);
- var loss_value = _losses.Call(y_true, y_pred);
+ var loss_value = _losses.Call(y_true, y_pred, sample_weight:sample_weight);
var loss_metric_value = loss_value;
var batch_dim = array_ops.shape(y_true)[0];
diff --git a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
index a74a77f1..474d5e5a 100644
--- a/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
+++ b/src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
@@ -30,6 +30,7 @@ namespace Tensorflow.Keras.Engine
public Dictionary evaluate(NDArray x, NDArray y,
int batch_size = -1,
int verbose = 1,
+ NDArray sample_weight = null,
int steps = -1,
int max_queue_size = 10,
int workers = 1,
@@ -51,6 +52,7 @@ namespace Tensorflow.Keras.Engine
StepsPerEpoch = steps,
InitialEpoch = 0,
Epochs = 1,
+ SampleWeight = sample_weight,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
@@ -130,6 +132,7 @@ namespace Tensorflow.Keras.Engine
var end_step = step + data_handler.StepIncrement;
if (!is_val)
callbacks.on_test_batch_end(end_step, logs);
+ GC.Collect();
}
}
callbacks.on_test_end(logs);
@@ -140,7 +143,8 @@ namespace Tensorflow.Keras.Engine
Dictionary test_function(DataHandler data_handler, OwnedIterator iterator)
{
var data = iterator.next();
- var outputs = test_step(data_handler, data[0], data[1]);
+ var outputs = data.Length == 2 ? test_step(data_handler, data[0], data[1]) :
+ test_step(data_handler, data[0], data[1], data[2]);
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
return outputs;
}
@@ -149,7 +153,13 @@ namespace Tensorflow.Keras.Engine
{
var data = iterator.next();
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
- var outputs = test_step(data_handler, data.Take(x_size).ToArray(), data.Skip(x_size).ToArray());
+ var outputs = data.Length == 2 ?
+ test_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())) :
+ test_step(
+ data_handler,
+ new Tensors(data.Take(x_size).ToArray()),
+ new Tensors(data.Skip(x_size).Take(x_size).ToArray()),
+ new Tensors(data.Skip(2 * x_size).ToArray()));
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
return outputs;
}
@@ -157,11 +167,22 @@ namespace Tensorflow.Keras.Engine
Dictionary test_step(DataHandler data_handler, Tensors x, Tensors y)
{
- (x, y) = data_handler.DataAdapter.Expand1d(x, y);
+ (x,y) = data_handler.DataAdapter.Expand1d(x, y);
+
var y_pred = Apply(x, training: false);
+
var loss = compiled_loss.Call(y, y_pred);
compiled_metrics.update_state(y, y_pred);
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2);
}
+
+ Dictionary test_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight)
+ {
+ (x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight);
+ var y_pred = Apply(x, training: false);
+ var loss = compiled_loss.Call(y, y_pred, sample_weight: sample_weight);
+ compiled_metrics.update_state(y, y_pred);
+ return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2);
+ }
}
}
diff --git a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs
index d6f89d8b..d61211c7 100644
--- a/src/TensorFlowNET.Keras/Engine/Model.Fit.cs
+++ b/src/TensorFlowNET.Keras/Engine/Model.Fit.cs
@@ -6,10 +6,12 @@ using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine.DataAdapters;
using System.Diagnostics;
using Tensorflow.Keras.Callbacks;
-using System.Data;
+using Tensorflow.Util;
namespace Tensorflow.Keras.Engine
{
+
+
public partial class Model
{
///
@@ -19,19 +21,30 @@ namespace Tensorflow.Keras.Engine
///
///
///
- ///
///
+ ///
///
///
///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
public ICallback fit(NDArray x, NDArray y,
int batch_size = -1,
int epochs = 1,
int verbose = 1,
List callbacks = null,
float validation_split = 0f,
- (NDArray val_x, NDArray val_y)? validation_data = null,
+ ValidationDataPack validation_data = null,
+ int validation_step = 10,
bool shuffle = true,
+ Dictionary class_weight = null,
+ NDArray sample_weight = null,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
@@ -43,25 +56,24 @@ namespace Tensorflow.Keras.Engine
$"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}");
}
- var train_x = x;
- var train_y = y;
+ // The default dtype in NDArray is double, so we need to cast sample_weight to float to mul with loss which's dtype is float.
+ sample_weight = sample_weight?.astype(TF_DataType.TF_FLOAT);
if (validation_split != 0f && validation_data == null)
{
- int train_count = Convert.ToInt32(x.dims[0] * (1 - validation_split));
- train_x = x[new Slice(0, train_count)];
- train_y = y[new Slice(0, train_count)];
- validation_data = (val_x: x[new Slice(train_count)], val_y: y[new Slice(train_count)]);
+ ((x, y, sample_weight), validation_data) = DataAdapter.train_validation_split((x, y, sample_weight), validation_split);
}
var data_handler = new DataHandler(new DataHandlerArgs
{
- X = train_x,
- Y = train_y,
+ X = x,
+ Y = y,
+ SampleWeight = sample_weight,
BatchSize = batch_size,
InitialEpoch = initial_epoch,
Epochs = epochs,
Shuffle = shuffle,
+ ClassWeight = class_weight,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
@@ -73,14 +85,17 @@ namespace Tensorflow.Keras.Engine
train_step_func: train_step_function);
}
+
public ICallback fit(IEnumerable x, NDArray y,
int batch_size = -1,
int epochs = 1,
int verbose = 1,
List callbacks = null,
float validation_split = 0f,
- (IEnumerable val_x, NDArray val_y)? validation_data = null,
+ ValidationDataPack validation_data = null,
bool shuffle = true,
+ Dictionary class_weight = null,
+ NDArray sample_weight = null,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
@@ -95,27 +110,24 @@ namespace Tensorflow.Keras.Engine
}
}
- var train_x = x;
- var train_y = y;
+ sample_weight = sample_weight?.astype(TF_DataType.TF_FLOAT);
+
if (validation_split != 0f && validation_data == null)
{
- int train_count = Convert.ToInt32(y.dims[0] * (1 - validation_split));
- train_x = x.Select(x => x[new Slice(0, train_count)] as NDArray);
- train_y = y[new Slice(0, train_count)];
- var val_x = x.Select(x => x[new Slice(train_count)] as NDArray);
- var val_y = y[new Slice(train_count)];
- validation_data = (val_x, val_y);
+ ((x, y, sample_weight), validation_data) = DataAdapter.train_validation_split((x, y, sample_weight), validation_split);
}
var data_handler = new DataHandler(new DataHandlerArgs
{
- X = new Tensors(train_x.ToArray()),
- Y = train_y,
+ X = new Tensors(x.ToArray()),
+ Y = y,
+ SampleWeight = sample_weight,
BatchSize = batch_size,
InitialEpoch = initial_epoch,
Epochs = epochs,
Shuffle = shuffle,
+ ClassWeight = class_weight,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
@@ -136,14 +148,15 @@ namespace Tensorflow.Keras.Engine
}
}
- public History fit(IDatasetV2 dataset,
+ public ICallback fit(IDatasetV2 dataset,
int batch_size = -1,
int epochs = 1,
int verbose = 1,
List callbacks = null,
IDatasetV2 validation_data = null,
- int validation_step = 10, // 间隔多少次会进行一次验证
+ int validation_step = 10,
bool shuffle = true,
+ Dictionary class_weight = null,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
@@ -157,6 +170,7 @@ namespace Tensorflow.Keras.Engine
InitialEpoch = initial_epoch,
Epochs = epochs,
Shuffle = shuffle,
+ ClassWeight = class_weight,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
@@ -204,13 +218,14 @@ namespace Tensorflow.Keras.Engine
var end_step = step + data_handler.StepIncrement;
End_step = end_step;
callbacks.on_train_batch_end(end_step, logs);
+ GC.Collect();
}
if (validation_data != null)
{
if (validation_step > 0 && epoch ==0 || (epoch) % validation_step != 0)
continue;
-
+
var val_logs = evaluate(validation_data);
foreach(var log in val_logs)
{
@@ -219,11 +234,10 @@ namespace Tensorflow.Keras.Engine
callbacks.on_train_batch_end(End_step, logs);
}
+ GC.Collect();
callbacks.on_epoch_end(epoch, logs);
- GC.Collect();
- GC.WaitForPendingFinalizers();
if (stop_training)
{
break;
@@ -233,7 +247,7 @@ namespace Tensorflow.Keras.Engine
return callbacks.History;
}
- History FitInternal(DataHandler data_handler, int epochs, int verbose, List callbackList, (NDArray, NDArray)? validation_data,
+ History FitInternal(DataHandler data_handler, int epochs, int verbose, List callbackList, ValidationDataPack validation_data,
Func> train_step_func)
{
stop_training = false;
@@ -268,13 +282,15 @@ namespace Tensorflow.Keras.Engine
var end_step = step + data_handler.StepIncrement;
End_step = end_step;
callbacks.on_train_batch_end(end_step, logs);
+ GC.Collect();
}
if (validation_data != null)
{
// Because evaluate calls call_test_batch_end, this interferes with our output on the screen
// so we need to pass a is_val parameter to stop on_test_batch_end
- var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2, is_val:true);
+ var (val_x, val_y, val_sample_weight) = validation_data;
+ var val_logs = evaluate(val_x, val_y, sample_weight:val_sample_weight, is_val:true);
foreach (var log in val_logs)
{
logs["val_" + log.Key] = log.Value;
@@ -286,7 +302,6 @@ namespace Tensorflow.Keras.Engine
callbacks.on_epoch_end(epoch, logs);
GC.Collect();
- GC.WaitForPendingFinalizers();
if (stop_training)
{
break;
@@ -296,64 +311,5 @@ namespace Tensorflow.Keras.Engine
return callbacks.History;
}
- History FitInternal(DataHandler data_handler, int epochs, int verbose, List callbackList, (IEnumerable, NDArray)? validation_data,
- Func> train_step_func)
- {
- stop_training = false;
- _train_counter.assign(0);
- var callbacks = new CallbackList(new CallbackParams
- {
- Model = this,
- Verbose = verbose,
- Epochs = epochs,
- Steps = data_handler.Inferredsteps
- });
-
- if (callbackList != null)
- {
- foreach (var callback in callbackList)
- callbacks.callbacks.add(callback);
- }
-
- callbacks.on_train_begin();
-
- foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
- {
- reset_metrics();
- callbacks.on_epoch_begin(epoch);
- // data_handler.catch_stop_iteration();
- var logs = new Dictionary();
- long End_step = 0;
- foreach (var step in data_handler.steps())
- {
- callbacks.on_train_batch_begin(step);
- logs = train_step_func(data_handler, iterator);
- var end_step = step + data_handler.StepIncrement;
- End_step = end_step;
- callbacks.on_train_batch_end(end_step, logs);
- }
-
- if (validation_data != null)
- {
- var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2);
- foreach (var log in val_logs)
- {
- logs["val_" + log.Key] = log.Value;
- callbacks.on_train_batch_end(End_step, logs);
- }
- }
-
- callbacks.on_epoch_end(epoch, logs);
-
- GC.Collect();
- GC.WaitForPendingFinalizers();
- if (stop_training)
- {
- break;
- }
- }
-
- return callbacks.History;
- }
}
}
diff --git a/src/TensorFlowNET.Keras/Engine/Model.Predict.cs b/src/TensorFlowNET.Keras/Engine/Model.Predict.cs
index cbe4a729..e3a5aba6 100644
--- a/src/TensorFlowNET.Keras/Engine/Model.Predict.cs
+++ b/src/TensorFlowNET.Keras/Engine/Model.Predict.cs
@@ -102,9 +102,9 @@ namespace Tensorflow.Keras.Engine
for (int i = 0; i < batch_outputs.Length; i++)
batch_outputs[i] = tf.concat(new Tensor[] { batch_outputs[i], tmp_batch_outputs[i] }, axis: 0);
}
-
var end_step = step + data_handler.StepIncrement;
callbacks.on_predict_batch_end(end_step, new Dictionary { { "outputs", batch_outputs } });
+ GC.Collect();
}
}
diff --git a/src/TensorFlowNET.Keras/Engine/Model.Train.cs b/src/TensorFlowNET.Keras/Engine/Model.Train.cs
index ad3c70d2..8f1ec808 100644
--- a/src/TensorFlowNET.Keras/Engine/Model.Train.cs
+++ b/src/TensorFlowNET.Keras/Engine/Model.Train.cs
@@ -12,7 +12,9 @@ namespace Tensorflow.Keras.Engine
Dictionary train_step_function(DataHandler data_handler, OwnedIterator iterator)
{
var data = iterator.next();
- var outputs = train_step(data_handler, data[0], data[1]);
+ // whether have sample_weight
+ var outputs = data.Length == 2 ? train_step(data_handler, data[0], data[1]) :
+ train_step(data_handler, data[0], data[1], data[2]);
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
return outputs;
}
@@ -21,7 +23,13 @@ namespace Tensorflow.Keras.Engine
{
var data = iterator.next();
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
- var outputs = train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray()));
+ var outputs = data.Length == 2 ?
+ train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray())) :
+ train_step(
+ data_handler,
+ new Tensors(data.Take(x_size).ToArray()),
+ new Tensors(data.Skip(x_size).Take(x_size).ToArray()),
+ new Tensors(data.Skip(2 * x_size).ToArray()));
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
return outputs;
}
@@ -61,6 +69,34 @@ namespace Tensorflow.Keras.Engine
});
return dict;
}
+ Dictionary train_step(DataHandler data_handler, Tensors x, Tensors y, Tensors sample_weight = null)
+ {
+ (x, y, sample_weight) = data_handler.DataAdapter.Expand1d(x, y, sample_weight);
+ using var tape = tf.GradientTape();
+ var y_pred = Apply(x, training: true);
+ var loss = compiled_loss.Call(y, y_pred, sample_weight:sample_weight);
+
+ // For custom training steps, users can just write:
+ // trainable_variables = self.trainable_variables
+ // gradients = tape.gradient(loss, trainable_variables)
+ // self.optimizer.apply_gradients(zip(gradients, trainable_variables))
+ // The _minimize call does a few extra steps unnecessary in most cases,
+ // such as loss scaling and gradient clipping.
+ _minimize(tape, optimizer, loss, TrainableVariables);
+ compiled_metrics.update_state(y, y_pred);
+
+ var dict = new Dictionary();
+ metrics.ToList().ForEach(x =>
+ {
+ var r = x.result();
+ if (r.ndim > 0)
+ {
+ r = tf.reduce_mean(r);
+ }
+ dict[x.Name] = (float)r;
+ });
+ return dict;
+ }
void _minimize(GradientTape tape, IOptimizer optimizer, Tensor loss, List trainable_variables)
{
diff --git a/src/TensorFlowNET.Keras/Engine/Model.Training.cs b/src/TensorFlowNET.Keras/Engine/Model.Training.cs
index 50d934d9..457b3d69 100644
--- a/src/TensorFlowNET.Keras/Engine/Model.Training.cs
+++ b/src/TensorFlowNET.Keras/Engine/Model.Training.cs
@@ -10,8 +10,38 @@ namespace Tensorflow.Keras.Engine
{
public partial class Model
{
+ static Dictionary> weightsCache
+ = new Dictionary>();
+
public void load_weights(string filepath, bool by_name = false, bool skip_mismatch = false, object options = null)
{
+ // Get from cache
+ if (weightsCache.ContainsKey(filepath))
+ {
+ var filtered_layers = new List();
+ foreach (var layer in Layers)
+ {
+ var weights = hdf5_format._legacy_weights(layer);
+ if (weights.Count > 0)
+ filtered_layers.append(layer);
+ }
+
+ var weight_value_tuples = new List<(IVariableV1, NDArray)>();
+ filtered_layers.Select((layer, i) =>
+ {
+ var symbolic_weights = hdf5_format._legacy_weights(layer);
+ foreach(var weight in symbolic_weights)
+ {
+ var weight_value = weightsCache[filepath].First(x => x.Item1 == weight.Name).Item2;
+ weight_value_tuples.Add((weight, weight_value));
+ }
+ return layer;
+ }).ToList();
+
+ keras.backend.batch_set_value(weight_value_tuples);
+ return;
+ }
+
long fileId = Hdf5.OpenFile(filepath, true);
if(fileId < 0)
{
@@ -29,8 +59,11 @@ namespace Tensorflow.Keras.Engine
throw new NotImplementedException("");
else
{
- hdf5_format.load_weights_from_hdf5_group(fileId, Layers);
+ var weight_value_tuples = hdf5_format.load_weights_from_hdf5_group(fileId, Layers);
Hdf5.CloseFile(fileId);
+
+ weightsCache[filepath] = weight_value_tuples.Select(x => (x.Item1.Name, x.Item2)).ToList();
+ keras.backend.batch_set_value(weight_value_tuples);
}
}
diff --git a/src/TensorFlowNET.Keras/Layers/Activation/ReLu6.cs b/src/TensorFlowNET.Keras/Layers/Activation/ReLu6.cs
new file mode 100644
index 00000000..5af3f767
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Layers/Activation/ReLu6.cs
@@ -0,0 +1,25 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Keras.Engine;
+using Tensorflow.Common.Types;
+using static Tensorflow.Binding;
+
+namespace Tensorflow.Keras.Layers
+{
+ ///
+ /// Leaky version of a Rectified Linear Unit.
+ ///
+ public class ReLu6 : Layer
+ {
+ public ReLu6() : base(new LayerArgs { })
+ {
+ }
+
+ protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = null, IOptionalArgs? optional_args = null)
+ {
+ return tf.nn.relu6(inputs);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/DepthwiseConv2D.cs b/src/TensorFlowNET.Keras/Layers/Convolution/DepthwiseConv2D.cs
new file mode 100644
index 00000000..dae4a403
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Layers/Convolution/DepthwiseConv2D.cs
@@ -0,0 +1,167 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using System;
+using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Keras.Saving;
+using Tensorflow.Common.Types;
+using Tensorflow.Keras.Utils;
+using Tensorflow.Operations;
+using Newtonsoft.Json;
+using System.Security.Cryptography;
+
+namespace Tensorflow.Keras.Layers
+{
+ public class DepthwiseConv2DArgs: Conv2DArgs
+ {
+ ///
+ /// depth_multiplier: The number of depthwise convolution output channels for
+ /// each input channel.The total number of depthwise convolution output
+ /// channels will be equal to `filters_in* depth_multiplier`.
+ ///
+ [JsonProperty("depth_multiplier")]
+ public int DepthMultiplier { get; set; } = 1;
+
+ [JsonProperty("depthwise_initializer")]
+ public IInitializer DepthwiseInitializer { get; set; }
+ }
+
+ public class DepthwiseConv2D : Conv2D
+ {
+ ///
+ /// depth_multiplier: The number of depthwise convolution output channels for
+ /// each input channel.The total number of depthwise convolution output
+ /// channels will be equal to `filters_in* depth_multiplier`.
+ ///
+ int DepthMultiplier = 1;
+
+ IInitializer DepthwiseInitializer;
+
+ int[] strides;
+
+ int[] dilation_rate;
+
+ string getDataFormat()
+ {
+ return data_format == "channels_first" ? "NCHW" : "NHWC";
+ }
+
+ static int _id = 1;
+
+ public DepthwiseConv2D(DepthwiseConv2DArgs args):base(args)
+ {
+ args.Padding = args.Padding.ToUpper();
+
+ if(string.IsNullOrEmpty(args.Name))
+ name = "DepthwiseConv2D_" + _id;
+
+ this.DepthMultiplier = args.DepthMultiplier;
+ this.DepthwiseInitializer = args.DepthwiseInitializer;
+
+ }
+
+ public override void build(KerasShapesWrapper input_shape)
+ {
+ //base.build(input_shape);
+
+ var shape = input_shape.ToSingleShape();
+
+ int channel_axis = data_format == "channels_first" ? 1 : -1;
+ var input_channel = channel_axis < 0 ?
+ shape.dims[shape.ndim + channel_axis] :
+ shape.dims[channel_axis];
+
+ var arg = args as DepthwiseConv2DArgs;
+
+ if (arg.Strides.ndim != shape.ndim)
+ {
+ if (arg.Strides.ndim == 2)
+ {
+ this.strides = new int[] { 1, (int)arg.Strides[0], (int)arg.Strides[1], 1 };
+ }
+ else
+ {
+ this.strides = conv_utils.normalize_tuple(new int[] { (int)arg.Strides[0] }, shape.ndim, "strides");
+ }
+ }
+ else
+ {
+ this.strides = arg.Strides.dims.Select(o=>(int)(o)).ToArray();
+ }
+
+ if (arg.DilationRate.ndim != shape.ndim)
+ {
+ this.dilation_rate = conv_utils.normalize_tuple(new int[] { (int)arg.DilationRate[0] }, shape.ndim, "dilation_rate");
+ }
+
+ long channel_data = data_format == "channels_first" ? shape[0] : shape[shape.Length - 1];
+
+ var depthwise_kernel_shape = this.kernel_size.dims.concat(new long[] {
+ channel_data,
+ this.DepthMultiplier
+ });
+
+ this.kernel = this.add_weight(
+ shape: depthwise_kernel_shape,
+ initializer: this.DepthwiseInitializer != null ? this.DepthwiseInitializer : this.kernel_initializer,
+ name: "depthwise_kernel",
+ trainable: true,
+ dtype: DType,
+ regularizer: this.kernel_regularizer
+ );
+
+ var axes = new Dictionary();
+ axes.Add(-1, (int)input_channel);
+ inputSpec = new InputSpec(min_ndim: rank + 2, axes: axes);
+
+
+ if (use_bias)
+ {
+ bias = add_weight(name: "bias",
+ shape: ((int)channel_data),
+ initializer: bias_initializer,
+ trainable: true,
+ dtype: DType);
+ }
+
+ built = true;
+ _buildInputShape = input_shape;
+ }
+
+ protected override Tensors Call(Tensors inputs, Tensors state = null,
+ bool? training = false, IOptionalArgs? optional_args = null)
+ {
+ Tensor outputs = null;
+
+ outputs = gen_nn_ops.depthwise_conv2d_native(
+ inputs,
+ filter: this.kernel.AsTensor(),
+ strides: this.strides,
+ padding: this.padding,
+ dilations: this.dilation_rate,
+ data_format: this.getDataFormat(),
+ name: name
+ );
+
+ if (use_bias)
+ {
+ if (data_format == "channels_first")
+ {
+ throw new NotImplementedException("call channels_first");
+ }
+ else
+ {
+ outputs = gen_nn_ops.bias_add(outputs, ops.convert_to_tensor(bias),
+ data_format: this.getDataFormat(), name: name);
+ }
+ }
+
+ if (activation != null)
+ outputs = activation.Apply(outputs);
+
+
+ return outputs;
+ }
+
+ }
+}
\ No newline at end of file
diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
index 928e7e33..e2adb23d 100644
--- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs
+++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
@@ -112,7 +112,28 @@ namespace Tensorflow.Keras.Layers
KernelInitializer = GetInitializerByName(kernel_initializer),
BiasInitializer = GetInitializerByName(bias_initializer)
});
-
+ public ILayer Conv2D(int filters,
+ Shape kernel_size = null,
+ Shape strides = null,
+ string padding = "valid")
+ => new Conv2D(new Conv2DArgs
+ {
+ Rank = 2,
+ Filters = filters,
+ KernelSize = (kernel_size == null) ? (5, 5) : kernel_size,
+ Strides = strides == null ? (1, 1) : strides,
+ Padding = padding,
+ DataFormat = null,
+ DilationRate = (1, 1),
+ Groups = 1,
+ UseBias = false,
+ KernelRegularizer = null,
+ KernelInitializer =tf.glorot_uniform_initializer,
+ BiasInitializer = tf.zeros_initializer,
+ BiasRegularizer = null,
+ ActivityRegularizer = null,
+ Activation = keras.activations.Linear,
+ });
///
/// 2D convolution layer (e.g. spatial convolution over images).
/// This layer creates a convolution kernel that is convolved with the layer input to produce a tensor of outputs.
@@ -210,6 +231,38 @@ namespace Tensorflow.Keras.Layers
Activation = keras.activations.GetActivationFromName(activation)
});
+ public ILayer DepthwiseConv2D(Shape kernel_size = null,
+ Shape strides = null,
+ string padding = "valid",
+ string data_format = null,
+ Shape dilation_rate = null,
+ int groups = 1,
+ int depth_multiplier = 1,
+ string activation = null,
+ bool use_bias = false,
+ string kernel_initializer = "glorot_uniform",
+ string bias_initializer = "zeros",
+ string depthwise_initializer = "glorot_uniform"
+ )
+ => new DepthwiseConv2D(new DepthwiseConv2DArgs
+ {
+ Rank = 2,
+ Filters = 1,
+ KernelSize = (kernel_size == null) ? (5, 5) : kernel_size,
+ Strides = strides == null ? (1) : strides,
+ Padding = padding,
+ DepthMultiplier = depth_multiplier,
+ DataFormat = data_format,
+ DilationRate = dilation_rate == null ? (1) : dilation_rate,
+ Groups = groups,
+ UseBias = use_bias,
+ KernelInitializer = GetInitializerByName(kernel_initializer),
+ DepthwiseInitializer = GetInitializerByName(depthwise_initializer == null ? kernel_initializer : depthwise_initializer),
+ BiasInitializer = GetInitializerByName(bias_initializer),
+ Activation = keras.activations.GetActivationFromName(activation),
+ });
+
+
///
/// Transposed convolution layer (sometimes called Deconvolution).
///
@@ -682,6 +735,15 @@ namespace Tensorflow.Keras.Layers
});
+ ///
+ /// Leaky version of a Rectified Linear Unit.
+ ///
+ /// Negative slope coefficient.
+ ///
+ public ILayer ReLU6()
+ => new ReLu6();
+
+
public IRnnCell SimpleRNNCell(
int units,
string activation = "tanh",
diff --git a/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs b/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs
index a2a8286b..fa82426c 100644
--- a/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs
+++ b/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs
@@ -39,6 +39,7 @@ namespace Tensorflow.Keras.Layers
shape_set.Add(shape);
}*/
_buildInputShape = input_shape;
+ built = true;
}
protected override Tensors _merge_function(Tensors inputs)
diff --git a/src/TensorFlowNET.Keras/Saving/hdf5_format.cs b/src/TensorFlowNET.Keras/Saving/hdf5_format.cs
index bab0efec..68b73953 100644
--- a/src/TensorFlowNET.Keras/Saving/hdf5_format.cs
+++ b/src/TensorFlowNET.Keras/Saving/hdf5_format.cs
@@ -82,7 +82,7 @@ namespace Tensorflow.Keras.Saving
}
- public static void load_weights_from_hdf5_group(long f, List layers)
+ public static List<(IVariableV1, NDArray)> load_weights_from_hdf5_group(long f, List layers)
{
string original_keras_version = "2.5.0";
string original_backend = null;
@@ -152,7 +152,7 @@ namespace Tensorflow.Keras.Saving
weight_value_tuples.AddRange(zip(symbolic_weights, weight_values));
}
- keras.backend.batch_set_value(weight_value_tuples);
+ return weight_value_tuples;
}
public static void toarrayf4(long filepath = -1, Dictionary custom_objects = null, bool compile = false)
diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
index 36d1bc1d..eb8ebf93 100644
--- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
+++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj
@@ -7,7 +7,7 @@
enable
Tensorflow.Keras
AnyCPU;x64
- 0.11.3
+ 0.15.0
Haiping Chen
Keras for .NET
Apache 2.0, Haiping Chen since 2018
@@ -30,6 +30,7 @@
* Fixed memory leak for YOLOv3 model.
* Support RNN and LSTM models
* Support Transformer model
+ * Support BERT model
Keras for .NET
@@ -42,8 +43,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
Git
False
Open.snk
- 0.11.3.0
- 0.11.3.0
+ 0.15.0.0
+ 0.15.0.0
LICENSE
Debug;Release;GPU
@@ -143,7 +144,7 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
-
+
diff --git a/src/TensorFlowNET.Keras/Utils/data_utils.cs b/src/TensorFlowNET.Keras/Utils/data_utils.cs
index e6db0ef7..b0bc1554 100644
--- a/src/TensorFlowNET.Keras/Utils/data_utils.cs
+++ b/src/TensorFlowNET.Keras/Utils/data_utils.cs
@@ -53,15 +53,17 @@ namespace Tensorflow.Keras.Utils
new_seq, new_label: shortened lists for `seq` and `label`.
*/
+ var nRow = seq.GetLength(0);
+ var nCol = seq.GetLength(1);
List new_seq = new List();
List new_label = new List();
- for (var i = 0; i < seq.GetLength(0); i++)
+ for (var i = 0; i < nRow; i++)
{
- if (maxlen < seq.GetLength(1) && seq[i, maxlen] != 0)
+ if (maxlen < nCol && seq[i, maxlen] != 0)
continue;
int[] sentence = new int[maxlen];
- for (var j = 0; j < maxlen && j < seq.GetLength(1); j++)
+ for (var j = 0; j < maxlen && j < nCol; j++)
{
sentence[j] = seq[i, j];
}
diff --git a/src/TensorFlowNET.Keras/Utils/generic_utils.cs b/src/TensorFlowNET.Keras/Utils/generic_utils.cs
index 5402f499..20937e2e 100644
--- a/src/TensorFlowNET.Keras/Utils/generic_utils.cs
+++ b/src/TensorFlowNET.Keras/Utils/generic_utils.cs
@@ -112,12 +112,23 @@ namespace Tensorflow.Keras.Utils
foreach (var token in layersToken)
{
var args = deserialize_layer_args(token["class_name"].ToObject(), token["config"]);
+
+ List nodeConfig = null; //python tensorflow sometimes exports inbound nodes in an extra nested array
+ if (token["inbound_nodes"].Count() > 0 && token["inbound_nodes"][0].Count() > 0 && token["inbound_nodes"][0][0].Count() > 0)
+ {
+ nodeConfig = token["inbound_nodes"].ToObject>>().FirstOrDefault() ?? new List();
+ }
+ else
+ {
+ nodeConfig = token["inbound_nodes"].ToObject>();
+ }
+
config.Layers.Add(new LayerConfig()
{
Config = args,
Name = token["name"].ToObject(),
ClassName = token["class_name"].ToObject(),
- InboundNodes = token["inbound_nodes"].ToObject>()
+ InboundNodes = nodeConfig,
});
}
config.InputLayers = json["input_layers"].ToObject>();
diff --git a/src/TensorflowNET.Hub/Tensorflow.Hub.csproj b/src/TensorflowNET.Hub/Tensorflow.Hub.csproj
index 3c09f808..efa37598 100644
--- a/src/TensorflowNET.Hub/Tensorflow.Hub.csproj
+++ b/src/TensorflowNET.Hub/Tensorflow.Hub.csproj
@@ -26,7 +26,7 @@
-
+
diff --git a/test/TensorFlow.Kernel.UnitTest/TensorFlow.Kernel.UnitTest.csproj b/test/TensorFlow.Kernel.UnitTest/TensorFlow.Kernel.UnitTest.csproj
new file mode 100644
index 00000000..21b2731b
--- /dev/null
+++ b/test/TensorFlow.Kernel.UnitTest/TensorFlow.Kernel.UnitTest.csproj
@@ -0,0 +1,24 @@
+
+
+
+ net6.0
+ enable
+ enable
+
+ false
+ true
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/test/TensorFlow.Kernel.UnitTest/array_ops/concat_op_test.cs b/test/TensorFlow.Kernel.UnitTest/array_ops/concat_op_test.cs
new file mode 100644
index 00000000..67d0aa60
--- /dev/null
+++ b/test/TensorFlow.Kernel.UnitTest/array_ops/concat_op_test.cs
@@ -0,0 +1,63 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Tensorflow;
+using Tensorflow.NumPy;
+using static Tensorflow.Binding;
+
+namespace TensorFlow.Kernel.UnitTest
+{
+ [TestClass]
+ public class concat_op_test
+ {
+ [TestMethod]
+ public void testConcatEmpty()
+ {
+ var t1 = tf.constant(new int[] { });
+ var t2 = tf.constant(new int[] { });
+ var c = array_ops.concat(new[] { t1, t2 }, 0);
+ var expected = np.array(new int[] { });
+ Assert.IsTrue(Enumerable.SequenceEqual(expected.ToArray(), c.numpy().ToArray()));
+ }
+
+ [TestMethod]
+ public void testConcatNegativeAxis()
+ {
+ var t1 = tf.constant(new int[,] { { 1, 2, 3 }, { 4, 5, 6 } });
+ var t2 = tf.constant(new int[,] { { 7, 8, 9 }, { 10, 11, 12 } });
+ var c = array_ops.concat(new[] { t1, t2 }, -2);
+ var expected = np.array(new int[,,] { { { 1, 2, 3 }, { 4, 5, 6 } }, { { 7, 8, 9 }, { 10, 11, 12 } } });
+ Assert.IsTrue(Enumerable.SequenceEqual(expected.ToArray(), c.numpy().ToArray()));
+
+ c = array_ops.concat(new[] { t1, t2 }, -1);
+ expected = np.array(new int[,] { { 1, 2, 3, 7, 8, 9 }, { 4, 5, 6, 10, 11, 12 } });
+ Assert.IsTrue(Enumerable.SequenceEqual(expected.ToArray(), c.numpy().ToArray()));
+ }
+
+ [TestMethod]
+ [DataRow(TF_DataType.TF_INT32)]
+ [DataRow(TF_DataType.TF_INT64)]
+ [DataRow(TF_DataType.TF_UINT32)]
+ [DataRow(TF_DataType.TF_UINT64)]
+ public void testConcatDtype(TF_DataType dtype)
+ {
+ var t1 = tf.constant(new int[,] { { 1, 2, 3 }, { 4, 5, 6 } }, dtype: dtype);
+ var t2 = tf.constant(new int[,] { { 7, 8, 9 }, { 10, 11, 12 } }, dtype: dtype);
+ var c = array_ops.concat(new[] { t1, t2 }, 1);
+ var expected = np.array(new int[,] { { 1, 2, 3, 7, 8, 9 }, { 4, 5, 6, 10, 11, 12 } });
+ Assert.IsTrue(Enumerable.SequenceEqual(expected.ToArray(), tf.cast(c, TF_DataType.TF_INT32).numpy().ToArray()));
+
+ }
+
+ [TestMethod]
+ [DataRow(TF_DataType.TF_INT32)]
+ [DataRow(TF_DataType.TF_INT64)]
+ public void testConcatAxisType(TF_DataType dtype)
+ {
+ var t1 = tf.constant(new int[,] { { 1, 2, 3 }, { 4, 5, 6 } });
+ var t2 = tf.constant(new int[,] { { 7, 8, 9 }, { 10, 11, 12 } });
+ var c = array_ops.concat(new[] { t1, t2 }, tf.constant(1, dtype: dtype));
+ var expected = np.array(new int[,] { { 1, 2, 3, 7, 8, 9 }, { 4, 5, 6, 10, 11, 12 } });
+ Assert.IsTrue(Enumerable.SequenceEqual(expected.ToArray(), tf.cast(c, TF_DataType.TF_INT32).numpy().ToArray()));
+ }
+
+ }
+}
\ No newline at end of file
diff --git a/test/TensorFlowNET.Graph.UnitTest/Basics/TensorTest.cs b/test/TensorFlowNET.Graph.UnitTest/Basics/TensorTest.cs
index 90de7874..8093c1f2 100644
--- a/test/TensorFlowNET.Graph.UnitTest/Basics/TensorTest.cs
+++ b/test/TensorFlowNET.Graph.UnitTest/Basics/TensorTest.cs
@@ -3,6 +3,7 @@ using Tensorflow.NumPy;
using System;
using System.Linq;
using static Tensorflow.Binding;
+using Tensorflow;
namespace TensorFlowNET.UnitTest.Basics
{
@@ -60,14 +61,14 @@ namespace TensorFlowNET.UnitTest.Basics
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 15, 21, 16, 22, 17, 23 }, result[0, 3].ToArray()));
}
- [TestMethod, Ignore]
+ [TestMethod]
public void boolean_mask()
{
+ if (!tf.executing_eagerly())
+ tf.enable_eager_execution();
var tensor = new[] { 0, 1, 2, 3 };
var mask = np.array(new[] { true, false, true, false });
var masked = tf.boolean_mask(tensor, mask);
- var sess = tf.Session();
- var result = sess.run(masked);
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 2 }, masked.ToArray()));
}
}
diff --git a/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs b/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs
index d671b609..127b65bf 100644
--- a/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs
+++ b/test/TensorFlowNET.Graph.UnitTest/ImageTest.cs
@@ -4,6 +4,7 @@ using System.Linq;
using Tensorflow;
using static Tensorflow.Binding;
using System;
+using System.IO;
namespace TensorFlowNET.UnitTest
{
@@ -164,5 +165,94 @@ namespace TensorFlowNET.UnitTest
Assert.AreEqual(result.size, 16ul);
Assert.AreEqual(result[0, 0, 0, 0], 12f);
}
+
+ [TestMethod]
+ public void ImageSaveTest()
+ {
+ var imgPath = TestHelper.GetFullPathFromDataDir("img001.bmp");
+ var jpegImgPath = TestHelper.GetFullPathFromDataDir("img001.jpeg");
+ var pngImgPath = TestHelper.GetFullPathFromDataDir("img001.png");
+
+ File.Delete(jpegImgPath);
+ File.Delete(pngImgPath);
+
+ var contents = tf.io.read_file(imgPath);
+ var bmp = tf.image.decode_image(contents);
+ Assert.AreEqual(bmp.name, "decode_image/DecodeImage:0");
+
+ var jpeg = tf.image.encode_jpeg(bmp);
+ var op1 = tf.io.write_file(jpegImgPath, jpeg);
+
+ var png = tf.image.encode_png(bmp);
+ var op2 = tf.io.write_file(pngImgPath, png);
+
+ this.session().run(op1);
+ this.session().run(op2);
+
+ Assert.IsTrue(File.Exists(jpegImgPath), "not find file:" + jpegImgPath);
+ Assert.IsTrue(File.Exists(pngImgPath), "not find file:" + pngImgPath);
+
+ // 如果要测试图片正确性,需要注释下面两行代码
+ File.Delete(jpegImgPath);
+ File.Delete(pngImgPath);
+ }
+
+ [TestMethod]
+ public void ImageFlipTest()
+ {
+ var imgPath = TestHelper.GetFullPathFromDataDir("img001.bmp");
+
+ var contents = tf.io.read_file(imgPath);
+ var bmp = tf.image.decode_image(contents);
+
+ // 左右翻转
+ var lrImgPath = TestHelper.GetFullPathFromDataDir("img001_lr.png");
+ File.Delete(lrImgPath);
+
+ var lr = tf.image.flip_left_right(bmp);
+ var png = tf.image.encode_png(lr);
+ var op = tf.io.write_file(lrImgPath, png);
+ this.session().run(op);
+
+ Assert.IsTrue(File.Exists(lrImgPath), "not find file:" + lrImgPath);
+
+ // 上下翻转
+ var updownImgPath = TestHelper.GetFullPathFromDataDir("img001_updown.png");
+ File.Delete(updownImgPath);
+
+ var updown = tf.image.flip_up_down(bmp);
+ var pngupdown = tf.image.encode_png(updown);
+ var op2 = tf.io.write_file(updownImgPath, pngupdown);
+ this.session().run(op2);
+ Assert.IsTrue(File.Exists(updownImgPath));
+
+
+ // 暂时先人工观测图片是否翻转,观测时需要删除下面这两行代码
+ File.Delete(lrImgPath);
+ File.Delete(updownImgPath);
+
+ // 多图翻转
+ // 目前直接通过 bmp 拿到 shape ,这里先用默认定义图片大小来构建了
+ var mImg = tf.stack(new[] { bmp, lr }, axis:0);
+ print(mImg.shape);
+
+ var up2 = tf.image.flip_up_down(mImg);
+
+ var updownImgPath_m1 = TestHelper.GetFullPathFromDataDir("img001_m_ud.png"); // 直接上下翻转
+ File.Delete(updownImgPath_m1);
+
+ var img001_updown_m2 = TestHelper.GetFullPathFromDataDir("img001_m_lr_ud.png"); // 先左右再上下
+ File.Delete(img001_updown_m2);
+
+ var png2 = tf.image.encode_png(up2[0]);
+ tf.io.write_file(updownImgPath_m1, png2);
+
+ png2 = tf.image.encode_png(up2[1]);
+ tf.io.write_file(img001_updown_m2, png2);
+
+ // 如果要测试图片正确性,需要注释下面两行代码
+ File.Delete(updownImgPath_m1);
+ File.Delete(img001_updown_m2);
+ }
}
}
diff --git a/test/TensorFlowNET.Keras.UnitTest/EagerModeTestBase.cs b/test/TensorFlowNET.Keras.UnitTest/EagerModeTestBase.cs
index c7eab364..635f13a5 100644
--- a/test/TensorFlowNET.Keras.UnitTest/EagerModeTestBase.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/EagerModeTestBase.cs
@@ -33,6 +33,40 @@ namespace Tensorflow.Keras.UnitTest
return ret;
}
+
+ public void AssertArray(int[] f1, int[] f2)
+ {
+ bool ret = false;
+ for (var i = 0; i < f1.Length; i++)
+ {
+ ret = f1[i] == f2[i];
+ if (!ret)
+ break;
+ }
+
+ if (!ret)
+ {
+ Assert.Fail($"Array not Equal:[{string.Join(",", f1)}] [{string.Join(",", f2)}]");
+ }
+ }
+
+ public void AssertArray(float[] f1, float[] f2)
+ {
+ bool ret = false;
+ var tolerance = .00001f;
+ for (var i = 0; i < f1.Length; i++)
+ {
+ ret = Math.Abs(f1[i] - f2[i]) <= tolerance;
+ if (!ret)
+ break;
+ }
+
+ if (!ret)
+ {
+ Assert.Fail($"Array float not Equal:[{string.Join(",", f1)}] [{string.Join(",", f2)}]");
+ }
+ }
+
public bool Equal(double[] d1, double[] d2)
{
bool ret = false;
diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Convolution.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Convolution.Test.cs
index 997dcb4f..15c6e80f 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Convolution.Test.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Convolution.Test.cs
@@ -1,6 +1,8 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System.Linq;
using Tensorflow.NumPy;
using static Tensorflow.KerasApi;
+using static Tensorflow.Binding;
namespace Tensorflow.Keras.UnitTest.Layers
{
@@ -193,5 +195,128 @@ namespace Tensorflow.Keras.UnitTest.Layers
Assert.AreEqual(x.dims[2], y.shape[2]);
Assert.AreEqual(filters, y.shape[3]);
}
+
+
+ [TestMethod]
+ public void BasicDepthwiseConv2D()
+ {
+ var conv = keras.layers.DepthwiseConv2D(kernel_size:3, strides:1, activation: null,
+ padding:"same", depthwise_initializer: "ones");
+
+ var x = np.arange(2 * 9* 9* 3).reshape((2, 9, 9, 3));
+ var x2 = ops.convert_to_tensor(x, TF_DataType.TF_FLOAT);
+
+ var y = conv.Apply(x2);
+
+ print($"input:{x2.shape} DepthwiseConv2D.out: {y.shape}");
+
+
+ Assert.AreEqual(4, y.shape.ndim);
+ var arr = y.numpy().reshape((2, 9, 9, 3));
+
+ AssertArray(x[new int[] { 1, 1, 1 }].ToArray(), new int[] { 273, 274, 275 });
+ AssertArray(arr[new int[] { 1, 1, 1 }].ToArray(), new float[] { 2457f, 2466f, 2475f });
+
+ var bn = keras.layers.BatchNormalization();
+ var y2 = bn.Apply(y);
+ arr = y2.numpy().ToArray();
+
+ double delta = 0.0001; // 误差范围
+
+ Assert.AreEqual(arr[0], 59.97002f, delta);
+ Assert.AreEqual(arr[1], 63.96802f, delta);
+ }
+
+
+ [TestMethod]
+ public void BasicDepthwiseConv2D_strides_2()
+ {
+ var conv = keras.layers.DepthwiseConv2D(kernel_size: 3, strides: (1, 2, 2, 1), activation: null,
+ padding: "same", depthwise_initializer: "ones");
+
+ var x = np.arange(2 * 9 * 9 * 3).reshape((2, 9, 9, 3));
+ var x2 = ops.convert_to_tensor(x, TF_DataType.TF_FLOAT);
+
+ var y = conv.Apply(x2);
+
+ print($"input:{x2.shape} DepthwiseConv2D.out: {y.shape}");
+
+ Assert.AreEqual(4, y.shape.ndim);
+ var arr = y.numpy().reshape((2, 5, 5, 3));
+
+ AssertArray(x[new int[] { 1, 1, 1 }].ToArray(), new int[] { 273, 274, 275 });
+ AssertArray(arr[new int[] { 1, 1, 1 }].ToArray(), new float[] { 2727f, 2736f, 2745f });
+
+ var bn = keras.layers.BatchNormalization();
+ var y2 = bn.Apply(y);
+ arr = y2.numpy().ToArray();
+
+ double delta = 0.0001; // 误差范围
+
+ Assert.AreEqual(arr[0], 59.97002f, delta);
+ Assert.AreEqual(arr[1], 63.96802f, delta);
+ }
+
+
+
+ [TestMethod]
+ public void BasicDepthwiseConv2D_strides_3()
+ {
+ var conv = keras.layers.DepthwiseConv2D(kernel_size: 3, strides: 3, activation: null,
+ padding: "same", depthwise_initializer: "ones");
+
+ var x = np.arange(2 * 9 * 9 * 3).reshape((2, 9, 9, 3));
+ var x2 = ops.convert_to_tensor(x, TF_DataType.TF_FLOAT);
+
+ var y = conv.Apply(x2);
+
+ print($"input:{x2.shape} DepthwiseConv2D.out: {y.shape}");
+
+ Assert.AreEqual(4, y.shape.ndim);
+ var arr = y.numpy().reshape((2, 3, 3, 3));
+
+ AssertArray(x[new int[] { 1, 1, 1 }].ToArray(), new int[] { 273, 274, 275 });
+ AssertArray(arr[new int[] { 1, 1, 1 }].ToArray(), new float[] { 3267f, 3276f, 3285f });
+
+ var bn = keras.layers.BatchNormalization();
+ var y2 = bn.Apply(y);
+ arr = y2.numpy().ToArray();
+
+ double delta = 0.0001; // 误差范围
+
+ Assert.AreEqual(arr[0], 269.86508f, delta);
+ Assert.AreEqual(arr[1], 278.8606f, delta);
+
+ }
+ [TestMethod]
+ public void BasicDepthwiseConv2D_UseBias()
+ {
+ var conv = keras.layers.DepthwiseConv2D(kernel_size: 3, strides: 1, activation: null,
+ use_bias: true, padding: "same",
+ depthwise_initializer: "ones",
+ bias_initializer:"ones"
+ );
+
+ var weight = conv.get_weights();
+
+ var x = np.arange(9 * 9 * 3).reshape((1, 9, 9, 3));
+ var x2 = ops.convert_to_tensor(x, TF_DataType.TF_FLOAT);
+ var y = conv.Apply(x2);
+
+ Assert.AreEqual(4, y.shape.ndim);
+ var arr = y.numpy().ToArray();
+
+ Assert.AreEqual(arr[0], 61f);
+ Assert.AreEqual(arr[1], 65f);
+
+ var bn = keras.layers.BatchNormalization();
+ var y2 = bn.Apply(y);
+ arr = y2.numpy().ToArray();
+
+ double delta = 0.0001; // 误差范围
+
+ Assert.AreEqual(arr[0], 60.96952f, delta);
+ Assert.AreEqual(arr[1], 64.96752f, delta);
+ }
}
}
diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Merging.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Merging.Test.cs
index 36e44e48..9bc2fa76 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Merging.Test.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Merging.Test.cs
@@ -1,4 +1,5 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System.Collections.Generic;
using Tensorflow.NumPy;
using static Tensorflow.KerasApi;
@@ -8,12 +9,16 @@ namespace Tensorflow.Keras.UnitTest.Layers
public class LayersMergingTest : EagerModeTestBase
{
[TestMethod]
- public void Concatenate()
+ [DataRow(1, 4, 1, 5)]
+ [DataRow(2, 2, 2, 5)]
+ [DataRow(3, 2, 1, 10)]
+ public void Concatenate(int axis, int shapeA, int shapeB, int shapeC)
{
- var x = np.arange(20).reshape((2, 2, 5));
- var y = np.arange(20, 30).reshape((2, 1, 5));
- var z = keras.layers.Concatenate(axis: 1).Apply(new Tensors(x, y));
- Assert.AreEqual((2, 3, 5), z.shape);
+ var x = np.arange(10).reshape((1, 2, 1, 5));
+ var y = np.arange(10, 20).reshape((1, 2, 1, 5));
+ var z = keras.layers.Concatenate(axis: axis).Apply(new Tensors(x, y));
+ Assert.AreEqual((1, shapeA, shapeB, shapeC), z.shape);
}
+
}
}
diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs
index dbf5cae1..67e2b046 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs
@@ -74,8 +74,8 @@ namespace Tensorflow.Keras.UnitTest.Layers
OneHot = true,
ValidationSize = 55000,
}).Result;
-
- model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 1);
+ var sample_weight = np.ones(((int)dataset.Train.Data.shape[0]));
+ model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 1, sample_weight:sample_weight);
}
[TestMethod]
diff --git a/test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs b/test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs
index cb570fc0..53a67cbf 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs
@@ -1,10 +1,13 @@
using Microsoft.VisualStudio.TestPlatform.Utilities;
using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Newtonsoft.Json.Linq;
using System.Linq;
+using System.Xml.Linq;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Optimizers;
using Tensorflow.Keras.UnitTest.Helpers;
using Tensorflow.NumPy;
+using static HDF.PInvoke.H5Z;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
@@ -124,4 +127,44 @@ public class ModelLoadTest
var model = tf.saved_model.load(@"D:\development\temp\saved_model") as Tensorflow.Keras.Engine.Model;
model.summary();
}
+
+
+
+ [TestMethod]
+ public void CreateConcatenateModelSaveAndLoad()
+ {
+ // a small demo model that is just here to see if the axis value for the concatenate method is saved and loaded.
+ var input_layer = tf.keras.layers.Input((8, 8, 5));
+
+ var conv1 = tf.keras.layers.Conv2D(2, kernel_size: 3, activation: "relu", padding: "same"/*, data_format: "_conv_1"*/).Apply(input_layer);
+ conv1.Name = "conv1";
+
+ var conv2 = tf.keras.layers.Conv2D(2, kernel_size: 3, activation: "relu", padding: "same"/*, data_format: "_conv_2"*/).Apply(input_layer);
+ conv2.Name = "conv2";
+
+ var concat1 = tf.keras.layers.Concatenate(axis: 3).Apply((conv1, conv2));
+ concat1.Name = "concat1";
+
+ var model = tf.keras.Model(input_layer, concat1);
+ model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.CategoricalCrossentropy());
+
+ model.save(@"Assets/concat_axis3_model");
+
+
+ var tensorInput = np.arange(320).reshape((1, 8, 8, 5)).astype(TF_DataType.TF_FLOAT);
+
+ var tensors1 = model.predict(tensorInput);
+
+ Assert.AreEqual((1, 8, 8, 4), tensors1.shape);
+
+ model = null;
+ keras.backend.clear_session();
+
+ var model2 = tf.keras.models.load_model(@"Assets/concat_axis3_model");
+
+ var tensors2 = model2.predict(tensorInput);
+
+ Assert.AreEqual(tensors1.shape, tensors2.shape);
+ }
+
}
diff --git a/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs b/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs
index d08f4e50..b7b9ae12 100644
--- a/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs
+++ b/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs
@@ -20,6 +20,20 @@ namespace TensorFlowNET.UnitTest
return Math.Abs(f1 - f2) <= tolerance;
}
+ public bool Equal(long[] l1, long[] l2)
+ {
+ if (l1.Length != l2.Length)
+ return false;
+
+ for (var i = 0; i < l1.Length; i++)
+ {
+ if (l1[i] != l2[i])
+ return false;
+ }
+
+ return true;
+ }
+
public bool Equal(float[] f1, float[] f2)
{
bool ret = false;
diff --git a/test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs b/test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs
index e41e1d61..1cfceb3e 100644
--- a/test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs
+++ b/test/TensorFlowNET.UnitTest/GradientTest/GradientEagerTest.cs
@@ -62,7 +62,7 @@ namespace TensorFlowNET.UnitTest.Gradient
// Calcute the gradient of (x1-x2)^2
// by Automatic Differentiation in Eager mode
// Expected is 2*(abs(x1-x2))
- Tensor x1 = new NDArray( new float[] { 1, 3, 5, 21, 19, 17 });
+ Tensor x1 = new NDArray(new float[] { 1, 3, 5, 21, 19, 17 });
Tensor x2 = new NDArray(new float[] { 29, 27, 23, 7, 11, 13 });
float[] expected = new float[]
{
@@ -173,5 +173,34 @@ namespace TensorFlowNET.UnitTest.Gradient
var result = grad(x, 4);
Assert.AreEqual((float)result, 4.0f);
}
+
+ [TestMethod]
+ public void Tile()
+ {
+ var a = tf.constant(new int[] { 1 }, TF_DataType.TF_FLOAT);
+ var b = tf.constant(new int[] { 2 });
+ using (var tape = tf.GradientTape())
+ {
+ tape.watch(a);
+ var y = tf.tile(a, b);
+ var grad = tape.gradient(y, a);
+ Assert.AreEqual((float)grad.numpy(), 2.0f);
+ }
+ }
+
+ [TestMethod]
+ public void GatherNdTest()
+ {
+ var x = tf.constant(new float[,] { { 1.0f, 2.0f, 3.0f }, { 1.0f, 2.0f, 3.0f }, { 1.0f, 2.0f, 3.0f } }, dtype: TF_DataType.TF_FLOAT);
+ var indices = tf.constant(new int[,] { { 0, 1 }, { 1, 1 }, { 2, 1 } }, dtype: TF_DataType.TF_INT32);
+ using (var tape = tf.GradientTape())
+ {
+ tape.watch(x);
+ var res = tf.gather_nd(x, indices);
+ var grad = tape.gradient(res, x);
+ var expected = np.array(new float[,] { { 0f, 1f, 0f }, { 0f, 1f, 0f }, { 0f, 1f, 0f } });
+ Assert.IsTrue(Enumerable.SequenceEqual(grad.ToArray(), expected.ToArray()));
+ }
+ }
}
}
diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs
index 675689bb..e25c9779 100644
--- a/test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs
+++ b/test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs
@@ -3,6 +3,7 @@ using Tensorflow.NumPy;
using Tensorflow;
using static Tensorflow.Binding;
using System.Linq;
+using Tensorflow.Operations;
namespace TensorFlowNET.UnitTest.ManagedAPI
{
@@ -105,5 +106,321 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
Assert.IsTrue(Equal(a[0].ToArray().Reverse().ToArray(), b[0].ToArray()));
Assert.IsTrue(Equal(a[1].ToArray().Reverse().ToArray(), b[1].ToArray()));
}
+
+ [TestMethod]
+ public void ReverseImgArray3D()
+ {
+ // 创建 sourceImg 数组
+ var sourceImgArray = new float[,,] {
+ {
+ { 237, 28, 36 },
+ { 255, 255, 255 },
+ { 255, 255, 255 }
+ },
+ {
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ { 255, 255, 255 }
+ }
+ };
+ var sourceImg = ops.convert_to_tensor(sourceImgArray);
+
+ // 创建 lrImg 数组
+ var lrImgArray = new float[,,] {
+ {
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ { 237, 28, 36 }
+ },
+ {
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ { 255, 255, 255 }
+ }
+ };
+ var lrImg = ops.convert_to_tensor(lrImgArray);
+
+ var lr = tf.image.flip_left_right(sourceImg);
+ Assert.IsTrue(Equal(lrImg.numpy().ToArray(), lr.numpy().ToArray()), "tf.image.flip_left_right fail.");
+
+ var lr2 = tf.reverse(sourceImg, 1);
+ Assert.IsTrue(Equal(lrImg.numpy().ToArray(), lr2.numpy().ToArray()), "tf.reverse (axis=1) fail.");
+
+ var lr3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 1 }));
+ Assert.IsTrue(Equal(lrImg.numpy().ToArray(), lr3.numpy().ToArray()), "gen_array_ops.reverse_v2 axis=1 fail.");
+
+ // 创建 udImg 数组
+ var udImgArray = new float[,,] {
+ {
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ { 255, 255, 255 }
+ },
+ {
+ { 237, 28, 36 },
+ { 255, 255, 255 },
+ { 255, 255, 255 }
+ }
+ };
+ var udImg = ops.convert_to_tensor(udImgArray);
+
+ var ud = tf.image.flip_up_down(sourceImg);
+ Assert.IsTrue(Equal(udImg.numpy().ToArray(), ud.numpy().ToArray()), "tf.image.flip_up_down fail.");
+
+ var ud2 = tf.reverse(sourceImg, new Axis(0));
+ Assert.IsTrue(Equal(udImg.numpy().ToArray(), ud2.numpy().ToArray()), "tf.reverse (axis=0) fail.");
+
+ var ud3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 0 }));
+ Assert.IsTrue(Equal(udImg.numpy().ToArray(), ud3.numpy().ToArray()), "gen_array_ops.reverse_v2 axis=0 fail.");
+ }
+
+ [TestMethod]
+ public void ReverseImgArray4D()
+ {
+ // 原图左上角,加一张左右翻转后的图片
+ var m = new float[,,,] {
+ {
+ {
+ { 237, 28, 36 },
+ { 255, 255, 255 },
+ { 255, 255, 255 }
+ },
+ {
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ { 255, 255, 255 }
+ }
+ },
+ {
+ {
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ { 237, 28, 36 }
+ },
+ {
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ { 255, 255, 255 }
+ }
+ }
+ };
+ var sourceImg = ops.convert_to_tensor(m);
+
+ var lrArray = new float[,,,] {
+ {
+ {
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ { 237, 28, 36 },
+ },
+ {
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ { 255, 255, 255 }
+ }
+ },
+ {
+ {
+ { 237, 28, 36 },
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ },
+ {
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ { 255, 255, 255 }
+ }
+ }
+ };
+ var lrImg = ops.convert_to_tensor(lrArray);
+
+ // 创建 ud 数组
+ var udArray = new float[,,,] {
+ {
+ {
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ { 255, 255, 255 }
+ },
+ {
+ { 237, 28, 36 },
+ { 255, 255, 255 },
+ { 255, 255, 255 }
+ }
+ },
+ {
+ {
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ { 255, 255, 255 }
+ },
+ {
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ { 237, 28, 36 }
+ }
+ }
+ };
+ var udImg = ops.convert_to_tensor(udArray);
+
+ var ud3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 1 }));
+ Assert.IsTrue(Equal(udImg.numpy().ToArray(), ud3.numpy().ToArray()), "gen_array_ops.reverse_v2 axis=1 fail.");
+
+ var ud2 = tf.reverse(sourceImg, new Axis(1));
+ Assert.IsTrue(Equal(udImg.numpy().ToArray(), ud2.numpy().ToArray()), "tf.reverse (axis=1) fail.");
+
+ var ud = tf.image.flip_up_down(sourceImg);
+ Assert.IsTrue(Equal(udImg.numpy().ToArray(), ud.numpy().ToArray()), "tf.image.flip_up_down fail.");
+
+ // 左右翻转
+ var lr = tf.image.flip_left_right(sourceImg);
+ Assert.IsTrue(Equal(lrImg.numpy().ToArray(), lr.numpy().ToArray()), "tf.image.flip_left_right fail.");
+
+ var lr2 = tf.reverse(sourceImg, 0);
+ Assert.IsTrue(Equal(lrImg.numpy().ToArray(), lr2.numpy().ToArray()), "tf.reverse (axis=1) fail.");
+
+ var lr3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 0 }));
+ Assert.IsTrue(Equal(lrImg.numpy().ToArray(), lr3.numpy().ToArray()), "gen_array_ops.reverse_v2 axis=1 fail.");
+
+ }
+
+ [TestMethod]
+ public void ReverseImgArray4D_3x3()
+ {
+ // 原图左上角,加一张左右翻转后的图片
+ var m = new float[,,,] {
+ {
+ {
+ { 237, 28, 36 },
+ { 255, 255, 255 },
+ { 255, 255, 255 }
+ },
+ {
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ { 255, 255, 255 }
+ },
+ {
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ { 255, 255, 255 }
+ }
+ },
+ {
+ {
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ { 237, 28, 36 }
+ },
+ {
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ { 255, 255, 255 }
+ },
+ {
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ { 255, 255, 255 }
+ }
+ }
+ };
+ var sourceImg = ops.convert_to_tensor(m);
+
+ var lrArray = new float[,,,] {
+ {
+ {
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ { 237, 28, 36 },
+ },
+ {
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ { 255, 255, 255 }
+ },
+ {
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ { 255, 255, 255 }
+ }
+ },
+ {
+ {
+ { 237, 28, 36 },
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ },
+ {
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ { 255, 255, 255 }
+ },
+ {
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ { 255, 255, 255 }
+ }
+ }
+ };
+ var lrImg = ops.convert_to_tensor(lrArray);
+
+ // 创建 ud 数组
+ var udArray = new float[,,,] {
+ {
+ {
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ { 255, 255, 255 }
+ },
+ {
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ { 255, 255, 255 }
+ },
+ {
+ { 237, 28, 36 },
+ { 255, 255, 255 },
+ { 255, 255, 255 }
+ }
+ },
+ { {
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ { 255, 255, 255 }
+ },
+ {
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ { 255, 255, 255 }
+ },
+ {
+ { 255, 255, 255 },
+ { 255, 255, 255 },
+ { 237, 28, 36 }
+ }
+ }
+ };
+ var udImg = ops.convert_to_tensor(udArray);
+
+ var ud3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 1 }));
+ Assert.IsTrue(Equal(udImg.numpy().ToArray(), ud3.numpy().ToArray()), "gen_array_ops.reverse_v2 axis=1 fail.");
+
+ var ud2 = tf.reverse(sourceImg, new Axis(1));
+ Assert.IsTrue(Equal(udImg.numpy().ToArray(), ud2.numpy().ToArray()), "tf.reverse (axis=1) fail.");
+
+ var ud = tf.image.flip_up_down(sourceImg);
+ Assert.IsTrue(Equal(udImg.numpy().ToArray(), ud.numpy().ToArray()), "tf.image.flip_up_down fail.");
+
+ // 左右翻转
+ var lr = tf.image.flip_left_right(sourceImg);
+ Assert.IsTrue(Equal(lrImg.numpy().ToArray(), lr.numpy().ToArray()), "tf.image.flip_left_right fail.");
+
+ var lr2 = tf.reverse(sourceImg, 0);
+ Assert.IsTrue(Equal(lrImg.numpy().ToArray(), lr2.numpy().ToArray()), "tf.reverse (axis=1) fail.");
+
+ var lr3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 0 }));
+ Assert.IsTrue(Equal(lrImg.numpy().ToArray(), lr3.numpy().ToArray()), "gen_array_ops.reverse_v2 axis=1 fail.");
+
+ }
}
}
diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/RaggedTensorTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/RaggedTensorTest.cs
new file mode 100644
index 00000000..7a3de882
--- /dev/null
+++ b/test/TensorFlowNET.UnitTest/ManagedAPI/RaggedTensorTest.cs
@@ -0,0 +1,26 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Tensorflow;
+using Tensorflow.NumPy;
+using static Tensorflow.Binding;
+
+namespace TensorFlowNET.UnitTest.ManagedAPI
+{
+ public class RaggedTensorTest :EagerModeTestBase
+ {
+ [TestMethod]
+ public void Test_from_row_lengths()
+ {
+ var row_lengths = tf.convert_to_tensor(np.array(new int[] { 2, 0, 3, 1, 1 }, TF_DataType.TF_INT64));
+ var rp = RowPartition.from_row_lengths(row_lengths, validate: false);
+ var rp_row_lengths = rp.row_lengths();
+ var rp_nrows = rp.nrows();
+ Assert.IsTrue(rp_nrows.ToArray()[0] == rp.nrows().ToArray()[0]);
+
+ }
+ }
+}
diff --git a/test/TensorFlowNET.UnitTest/NumPy/ShapeTest.cs b/test/TensorFlowNET.UnitTest/NumPy/ShapeTest.cs
new file mode 100644
index 00000000..f5a8685b
--- /dev/null
+++ b/test/TensorFlowNET.UnitTest/NumPy/ShapeTest.cs
@@ -0,0 +1,44 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using Tensorflow.NumPy;
+using System;
+using System.Linq;
+using static Tensorflow.Binding;
+using Tensorflow;
+
+namespace TensorFlowNET.UnitTest.NumPy
+{
+ [TestClass]
+ public class ShapeTest : EagerModeTestBase
+ {
+ [Ignore]
+ [TestMethod]
+ public unsafe void ShapeGetLastElements()
+ {
+ // test code from function _CheckAtLeast3DImage
+ // 之前的 _CheckAtLeast3DImage 有bug,现在通过测试,下面的代码是正确的
+ // todo: shape["-3:"] 的写法,目前有bug,需要修复,单元测试等修复后再放开,暂时先忽略测试
+
+ var image_shape = new Shape(new[] { 32, 64, 3 });
+ var image_shape_4d = new Shape(new[] { 4, 64, 32, 3 });
+
+ var image_shape_last_three_elements = new Shape(new[] {
+ image_shape.dims[image_shape.dims.Length - 3],
+ image_shape.dims[image_shape.dims.Length - 2],
+ image_shape.dims[image_shape.dims.Length - 1]});
+
+ var image_shape_last_three_elements2 = image_shape["-3:"];
+
+ Assert.IsTrue(Equal(image_shape_last_three_elements.dims, image_shape_last_three_elements2.dims), "3dims get fail.");
+
+ var image_shape_last_three_elements_4d = new Shape(new[] {
+ image_shape_4d.dims[image_shape_4d.dims.Length - 3],
+ image_shape_4d.dims[image_shape_4d.dims.Length - 2],
+ image_shape_4d.dims[image_shape_4d.dims.Length - 1]});
+
+ var image_shape_last_three_elements2_4d = image_shape_4d["-3:"];
+
+ Assert.IsTrue(Equals(image_shape_last_three_elements_4d.dims, image_shape_last_three_elements2_4d.dims), "4dims get fail.");
+ }
+
+ }
+}
\ No newline at end of file
diff --git a/tools/TensorFlowNET.Console/Tensorflow.Console.csproj b/tools/TensorFlowNET.Console/Tensorflow.Console.csproj
index ecc2d30b..bb60b6b6 100644
--- a/tools/TensorFlowNET.Console/Tensorflow.Console.csproj
+++ b/tools/TensorFlowNET.Console/Tensorflow.Console.csproj
@@ -19,13 +19,10 @@
AnyCPU
-
-
-
-
+
diff --git a/tools/Tensorflow.CodeGen/Tensorflow.CodeGen.csproj b/tools/Tensorflow.CodeGen/Tensorflow.CodeGen.csproj
index 03195e6a..2afc68a3 100644
--- a/tools/Tensorflow.CodeGen/Tensorflow.CodeGen.csproj
+++ b/tools/Tensorflow.CodeGen/Tensorflow.CodeGen.csproj
@@ -9,7 +9,6 @@
-
diff --git a/tools/Tensorflow.UnitTest.RedistHolder/Tensorflow.UnitTest.RedistHolder.csproj b/tools/Tensorflow.UnitTest.RedistHolder/Tensorflow.UnitTest.RedistHolder.csproj
index 1ca387db..0d1018ca 100644
--- a/tools/Tensorflow.UnitTest.RedistHolder/Tensorflow.UnitTest.RedistHolder.csproj
+++ b/tools/Tensorflow.UnitTest.RedistHolder/Tensorflow.UnitTest.RedistHolder.csproj
@@ -5,7 +5,7 @@
-
+