Browse Source

Merge pull request #1190 from dogvane/master

解决keras模式下,使用GPU训练时会爆显存的bug。
tags/v0.150.0-BERT-Model
Haiping GitHub 2 years ago
parent
commit
090dc1eee3
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 983 additions and 20 deletions
  1. BIN
      data/img001.bmp
  2. +7
    -0
      src/TensorFlowNET.Core/APIs/tf.image.cs
  3. +7
    -0
      src/TensorFlowNET.Core/APIs/tf.io.cs
  4. +5
    -0
      src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs
  5. +31
    -0
      src/TensorFlowNET.Core/Gradients/nn_grad.cs
  6. +23
    -0
      src/TensorFlowNET.Core/Keras/Engine/IModel.cs
  7. +19
    -0
      src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
  8. +32
    -11
      src/TensorFlowNET.Core/Operations/image_ops_impl.cs
  9. +4
    -1
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  10. +3
    -0
      src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
  11. +6
    -6
      src/TensorFlowNET.Keras/Engine/Model.Fit.cs
  12. +1
    -1
      src/TensorFlowNET.Keras/Engine/Model.Predict.cs
  13. +167
    -0
      src/TensorFlowNET.Keras/Layers/Convolution/DepthwiseConv2D.cs
  14. +54
    -1
      src/TensorFlowNET.Keras/Layers/LayersApi.cs
  15. +90
    -0
      test/TensorFlowNET.Graph.UnitTest/ImageTest.cs
  16. +34
    -0
      test/TensorFlowNET.Keras.UnitTest/EagerModeTestBase.cs
  17. +125
    -0
      test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Convolution.Test.cs
  18. +14
    -0
      test/TensorFlowNET.UnitTest/EagerModeTestBase.cs
  19. +317
    -0
      test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs
  20. +44
    -0
      test/TensorFlowNET.UnitTest/NumPy/ShapeTest.cs

BIN
data/img001.bmp View File

Before After
Width: 244  |  Height: 244  |  Size: 179 kB

+ 7
- 0
src/TensorFlowNET.Core/APIs/tf.image.cs View File

@@ -339,6 +339,13 @@ namespace Tensorflow
=> image_ops_impl.decode_image(contents, channels: channels, dtype: dtype,
name: name, expand_animations: expand_animations);

public Tensor encode_png(Tensor contents, string name = null)
=> image_ops_impl.encode_png(contents, name: name);

public Tensor encode_jpeg(Tensor contents, string name = null)
=> image_ops_impl.encode_jpeg(contents, name: name);


/// <summary>
/// Convenience function to check if the 'contents' encodes a JPEG image.
/// </summary>


+ 7
- 0
src/TensorFlowNET.Core/APIs/tf.io.cs View File

@@ -16,6 +16,7 @@

using System.Collections.Generic;
using Tensorflow.IO;
using Tensorflow.Operations;

namespace Tensorflow
{
@@ -46,6 +47,12 @@ namespace Tensorflow
public Tensor[] restore_v2(Tensor prefix, string[] tensor_names,
string[] shape_and_slices, TF_DataType[] dtypes, string name = null)
=> ops.restore_v2(prefix, tensor_names, shape_and_slices, dtypes, name: name);

public Operation write_file(string filename, Tensor conentes, string name = null)
=> write_file(Tensorflow.ops.convert_to_tensor(filename, TF_DataType.TF_STRING), conentes, name);

public Operation write_file(Tensor filename, Tensor conentes, string name = null)
=> gen_ops.write_file(filename, conentes, name);
}

public GFile gfile = new GFile();


+ 5
- 0
src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs View File

@@ -80,6 +80,11 @@ namespace Tensorflow.Eager
Tensor[] op_outputs)
=> (out_grads, unneeded_gradients) =>
{
if(!ops.gradientFunctions.ContainsKey(op_name))
{
throw new Exception($"gradientFunctions not find op_name: {op_name}");
}

if (ops.gradientFunctions[op_name] == null)
return new Tensor[op_inputs.Length];



+ 31
- 0
src/TensorFlowNET.Core/Gradients/nn_grad.cs View File

@@ -229,6 +229,37 @@ namespace Tensorflow.Gradients
};
}

/// <summary>
/// Gradient function for Conv2D.
/// </summary>
/// <param name="op"></param>
/// <param name="grads"></param>
/// <returns></returns>
[RegisterGradient("DepthwiseConv2dNative")]
public static Tensor[] _DepthwiseConv2DGrad(Operation op, Tensor[] grads)
{
var dilations = op.get_attr_list<int>("dilations");
var strides = op.get_attr_list<int>("strides");
var padding = op.get_attr<string>("padding");
var explicit_paddings = op.get_attr_list<int>("explicit_paddings");
var data_format = op.get_attr<string>("data_format");
var shape = gen_array_ops.shape_n(new Tensor[] { op.inputs[0], op.inputs[1] });

return new Tensor[]
{
gen_nn_ops.depthwise_conv2d_native_backprop_input(
shape[0], op.inputs[1], grads[0],
strides, padding, explicit_paddings,
dilations: dilations,
data_format: data_format),
gen_nn_ops.depthwise_conv2d_native_backprop_filter(op.inputs[0], shape[1], grads[0],
strides, padding,
dilations: dilations,
explicit_paddings: explicit_paddings,
data_format: data_format)
};
}

[RegisterGradient("FusedBatchNorm")]
public static Tensor[] _FusedBatchNormGrad(Operation op, Tensor[] grads)
=> _BaseFusedBatchNormGrad(op, 0, grads);


+ 23
- 0
src/TensorFlowNET.Core/Keras/Engine/IModel.cs View File

@@ -24,6 +24,7 @@ public interface IModel : ILayer
List<ICallback> callbacks = null,
float validation_split = 0f,
ValidationDataPack validation_data = null,
int validation_step = 10,
bool shuffle = true,
Dictionary<int, float> class_weight = null,
NDArray sample_weight = null,
@@ -47,6 +48,20 @@ public interface IModel : ILayer
int workers = 1,
bool use_multiprocessing = false);

public ICallback fit(IDatasetV2 dataset,
int batch_size = -1,
int epochs = 1,
int verbose = 1,
List<ICallback> callbacks = null,
IDatasetV2 validation_data = null,
int validation_step = 10, // 间隔多少次会进行一次验证
bool shuffle = true,
Dictionary<int, float> class_weight = null,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
bool use_multiprocessing = false);

void save(string filepath,
bool overwrite = true,
bool include_optimizer = true,
@@ -85,6 +100,14 @@ public interface IModel : ILayer
int workers = 1,
bool use_multiprocessing = false);

public Tensors predict(IDatasetV2 dataset,
int batch_size = -1,
int verbose = 0,
int steps = -1,
int max_queue_size = 10,
int workers = 1,
bool use_multiprocessing = false);

void summary(int line_length = -1, float[] positions = null);

IKerasConfig get_config();


+ 19
- 0
src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs View File

@@ -55,6 +55,12 @@ namespace Tensorflow.Keras.Layers
string kernel_initializer = "glorot_uniform",
string bias_initializer = "zeros");

public ILayer Conv2D(int filters,
Shape kernel_size = null,
Shape strides = null,
string padding = "valid"
);

public ILayer Conv2D(int filters,
Shape kernel_size = null,
Shape strides = null,
@@ -95,6 +101,19 @@ namespace Tensorflow.Keras.Layers
bool use_bias = true,
string kernel_initializer = "glorot_uniform",
string bias_initializer = "zeros");
public ILayer DepthwiseConv2D(Shape kernel_size = null,
Shape strides = null,
string padding = "valid",
string data_format = null,
Shape dilation_rate = null,
int groups = 1,
int depth_multiplier = 1,
string activation = null,
bool use_bias = false,
string kernel_initializer = "glorot_uniform",
string bias_initializer = "zeros",
string depthwise_initializer = "glorot_uniform"
);

public ILayer Dense(int units);
public ILayer Dense(int units,


+ 32
- 11
src/TensorFlowNET.Core/Operations/image_ops_impl.cs View File

@@ -102,7 +102,10 @@ namespace Tensorflow
{
throw new ValueError("\'image\' must be fully defined.");
}
var dims = image_shape["-3:"];
var dims = new Shape(new[] {
image_shape.dims[image_shape.dims.Length - 3],
image_shape.dims[image_shape.dims.Length - 2],
image_shape.dims[image_shape.dims.Length - 1]});
foreach (var dim in dims.dims)
{
if (dim == 0)
@@ -112,16 +115,18 @@ namespace Tensorflow
}

var image_shape_last_three_elements = new Shape(new[] {
image_shape.dims[image_shape.dims.Length - 1],
image_shape.dims[image_shape.dims.Length - 3],
image_shape.dims[image_shape.dims.Length - 2],
image_shape.dims[image_shape.dims.Length - 3]});
image_shape.dims[image_shape.dims.Length - 1]});
if (!image_shape_last_three_elements.IsFullyDefined)
{
Tensor image_shape_ = array_ops.shape(image);
var image_shape_return = tf.constant(new[] {
image_shape_.dims[image_shape.dims.Length - 1],
image_shape_.dims[image_shape.dims.Length - 2],
image_shape_.dims[image_shape.dims.Length - 3]});
var image_shape_return = tf.slice(image_shape_, new[] { Math.Max(image_shape.dims.Length - 3, 0) }, new[] { 3 });

//var image_shape_return = tf.constant(new[] {
// image_shape_.dims[image_shape_.dims.Length - 3],
// image_shape_.dims[image_shape_.dims.Length - 2],
// image_shape_.dims[image_shape_.dims.Length - 1]});

return new Operation[] {
check_ops.assert_positive(
@@ -209,10 +214,10 @@ namespace Tensorflow
}

public static Tensor flip_left_right(Tensor image)
=> _flip(image, 0, "flip_left_right");
=> _flip(image, 1, "flip_left_right");

public static Tensor flip_up_down(Tensor image)
=> _flip(image, 1, "flip_up_down");
=> _flip(image, 0, "flip_up_down");

internal static Tensor _flip(Tensor image, int flip_index, string scope_name)
{
@@ -223,11 +228,11 @@ namespace Tensorflow
Shape shape = image.shape;
if (shape.ndim == 3 || shape.ndim == Unknown)
{
return fix_image_flip_shape(image, gen_array_ops.reverse(image, ops.convert_to_tensor(new int[] { flip_index })));
return fix_image_flip_shape(image, gen_array_ops.reverse_v2(image, ops.convert_to_tensor(new int[] { flip_index })));
}
else if (shape.ndim == 4)
{
return gen_array_ops.reverse_v2(image, ops.convert_to_tensor(new[] { (flip_index + 1) % 2 }));
return gen_array_ops.reverse_v2(image, ops.convert_to_tensor(new[] { flip_index + 1 }));
}
else
{
@@ -2047,6 +2052,22 @@ new_height, new_width");
});
}

public static Tensor encode_jpeg(Tensor contents, string name = null)
{
return tf_with(ops.name_scope(name, "encode_jpeg"), scope =>
{
return gen_ops.encode_jpeg(contents, name:name);
});
}

public static Tensor encode_png(Tensor contents, string name = null)
{
return tf_with(ops.name_scope(name, "encode_png"), scope =>
{
return gen_ops.encode_png(contents, name: name);
});
}

public static Tensor is_jpeg(Tensor contents, string name = null)
{
return tf_with(ops.name_scope(name, "is_jpeg"), scope =>


+ 4
- 1
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -249,6 +249,9 @@ namespace Tensorflow
case sbyte val:
tensor_proto.IntVal.AddRange(new[] { (int)val });
break;
case byte val:
tensor_proto.IntVal.AddRange(new[] { (int)val });
break;
case int val:
tensor_proto.IntVal.AddRange(new[] { val });
break;
@@ -262,7 +265,7 @@ namespace Tensorflow
tensor_proto.DoubleVal.AddRange(new[] { val });
break;
default:
throw new Exception("make_tensor_proto Not Implemented");
throw new Exception($"make_tensor_proto Not Implemented {values.GetType().Name}");
}
}



+ 3
- 0
src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs View File

@@ -132,6 +132,7 @@ namespace Tensorflow.Keras.Engine
var end_step = step + data_handler.StepIncrement;
if (!is_val)
callbacks.on_test_batch_end(end_step, logs);
GC.Collect();
}
}
callbacks.on_test_end(logs);
@@ -167,7 +168,9 @@ namespace Tensorflow.Keras.Engine
Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y)
{
(x,y) = data_handler.DataAdapter.Expand1d(x, y);

var y_pred = Apply(x, training: false);

var loss = compiled_loss.Call(y, y_pred);
compiled_metrics.update_state(y, y_pred);
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2);


+ 6
- 6
src/TensorFlowNET.Keras/Engine/Model.Fit.cs View File

@@ -41,6 +41,7 @@ namespace Tensorflow.Keras.Engine
List<ICallback> callbacks = null,
float validation_split = 0f,
ValidationDataPack validation_data = null,
int validation_step = 10,
bool shuffle = true,
Dictionary<int, float> class_weight = null,
NDArray sample_weight = null,
@@ -147,7 +148,7 @@ namespace Tensorflow.Keras.Engine
}
}

public History fit(IDatasetV2 dataset,
public ICallback fit(IDatasetV2 dataset,
int batch_size = -1,
int epochs = 1,
int verbose = 1,
@@ -156,7 +157,6 @@ namespace Tensorflow.Keras.Engine
int validation_step = 10,
bool shuffle = true,
Dictionary<int, float> class_weight = null,
NDArray sample_weight = null,
int initial_epoch = 0,
int max_queue_size = 10,
int workers = 1,
@@ -170,7 +170,7 @@ namespace Tensorflow.Keras.Engine
InitialEpoch = initial_epoch,
Epochs = epochs,
Shuffle = shuffle,
SampleWeight = sample_weight,
ClassWeight = class_weight,
MaxQueueSize = max_queue_size,
Workers = workers,
UseMultiprocessing = use_multiprocessing,
@@ -218,6 +218,7 @@ namespace Tensorflow.Keras.Engine
var end_step = step + data_handler.StepIncrement;
End_step = end_step;
callbacks.on_train_batch_end(end_step, logs);
GC.Collect();
}

if (validation_data != null)
@@ -233,11 +234,10 @@ namespace Tensorflow.Keras.Engine
callbacks.on_train_batch_end(End_step, logs);
}

GC.Collect();

callbacks.on_epoch_end(epoch, logs);

GC.Collect();
GC.WaitForPendingFinalizers();
if (stop_training)
{
break;
@@ -282,6 +282,7 @@ namespace Tensorflow.Keras.Engine
var end_step = step + data_handler.StepIncrement;
End_step = end_step;
callbacks.on_train_batch_end(end_step, logs);
GC.Collect();
}

if (validation_data != null)
@@ -301,7 +302,6 @@ namespace Tensorflow.Keras.Engine
callbacks.on_epoch_end(epoch, logs);

GC.Collect();
GC.WaitForPendingFinalizers();
if (stop_training)
{
break;


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Model.Predict.cs View File

@@ -102,9 +102,9 @@ namespace Tensorflow.Keras.Engine
for (int i = 0; i < batch_outputs.Length; i++)
batch_outputs[i] = tf.concat(new Tensor[] { batch_outputs[i], tmp_batch_outputs[i] }, axis: 0);
}

var end_step = step + data_handler.StepIncrement;
callbacks.on_predict_batch_end(end_step, new Dictionary<string, Tensors> { { "outputs", batch_outputs } });
GC.Collect();
}
}



+ 167
- 0
src/TensorFlowNET.Keras/Layers/Convolution/DepthwiseConv2D.cs View File

@@ -0,0 +1,167 @@
using System;
using System.Collections.Generic;
using System.Text;
using System;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Saving;
using Tensorflow.Common.Types;
using Tensorflow.Keras.Utils;
using Tensorflow.Operations;
using Newtonsoft.Json;
using System.Security.Cryptography;

namespace Tensorflow.Keras.Layers
{
public class DepthwiseConv2DArgs: Conv2DArgs
{
/// <summary>
/// depth_multiplier: The number of depthwise convolution output channels for
/// each input channel.The total number of depthwise convolution output
/// channels will be equal to `filters_in* depth_multiplier`.
/// </summary>
[JsonProperty("depth_multiplier")]
public int DepthMultiplier { get; set; } = 1;

[JsonProperty("depthwise_initializer")]
public IInitializer DepthwiseInitializer { get; set; }
}

public class DepthwiseConv2D : Conv2D
{
/// <summary>
/// depth_multiplier: The number of depthwise convolution output channels for
/// each input channel.The total number of depthwise convolution output
/// channels will be equal to `filters_in* depth_multiplier`.
/// </summary>
int DepthMultiplier = 1;
IInitializer DepthwiseInitializer;

int[] strides;

int[] dilation_rate;

string getDataFormat()
{
return data_format == "channels_first" ? "NCHW" : "NHWC";
}

static int _id = 1;

public DepthwiseConv2D(DepthwiseConv2DArgs args):base(args)
{
args.Padding = args.Padding.ToUpper();

if(string.IsNullOrEmpty(args.Name))
name = "DepthwiseConv2D_" + _id;

this.DepthMultiplier = args.DepthMultiplier;
this.DepthwiseInitializer = args.DepthwiseInitializer;

}

public override void build(KerasShapesWrapper input_shape)
{
//base.build(input_shape);

var shape = input_shape.ToSingleShape();

int channel_axis = data_format == "channels_first" ? 1 : -1;
var input_channel = channel_axis < 0 ?
shape.dims[shape.ndim + channel_axis] :
shape.dims[channel_axis];

var arg = args as DepthwiseConv2DArgs;

if (arg.Strides.ndim != shape.ndim)
{
if (arg.Strides.ndim == 2)
{
this.strides = new int[] { 1, (int)arg.Strides[0], (int)arg.Strides[1], 1 };
}
else
{
this.strides = conv_utils.normalize_tuple(new int[] { (int)arg.Strides[0] }, shape.ndim, "strides");
}
}
else
{
this.strides = arg.Strides.dims.Select(o=>(int)(o)).ToArray();
}

if (arg.DilationRate.ndim != shape.ndim)
{
this.dilation_rate = conv_utils.normalize_tuple(new int[] { (int)arg.DilationRate[0] }, shape.ndim, "dilation_rate");
}

long channel_data = data_format == "channels_first" ? shape[0] : shape[shape.Length - 1];

var depthwise_kernel_shape = this.kernel_size.dims.concat(new long[] {
channel_data,
this.DepthMultiplier
});

this.kernel = this.add_weight(
shape: depthwise_kernel_shape,
initializer: this.DepthwiseInitializer != null ? this.DepthwiseInitializer : this.kernel_initializer,
name: "depthwise_kernel",
trainable: true,
dtype: DType,
regularizer: this.kernel_regularizer
);

var axes = new Dictionary<int, int>();
axes.Add(-1, (int)input_channel);
inputSpec = new InputSpec(min_ndim: rank + 2, axes: axes);


if (use_bias)
{
bias = add_weight(name: "bias",
shape: ((int)channel_data),
initializer: bias_initializer,
trainable: true,
dtype: DType);
}

built = true;
_buildInputShape = input_shape;
}

protected override Tensors Call(Tensors inputs, Tensors state = null,
bool? training = false, IOptionalArgs? optional_args = null)
{
Tensor outputs = null;

outputs = gen_nn_ops.depthwise_conv2d_native(
inputs,
filter: this.kernel.AsTensor(),
strides: this.strides,
padding: this.padding,
dilations: this.dilation_rate,
data_format: this.getDataFormat(),
name: name
);

if (use_bias)
{
if (data_format == "channels_first")
{
throw new NotImplementedException("call channels_first");
}
else
{
outputs = gen_nn_ops.bias_add(outputs, ops.convert_to_tensor(bias),
data_format: this.getDataFormat(), name: name);
}
}

if (activation != null)
outputs = activation.Apply(outputs);


return outputs;
}

}
}

+ 54
- 1
src/TensorFlowNET.Keras/Layers/LayersApi.cs View File

@@ -112,7 +112,28 @@ namespace Tensorflow.Keras.Layers
KernelInitializer = GetInitializerByName(kernel_initializer),
BiasInitializer = GetInitializerByName(bias_initializer)
});

public ILayer Conv2D(int filters,
Shape kernel_size = null,
Shape strides = null,
string padding = "valid")
=> new Conv2D(new Conv2DArgs
{
Rank = 2,
Filters = filters,
KernelSize = (kernel_size == null) ? (5, 5) : kernel_size,
Strides = strides == null ? (1, 1) : strides,
Padding = padding,
DataFormat = null,
DilationRate = (1, 1),
Groups = 1,
UseBias = false,
KernelRegularizer = null,
KernelInitializer =tf.glorot_uniform_initializer,
BiasInitializer = tf.zeros_initializer,
BiasRegularizer = null,
ActivityRegularizer = null,
Activation = keras.activations.Linear,
});
/// <summary>
/// 2D convolution layer (e.g. spatial convolution over images).
/// This layer creates a convolution kernel that is convolved with the layer input to produce a tensor of outputs.
@@ -210,6 +231,38 @@ namespace Tensorflow.Keras.Layers
Activation = keras.activations.GetActivationFromName(activation)
});

public ILayer DepthwiseConv2D(Shape kernel_size = null,
Shape strides = null,
string padding = "valid",
string data_format = null,
Shape dilation_rate = null,
int groups = 1,
int depth_multiplier = 1,
string activation = null,
bool use_bias = false,
string kernel_initializer = "glorot_uniform",
string bias_initializer = "zeros",
string depthwise_initializer = "glorot_uniform"
)
=> new DepthwiseConv2D(new DepthwiseConv2DArgs
{
Rank = 2,
Filters = 1,
KernelSize = (kernel_size == null) ? (5, 5) : kernel_size,
Strides = strides == null ? (1) : strides,
Padding = padding,
DepthMultiplier = depth_multiplier,
DataFormat = data_format,
DilationRate = dilation_rate == null ? (1) : dilation_rate,
Groups = groups,
UseBias = use_bias,
KernelInitializer = GetInitializerByName(kernel_initializer),
DepthwiseInitializer = GetInitializerByName(depthwise_initializer == null ? kernel_initializer : depthwise_initializer),
BiasInitializer = GetInitializerByName(bias_initializer),
Activation = keras.activations.GetActivationFromName(activation),
});


/// <summary>
/// Transposed convolution layer (sometimes called Deconvolution).
/// </summary>


+ 90
- 0
test/TensorFlowNET.Graph.UnitTest/ImageTest.cs View File

@@ -4,6 +4,7 @@ using System.Linq;
using Tensorflow;
using static Tensorflow.Binding;
using System;
using System.IO;

namespace TensorFlowNET.UnitTest
{
@@ -164,5 +165,94 @@ namespace TensorFlowNET.UnitTest
Assert.AreEqual(result.size, 16ul);
Assert.AreEqual(result[0, 0, 0, 0], 12f);
}

[TestMethod]
public void ImageSaveTest()
{
var imgPath = TestHelper.GetFullPathFromDataDir("img001.bmp");
var jpegImgPath = TestHelper.GetFullPathFromDataDir("img001.jpeg");
var pngImgPath = TestHelper.GetFullPathFromDataDir("img001.png");

File.Delete(jpegImgPath);
File.Delete(pngImgPath);

var contents = tf.io.read_file(imgPath);
var bmp = tf.image.decode_image(contents);
Assert.AreEqual(bmp.name, "decode_image/DecodeImage:0");

var jpeg = tf.image.encode_jpeg(bmp);
var op1 = tf.io.write_file(jpegImgPath, jpeg);

var png = tf.image.encode_png(bmp);
var op2 = tf.io.write_file(pngImgPath, png);

this.session().run(op1);
this.session().run(op2);

Assert.IsTrue(File.Exists(jpegImgPath), "not find file:" + jpegImgPath);
Assert.IsTrue(File.Exists(pngImgPath), "not find file:" + pngImgPath);

// 如果要测试图片正确性,需要注释下面两行代码
File.Delete(jpegImgPath);
File.Delete(pngImgPath);
}

[TestMethod]
public void ImageFlipTest()
{
var imgPath = TestHelper.GetFullPathFromDataDir("img001.bmp");

var contents = tf.io.read_file(imgPath);
var bmp = tf.image.decode_image(contents);

// 左右翻转
var lrImgPath = TestHelper.GetFullPathFromDataDir("img001_lr.png");
File.Delete(lrImgPath);

var lr = tf.image.flip_left_right(bmp);
var png = tf.image.encode_png(lr);
var op = tf.io.write_file(lrImgPath, png);
this.session().run(op);

Assert.IsTrue(File.Exists(lrImgPath), "not find file:" + lrImgPath);

// 上下翻转
var updownImgPath = TestHelper.GetFullPathFromDataDir("img001_updown.png");
File.Delete(updownImgPath);

var updown = tf.image.flip_up_down(bmp);
var pngupdown = tf.image.encode_png(updown);
var op2 = tf.io.write_file(updownImgPath, pngupdown);
this.session().run(op2);
Assert.IsTrue(File.Exists(updownImgPath));


// 暂时先人工观测图片是否翻转,观测时需要删除下面这两行代码
File.Delete(lrImgPath);
File.Delete(updownImgPath);

// 多图翻转
// 目前直接通过 bmp 拿到 shape ,这里先用默认定义图片大小来构建了
var mImg = tf.stack(new[] { bmp, lr }, axis:0);
print(mImg.shape);

var up2 = tf.image.flip_up_down(mImg);

var updownImgPath_m1 = TestHelper.GetFullPathFromDataDir("img001_m_ud.png"); // 直接上下翻转
File.Delete(updownImgPath_m1);

var img001_updown_m2 = TestHelper.GetFullPathFromDataDir("img001_m_lr_ud.png"); // 先左右再上下
File.Delete(img001_updown_m2);

var png2 = tf.image.encode_png(up2[0]);
tf.io.write_file(updownImgPath_m1, png2);

png2 = tf.image.encode_png(up2[1]);
tf.io.write_file(img001_updown_m2, png2);

// 如果要测试图片正确性,需要注释下面两行代码
File.Delete(updownImgPath_m1);
File.Delete(img001_updown_m2);
}
}
}

+ 34
- 0
test/TensorFlowNET.Keras.UnitTest/EagerModeTestBase.cs View File

@@ -33,6 +33,40 @@ namespace Tensorflow.Keras.UnitTest
return ret;
}


public void AssertArray(int[] f1, int[] f2)
{
bool ret = false;
for (var i = 0; i < f1.Length; i++)
{
ret = f1[i] == f2[i];
if (!ret)
break;
}

if (!ret)
{
Assert.Fail($"Array not Equal:[{string.Join(",", f1)}] [{string.Join(",", f2)}]");
}
}

public void AssertArray(float[] f1, float[] f2)
{
bool ret = false;
var tolerance = .00001f;
for (var i = 0; i < f1.Length; i++)
{
ret = Math.Abs(f1[i] - f2[i]) <= tolerance;
if (!ret)
break;
}

if (!ret)
{
Assert.Fail($"Array float not Equal:[{string.Join(",", f1)}] [{string.Join(",", f2)}]");
}
}

public bool Equal(double[] d1, double[] d2)
{
bool ret = false;


+ 125
- 0
test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Convolution.Test.cs View File

@@ -1,6 +1,8 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Linq;
using Tensorflow.NumPy;
using static Tensorflow.KerasApi;
using static Tensorflow.Binding;

namespace Tensorflow.Keras.UnitTest.Layers
{
@@ -193,5 +195,128 @@ namespace Tensorflow.Keras.UnitTest.Layers
Assert.AreEqual(x.dims[2], y.shape[2]);
Assert.AreEqual(filters, y.shape[3]);
}


[TestMethod]
public void BasicDepthwiseConv2D()
{
var conv = keras.layers.DepthwiseConv2D(kernel_size:3, strides:1, activation: null,
padding:"same", depthwise_initializer: "ones");

var x = np.arange(2 * 9* 9* 3).reshape((2, 9, 9, 3));
var x2 = ops.convert_to_tensor(x, TF_DataType.TF_FLOAT);

var y = conv.Apply(x2);

print($"input:{x2.shape} DepthwiseConv2D.out: {y.shape}");


Assert.AreEqual(4, y.shape.ndim);
var arr = y.numpy().reshape((2, 9, 9, 3));

AssertArray(x[new int[] { 1, 1, 1 }].ToArray<int>(), new int[] { 273, 274, 275 });
AssertArray(arr[new int[] { 1, 1, 1 }].ToArray<float>(), new float[] { 2457f, 2466f, 2475f });

var bn = keras.layers.BatchNormalization();
var y2 = bn.Apply(y);
arr = y2.numpy().ToArray<float>();

double delta = 0.0001; // 误差范围

Assert.AreEqual(arr[0], 59.97002f, delta);
Assert.AreEqual(arr[1], 63.96802f, delta);
}


[TestMethod]
public void BasicDepthwiseConv2D_strides_2()
{
var conv = keras.layers.DepthwiseConv2D(kernel_size: 3, strides: (1, 2, 2, 1), activation: null,
padding: "same", depthwise_initializer: "ones");

var x = np.arange(2 * 9 * 9 * 3).reshape((2, 9, 9, 3));
var x2 = ops.convert_to_tensor(x, TF_DataType.TF_FLOAT);

var y = conv.Apply(x2);

print($"input:{x2.shape} DepthwiseConv2D.out: {y.shape}");

Assert.AreEqual(4, y.shape.ndim);
var arr = y.numpy().reshape((2, 5, 5, 3));

AssertArray(x[new int[] { 1, 1, 1 }].ToArray<int>(), new int[] { 273, 274, 275 });
AssertArray(arr[new int[] { 1, 1, 1 }].ToArray<float>(), new float[] { 2727f, 2736f, 2745f });

var bn = keras.layers.BatchNormalization();
var y2 = bn.Apply(y);
arr = y2.numpy().ToArray<float>();

double delta = 0.0001; // 误差范围

Assert.AreEqual(arr[0], 59.97002f, delta);
Assert.AreEqual(arr[1], 63.96802f, delta);
}



[TestMethod]
public void BasicDepthwiseConv2D_strides_3()
{
var conv = keras.layers.DepthwiseConv2D(kernel_size: 3, strides: 3, activation: null,
padding: "same", depthwise_initializer: "ones");

var x = np.arange(2 * 9 * 9 * 3).reshape((2, 9, 9, 3));
var x2 = ops.convert_to_tensor(x, TF_DataType.TF_FLOAT);

var y = conv.Apply(x2);

print($"input:{x2.shape} DepthwiseConv2D.out: {y.shape}");

Assert.AreEqual(4, y.shape.ndim);
var arr = y.numpy().reshape((2, 3, 3, 3));

AssertArray(x[new int[] { 1, 1, 1 }].ToArray<int>(), new int[] { 273, 274, 275 });
AssertArray(arr[new int[] { 1, 1, 1 }].ToArray<float>(), new float[] { 3267f, 3276f, 3285f });

var bn = keras.layers.BatchNormalization();
var y2 = bn.Apply(y);
arr = y2.numpy().ToArray<float>();

double delta = 0.0001; // 误差范围
Assert.AreEqual(arr[0], 269.86508f, delta);
Assert.AreEqual(arr[1], 278.8606f, delta);

}
[TestMethod]
public void BasicDepthwiseConv2D_UseBias()
{
var conv = keras.layers.DepthwiseConv2D(kernel_size: 3, strides: 1, activation: null,
use_bias: true, padding: "same",
depthwise_initializer: "ones",
bias_initializer:"ones"
);

var weight = conv.get_weights();

var x = np.arange(9 * 9 * 3).reshape((1, 9, 9, 3));
var x2 = ops.convert_to_tensor(x, TF_DataType.TF_FLOAT);
var y = conv.Apply(x2);

Assert.AreEqual(4, y.shape.ndim);
var arr = y.numpy().ToArray<float>();

Assert.AreEqual(arr[0], 61f);
Assert.AreEqual(arr[1], 65f);

var bn = keras.layers.BatchNormalization();
var y2 = bn.Apply(y);
arr = y2.numpy().ToArray<float>();

double delta = 0.0001; // 误差范围

Assert.AreEqual(arr[0], 60.96952f, delta);
Assert.AreEqual(arr[1], 64.96752f, delta);
}
}
}

+ 14
- 0
test/TensorFlowNET.UnitTest/EagerModeTestBase.cs View File

@@ -20,6 +20,20 @@ namespace TensorFlowNET.UnitTest
return Math.Abs(f1 - f2) <= tolerance;
}

public bool Equal(long[] l1, long[] l2)
{
if (l1.Length != l2.Length)
return false;

for (var i = 0; i < l1.Length; i++)
{
if (l1[i] != l2[i])
return false;
}

return true;
}

public bool Equal(float[] f1, float[] f2)
{
bool ret = false;


+ 317
- 0
test/TensorFlowNET.UnitTest/ManagedAPI/ArrayOpsTest.cs View File

@@ -3,6 +3,7 @@ using Tensorflow.NumPy;
using Tensorflow;
using static Tensorflow.Binding;
using System.Linq;
using Tensorflow.Operations;

namespace TensorFlowNET.UnitTest.ManagedAPI
{
@@ -105,5 +106,321 @@ namespace TensorFlowNET.UnitTest.ManagedAPI
Assert.IsTrue(Equal(a[0].ToArray<float>().Reverse().ToArray(), b[0].ToArray<float>()));
Assert.IsTrue(Equal(a[1].ToArray<float>().Reverse().ToArray(), b[1].ToArray<float>()));
}

[TestMethod]
public void ReverseImgArray3D()
{
// 创建 sourceImg 数组
var sourceImgArray = new float[,,] {
{
{ 237, 28, 36 },
{ 255, 255, 255 },
{ 255, 255, 255 }
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
}
};
var sourceImg = ops.convert_to_tensor(sourceImgArray);

// 创建 lrImg 数组
var lrImgArray = new float[,,] {
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 237, 28, 36 }
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
}
};
var lrImg = ops.convert_to_tensor(lrImgArray);

var lr = tf.image.flip_left_right(sourceImg);
Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr.numpy().ToArray<float>()), "tf.image.flip_left_right fail.");

var lr2 = tf.reverse(sourceImg, 1);
Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr2.numpy().ToArray<float>()), "tf.reverse (axis=1) fail.");

var lr3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 1 }));
Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr3.numpy().ToArray<float>()), "gen_array_ops.reverse_v2 axis=1 fail.");

// 创建 udImg 数组
var udImgArray = new float[,,] {
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
},
{
{ 237, 28, 36 },
{ 255, 255, 255 },
{ 255, 255, 255 }
}
};
var udImg = ops.convert_to_tensor(udImgArray);

var ud = tf.image.flip_up_down(sourceImg);
Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud.numpy().ToArray<float>()), "tf.image.flip_up_down fail.");

var ud2 = tf.reverse(sourceImg, new Axis(0));
Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud2.numpy().ToArray<float>()), "tf.reverse (axis=0) fail.");

var ud3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 0 }));
Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud3.numpy().ToArray<float>()), "gen_array_ops.reverse_v2 axis=0 fail.");
}

[TestMethod]
public void ReverseImgArray4D()
{
// 原图左上角,加一张左右翻转后的图片
var m = new float[,,,] {
{
{
{ 237, 28, 36 },
{ 255, 255, 255 },
{ 255, 255, 255 }
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
}
},
{
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 237, 28, 36 }
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
}
}
};
var sourceImg = ops.convert_to_tensor(m);

var lrArray = new float[,,,] {
{
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 237, 28, 36 },
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
}
},
{
{
{ 237, 28, 36 },
{ 255, 255, 255 },
{ 255, 255, 255 },
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
}
}
};
var lrImg = ops.convert_to_tensor(lrArray);

// 创建 ud 数组
var udArray = new float[,,,] {
{
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
},
{
{ 237, 28, 36 },
{ 255, 255, 255 },
{ 255, 255, 255 }
}
},
{
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 237, 28, 36 }
}
}
};
var udImg = ops.convert_to_tensor(udArray);

var ud3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 1 }));
Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud3.numpy().ToArray<float>()), "gen_array_ops.reverse_v2 axis=1 fail.");

var ud2 = tf.reverse(sourceImg, new Axis(1));
Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud2.numpy().ToArray<float>()), "tf.reverse (axis=1) fail.");

var ud = tf.image.flip_up_down(sourceImg);
Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud.numpy().ToArray<float>()), "tf.image.flip_up_down fail.");

// 左右翻转
var lr = tf.image.flip_left_right(sourceImg);
Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr.numpy().ToArray<float>()), "tf.image.flip_left_right fail.");

var lr2 = tf.reverse(sourceImg, 0);
Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr2.numpy().ToArray<float>()), "tf.reverse (axis=1) fail.");

var lr3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 0 }));
Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr3.numpy().ToArray<float>()), "gen_array_ops.reverse_v2 axis=1 fail.");

}

[TestMethod]
public void ReverseImgArray4D_3x3()
{
// 原图左上角,加一张左右翻转后的图片
var m = new float[,,,] {
{
{
{ 237, 28, 36 },
{ 255, 255, 255 },
{ 255, 255, 255 }
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
}
},
{
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 237, 28, 36 }
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
}
}
};
var sourceImg = ops.convert_to_tensor(m);

var lrArray = new float[,,,] {
{
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 237, 28, 36 },
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
}
},
{
{
{ 237, 28, 36 },
{ 255, 255, 255 },
{ 255, 255, 255 },
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
}
}
};
var lrImg = ops.convert_to_tensor(lrArray);

// 创建 ud 数组
var udArray = new float[,,,] {
{
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
},
{
{ 237, 28, 36 },
{ 255, 255, 255 },
{ 255, 255, 255 }
}
},
{ {
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 255, 255, 255 }
},
{
{ 255, 255, 255 },
{ 255, 255, 255 },
{ 237, 28, 36 }
}
}
};
var udImg = ops.convert_to_tensor(udArray);

var ud3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 1 }));
Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud3.numpy().ToArray<float>()), "gen_array_ops.reverse_v2 axis=1 fail.");

var ud2 = tf.reverse(sourceImg, new Axis(1));
Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud2.numpy().ToArray<float>()), "tf.reverse (axis=1) fail.");

var ud = tf.image.flip_up_down(sourceImg);
Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud.numpy().ToArray<float>()), "tf.image.flip_up_down fail.");

// 左右翻转
var lr = tf.image.flip_left_right(sourceImg);
Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr.numpy().ToArray<float>()), "tf.image.flip_left_right fail.");

var lr2 = tf.reverse(sourceImg, 0);
Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr2.numpy().ToArray<float>()), "tf.reverse (axis=1) fail.");

var lr3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 0 }));
Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr3.numpy().ToArray<float>()), "gen_array_ops.reverse_v2 axis=1 fail.");

}
}
}

+ 44
- 0
test/TensorFlowNET.UnitTest/NumPy/ShapeTest.cs View File

@@ -0,0 +1,44 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.NumPy;
using System;
using System.Linq;
using static Tensorflow.Binding;
using Tensorflow;

namespace TensorFlowNET.UnitTest.NumPy
{
[TestClass]
public class ShapeTest : EagerModeTestBase
{
[Ignore]
[TestMethod]
public unsafe void ShapeGetLastElements()
{
// test code from function _CheckAtLeast3DImage
// 之前的 _CheckAtLeast3DImage 有bug,现在通过测试,下面的代码是正确的
// todo: shape["-3:"] 的写法,目前有bug,需要修复,单元测试等修复后再放开,暂时先忽略测试

var image_shape = new Shape(new[] { 32, 64, 3 });
var image_shape_4d = new Shape(new[] { 4, 64, 32, 3 });

var image_shape_last_three_elements = new Shape(new[] {
image_shape.dims[image_shape.dims.Length - 3],
image_shape.dims[image_shape.dims.Length - 2],
image_shape.dims[image_shape.dims.Length - 1]});

var image_shape_last_three_elements2 = image_shape["-3:"];

Assert.IsTrue(Equal(image_shape_last_three_elements.dims, image_shape_last_three_elements2.dims), "3dims get fail.");

var image_shape_last_three_elements_4d = new Shape(new[] {
image_shape_4d.dims[image_shape_4d.dims.Length - 3],
image_shape_4d.dims[image_shape_4d.dims.Length - 2],
image_shape_4d.dims[image_shape_4d.dims.Length - 1]});

var image_shape_last_three_elements2_4d = image_shape_4d["-3:"];

Assert.IsTrue(Equals(image_shape_last_three_elements_4d.dims, image_shape_last_three_elements2_4d.dims), "4dims get fail.");
}

}
}

Loading…
Cancel
Save