解决keras模式下,使用GPU训练时会爆显存的bug。tags/v0.150.0-BERT-Model
@@ -339,6 +339,13 @@ namespace Tensorflow | |||||
=> image_ops_impl.decode_image(contents, channels: channels, dtype: dtype, | => image_ops_impl.decode_image(contents, channels: channels, dtype: dtype, | ||||
name: name, expand_animations: expand_animations); | name: name, expand_animations: expand_animations); | ||||
public Tensor encode_png(Tensor contents, string name = null) | |||||
=> image_ops_impl.encode_png(contents, name: name); | |||||
public Tensor encode_jpeg(Tensor contents, string name = null) | |||||
=> image_ops_impl.encode_jpeg(contents, name: name); | |||||
/// <summary> | /// <summary> | ||||
/// Convenience function to check if the 'contents' encodes a JPEG image. | /// Convenience function to check if the 'contents' encodes a JPEG image. | ||||
/// </summary> | /// </summary> | ||||
@@ -16,6 +16,7 @@ | |||||
using System.Collections.Generic; | using System.Collections.Generic; | ||||
using Tensorflow.IO; | using Tensorflow.IO; | ||||
using Tensorflow.Operations; | |||||
namespace Tensorflow | namespace Tensorflow | ||||
{ | { | ||||
@@ -46,6 +47,12 @@ namespace Tensorflow | |||||
public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, | public Tensor[] restore_v2(Tensor prefix, string[] tensor_names, | ||||
string[] shape_and_slices, TF_DataType[] dtypes, string name = null) | string[] shape_and_slices, TF_DataType[] dtypes, string name = null) | ||||
=> ops.restore_v2(prefix, tensor_names, shape_and_slices, dtypes, name: name); | => ops.restore_v2(prefix, tensor_names, shape_and_slices, dtypes, name: name); | ||||
public Operation write_file(string filename, Tensor conentes, string name = null) | |||||
=> write_file(Tensorflow.ops.convert_to_tensor(filename, TF_DataType.TF_STRING), conentes, name); | |||||
public Operation write_file(Tensor filename, Tensor conentes, string name = null) | |||||
=> gen_ops.write_file(filename, conentes, name); | |||||
} | } | ||||
public GFile gfile = new GFile(); | public GFile gfile = new GFile(); | ||||
@@ -80,6 +80,11 @@ namespace Tensorflow.Eager | |||||
Tensor[] op_outputs) | Tensor[] op_outputs) | ||||
=> (out_grads, unneeded_gradients) => | => (out_grads, unneeded_gradients) => | ||||
{ | { | ||||
if(!ops.gradientFunctions.ContainsKey(op_name)) | |||||
{ | |||||
throw new Exception($"gradientFunctions not find op_name: {op_name}"); | |||||
} | |||||
if (ops.gradientFunctions[op_name] == null) | if (ops.gradientFunctions[op_name] == null) | ||||
return new Tensor[op_inputs.Length]; | return new Tensor[op_inputs.Length]; | ||||
@@ -229,6 +229,37 @@ namespace Tensorflow.Gradients | |||||
}; | }; | ||||
} | } | ||||
/// <summary> | |||||
/// Gradient function for Conv2D. | |||||
/// </summary> | |||||
/// <param name="op"></param> | |||||
/// <param name="grads"></param> | |||||
/// <returns></returns> | |||||
[RegisterGradient("DepthwiseConv2dNative")] | |||||
public static Tensor[] _DepthwiseConv2DGrad(Operation op, Tensor[] grads) | |||||
{ | |||||
var dilations = op.get_attr_list<int>("dilations"); | |||||
var strides = op.get_attr_list<int>("strides"); | |||||
var padding = op.get_attr<string>("padding"); | |||||
var explicit_paddings = op.get_attr_list<int>("explicit_paddings"); | |||||
var data_format = op.get_attr<string>("data_format"); | |||||
var shape = gen_array_ops.shape_n(new Tensor[] { op.inputs[0], op.inputs[1] }); | |||||
return new Tensor[] | |||||
{ | |||||
gen_nn_ops.depthwise_conv2d_native_backprop_input( | |||||
shape[0], op.inputs[1], grads[0], | |||||
strides, padding, explicit_paddings, | |||||
dilations: dilations, | |||||
data_format: data_format), | |||||
gen_nn_ops.depthwise_conv2d_native_backprop_filter(op.inputs[0], shape[1], grads[0], | |||||
strides, padding, | |||||
dilations: dilations, | |||||
explicit_paddings: explicit_paddings, | |||||
data_format: data_format) | |||||
}; | |||||
} | |||||
[RegisterGradient("FusedBatchNorm")] | [RegisterGradient("FusedBatchNorm")] | ||||
public static Tensor[] _FusedBatchNormGrad(Operation op, Tensor[] grads) | public static Tensor[] _FusedBatchNormGrad(Operation op, Tensor[] grads) | ||||
=> _BaseFusedBatchNormGrad(op, 0, grads); | => _BaseFusedBatchNormGrad(op, 0, grads); | ||||
@@ -24,6 +24,7 @@ public interface IModel : ILayer | |||||
List<ICallback> callbacks = null, | List<ICallback> callbacks = null, | ||||
float validation_split = 0f, | float validation_split = 0f, | ||||
ValidationDataPack validation_data = null, | ValidationDataPack validation_data = null, | ||||
int validation_step = 10, | |||||
bool shuffle = true, | bool shuffle = true, | ||||
Dictionary<int, float> class_weight = null, | Dictionary<int, float> class_weight = null, | ||||
NDArray sample_weight = null, | NDArray sample_weight = null, | ||||
@@ -47,6 +48,20 @@ public interface IModel : ILayer | |||||
int workers = 1, | int workers = 1, | ||||
bool use_multiprocessing = false); | bool use_multiprocessing = false); | ||||
public ICallback fit(IDatasetV2 dataset, | |||||
int batch_size = -1, | |||||
int epochs = 1, | |||||
int verbose = 1, | |||||
List<ICallback> callbacks = null, | |||||
IDatasetV2 validation_data = null, | |||||
int validation_step = 10, // 间隔多少次会进行一次验证 | |||||
bool shuffle = true, | |||||
Dictionary<int, float> class_weight = null, | |||||
int initial_epoch = 0, | |||||
int max_queue_size = 10, | |||||
int workers = 1, | |||||
bool use_multiprocessing = false); | |||||
void save(string filepath, | void save(string filepath, | ||||
bool overwrite = true, | bool overwrite = true, | ||||
bool include_optimizer = true, | bool include_optimizer = true, | ||||
@@ -85,6 +100,14 @@ public interface IModel : ILayer | |||||
int workers = 1, | int workers = 1, | ||||
bool use_multiprocessing = false); | bool use_multiprocessing = false); | ||||
public Tensors predict(IDatasetV2 dataset, | |||||
int batch_size = -1, | |||||
int verbose = 0, | |||||
int steps = -1, | |||||
int max_queue_size = 10, | |||||
int workers = 1, | |||||
bool use_multiprocessing = false); | |||||
void summary(int line_length = -1, float[] positions = null); | void summary(int line_length = -1, float[] positions = null); | ||||
IKerasConfig get_config(); | IKerasConfig get_config(); | ||||
@@ -55,6 +55,12 @@ namespace Tensorflow.Keras.Layers | |||||
string kernel_initializer = "glorot_uniform", | string kernel_initializer = "glorot_uniform", | ||||
string bias_initializer = "zeros"); | string bias_initializer = "zeros"); | ||||
public ILayer Conv2D(int filters, | |||||
Shape kernel_size = null, | |||||
Shape strides = null, | |||||
string padding = "valid" | |||||
); | |||||
public ILayer Conv2D(int filters, | public ILayer Conv2D(int filters, | ||||
Shape kernel_size = null, | Shape kernel_size = null, | ||||
Shape strides = null, | Shape strides = null, | ||||
@@ -95,6 +101,19 @@ namespace Tensorflow.Keras.Layers | |||||
bool use_bias = true, | bool use_bias = true, | ||||
string kernel_initializer = "glorot_uniform", | string kernel_initializer = "glorot_uniform", | ||||
string bias_initializer = "zeros"); | string bias_initializer = "zeros"); | ||||
public ILayer DepthwiseConv2D(Shape kernel_size = null, | |||||
Shape strides = null, | |||||
string padding = "valid", | |||||
string data_format = null, | |||||
Shape dilation_rate = null, | |||||
int groups = 1, | |||||
int depth_multiplier = 1, | |||||
string activation = null, | |||||
bool use_bias = false, | |||||
string kernel_initializer = "glorot_uniform", | |||||
string bias_initializer = "zeros", | |||||
string depthwise_initializer = "glorot_uniform" | |||||
); | |||||
public ILayer Dense(int units); | public ILayer Dense(int units); | ||||
public ILayer Dense(int units, | public ILayer Dense(int units, | ||||
@@ -102,7 +102,10 @@ namespace Tensorflow | |||||
{ | { | ||||
throw new ValueError("\'image\' must be fully defined."); | throw new ValueError("\'image\' must be fully defined."); | ||||
} | } | ||||
var dims = image_shape["-3:"]; | |||||
var dims = new Shape(new[] { | |||||
image_shape.dims[image_shape.dims.Length - 3], | |||||
image_shape.dims[image_shape.dims.Length - 2], | |||||
image_shape.dims[image_shape.dims.Length - 1]}); | |||||
foreach (var dim in dims.dims) | foreach (var dim in dims.dims) | ||||
{ | { | ||||
if (dim == 0) | if (dim == 0) | ||||
@@ -112,16 +115,18 @@ namespace Tensorflow | |||||
} | } | ||||
var image_shape_last_three_elements = new Shape(new[] { | var image_shape_last_three_elements = new Shape(new[] { | ||||
image_shape.dims[image_shape.dims.Length - 1], | |||||
image_shape.dims[image_shape.dims.Length - 3], | |||||
image_shape.dims[image_shape.dims.Length - 2], | image_shape.dims[image_shape.dims.Length - 2], | ||||
image_shape.dims[image_shape.dims.Length - 3]}); | |||||
image_shape.dims[image_shape.dims.Length - 1]}); | |||||
if (!image_shape_last_three_elements.IsFullyDefined) | if (!image_shape_last_three_elements.IsFullyDefined) | ||||
{ | { | ||||
Tensor image_shape_ = array_ops.shape(image); | Tensor image_shape_ = array_ops.shape(image); | ||||
var image_shape_return = tf.constant(new[] { | |||||
image_shape_.dims[image_shape.dims.Length - 1], | |||||
image_shape_.dims[image_shape.dims.Length - 2], | |||||
image_shape_.dims[image_shape.dims.Length - 3]}); | |||||
var image_shape_return = tf.slice(image_shape_, new[] { Math.Max(image_shape.dims.Length - 3, 0) }, new[] { 3 }); | |||||
//var image_shape_return = tf.constant(new[] { | |||||
// image_shape_.dims[image_shape_.dims.Length - 3], | |||||
// image_shape_.dims[image_shape_.dims.Length - 2], | |||||
// image_shape_.dims[image_shape_.dims.Length - 1]}); | |||||
return new Operation[] { | return new Operation[] { | ||||
check_ops.assert_positive( | check_ops.assert_positive( | ||||
@@ -209,10 +214,10 @@ namespace Tensorflow | |||||
} | } | ||||
public static Tensor flip_left_right(Tensor image) | public static Tensor flip_left_right(Tensor image) | ||||
=> _flip(image, 0, "flip_left_right"); | |||||
=> _flip(image, 1, "flip_left_right"); | |||||
public static Tensor flip_up_down(Tensor image) | public static Tensor flip_up_down(Tensor image) | ||||
=> _flip(image, 1, "flip_up_down"); | |||||
=> _flip(image, 0, "flip_up_down"); | |||||
internal static Tensor _flip(Tensor image, int flip_index, string scope_name) | internal static Tensor _flip(Tensor image, int flip_index, string scope_name) | ||||
{ | { | ||||
@@ -223,11 +228,11 @@ namespace Tensorflow | |||||
Shape shape = image.shape; | Shape shape = image.shape; | ||||
if (shape.ndim == 3 || shape.ndim == Unknown) | if (shape.ndim == 3 || shape.ndim == Unknown) | ||||
{ | { | ||||
return fix_image_flip_shape(image, gen_array_ops.reverse(image, ops.convert_to_tensor(new int[] { flip_index }))); | |||||
return fix_image_flip_shape(image, gen_array_ops.reverse_v2(image, ops.convert_to_tensor(new int[] { flip_index }))); | |||||
} | } | ||||
else if (shape.ndim == 4) | else if (shape.ndim == 4) | ||||
{ | { | ||||
return gen_array_ops.reverse_v2(image, ops.convert_to_tensor(new[] { (flip_index + 1) % 2 })); | |||||
return gen_array_ops.reverse_v2(image, ops.convert_to_tensor(new[] { flip_index + 1 })); | |||||
} | } | ||||
else | else | ||||
{ | { | ||||
@@ -2047,6 +2052,22 @@ new_height, new_width"); | |||||
}); | }); | ||||
} | } | ||||
public static Tensor encode_jpeg(Tensor contents, string name = null) | |||||
{ | |||||
return tf_with(ops.name_scope(name, "encode_jpeg"), scope => | |||||
{ | |||||
return gen_ops.encode_jpeg(contents, name:name); | |||||
}); | |||||
} | |||||
public static Tensor encode_png(Tensor contents, string name = null) | |||||
{ | |||||
return tf_with(ops.name_scope(name, "encode_png"), scope => | |||||
{ | |||||
return gen_ops.encode_png(contents, name: name); | |||||
}); | |||||
} | |||||
public static Tensor is_jpeg(Tensor contents, string name = null) | public static Tensor is_jpeg(Tensor contents, string name = null) | ||||
{ | { | ||||
return tf_with(ops.name_scope(name, "is_jpeg"), scope => | return tf_with(ops.name_scope(name, "is_jpeg"), scope => | ||||
@@ -249,6 +249,9 @@ namespace Tensorflow | |||||
case sbyte val: | case sbyte val: | ||||
tensor_proto.IntVal.AddRange(new[] { (int)val }); | tensor_proto.IntVal.AddRange(new[] { (int)val }); | ||||
break; | break; | ||||
case byte val: | |||||
tensor_proto.IntVal.AddRange(new[] { (int)val }); | |||||
break; | |||||
case int val: | case int val: | ||||
tensor_proto.IntVal.AddRange(new[] { val }); | tensor_proto.IntVal.AddRange(new[] { val }); | ||||
break; | break; | ||||
@@ -262,7 +265,7 @@ namespace Tensorflow | |||||
tensor_proto.DoubleVal.AddRange(new[] { val }); | tensor_proto.DoubleVal.AddRange(new[] { val }); | ||||
break; | break; | ||||
default: | default: | ||||
throw new Exception("make_tensor_proto Not Implemented"); | |||||
throw new Exception($"make_tensor_proto Not Implemented {values.GetType().Name}"); | |||||
} | } | ||||
} | } | ||||
@@ -132,6 +132,7 @@ namespace Tensorflow.Keras.Engine | |||||
var end_step = step + data_handler.StepIncrement; | var end_step = step + data_handler.StepIncrement; | ||||
if (!is_val) | if (!is_val) | ||||
callbacks.on_test_batch_end(end_step, logs); | callbacks.on_test_batch_end(end_step, logs); | ||||
GC.Collect(); | |||||
} | } | ||||
} | } | ||||
callbacks.on_test_end(logs); | callbacks.on_test_end(logs); | ||||
@@ -167,7 +168,9 @@ namespace Tensorflow.Keras.Engine | |||||
Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y) | Dictionary<string, float> test_step(DataHandler data_handler, Tensors x, Tensors y) | ||||
{ | { | ||||
(x,y) = data_handler.DataAdapter.Expand1d(x, y); | (x,y) = data_handler.DataAdapter.Expand1d(x, y); | ||||
var y_pred = Apply(x, training: false); | var y_pred = Apply(x, training: false); | ||||
var loss = compiled_loss.Call(y, y_pred); | var loss = compiled_loss.Call(y, y_pred); | ||||
compiled_metrics.update_state(y, y_pred); | compiled_metrics.update_state(y, y_pred); | ||||
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2); | return metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Item1, x => (float)x.Item2); | ||||
@@ -41,6 +41,7 @@ namespace Tensorflow.Keras.Engine | |||||
List<ICallback> callbacks = null, | List<ICallback> callbacks = null, | ||||
float validation_split = 0f, | float validation_split = 0f, | ||||
ValidationDataPack validation_data = null, | ValidationDataPack validation_data = null, | ||||
int validation_step = 10, | |||||
bool shuffle = true, | bool shuffle = true, | ||||
Dictionary<int, float> class_weight = null, | Dictionary<int, float> class_weight = null, | ||||
NDArray sample_weight = null, | NDArray sample_weight = null, | ||||
@@ -147,7 +148,7 @@ namespace Tensorflow.Keras.Engine | |||||
} | } | ||||
} | } | ||||
public History fit(IDatasetV2 dataset, | |||||
public ICallback fit(IDatasetV2 dataset, | |||||
int batch_size = -1, | int batch_size = -1, | ||||
int epochs = 1, | int epochs = 1, | ||||
int verbose = 1, | int verbose = 1, | ||||
@@ -156,7 +157,6 @@ namespace Tensorflow.Keras.Engine | |||||
int validation_step = 10, | int validation_step = 10, | ||||
bool shuffle = true, | bool shuffle = true, | ||||
Dictionary<int, float> class_weight = null, | Dictionary<int, float> class_weight = null, | ||||
NDArray sample_weight = null, | |||||
int initial_epoch = 0, | int initial_epoch = 0, | ||||
int max_queue_size = 10, | int max_queue_size = 10, | ||||
int workers = 1, | int workers = 1, | ||||
@@ -170,7 +170,7 @@ namespace Tensorflow.Keras.Engine | |||||
InitialEpoch = initial_epoch, | InitialEpoch = initial_epoch, | ||||
Epochs = epochs, | Epochs = epochs, | ||||
Shuffle = shuffle, | Shuffle = shuffle, | ||||
SampleWeight = sample_weight, | |||||
ClassWeight = class_weight, | |||||
MaxQueueSize = max_queue_size, | MaxQueueSize = max_queue_size, | ||||
Workers = workers, | Workers = workers, | ||||
UseMultiprocessing = use_multiprocessing, | UseMultiprocessing = use_multiprocessing, | ||||
@@ -218,6 +218,7 @@ namespace Tensorflow.Keras.Engine | |||||
var end_step = step + data_handler.StepIncrement; | var end_step = step + data_handler.StepIncrement; | ||||
End_step = end_step; | End_step = end_step; | ||||
callbacks.on_train_batch_end(end_step, logs); | callbacks.on_train_batch_end(end_step, logs); | ||||
GC.Collect(); | |||||
} | } | ||||
if (validation_data != null) | if (validation_data != null) | ||||
@@ -233,11 +234,10 @@ namespace Tensorflow.Keras.Engine | |||||
callbacks.on_train_batch_end(End_step, logs); | callbacks.on_train_batch_end(End_step, logs); | ||||
} | } | ||||
GC.Collect(); | |||||
callbacks.on_epoch_end(epoch, logs); | callbacks.on_epoch_end(epoch, logs); | ||||
GC.Collect(); | |||||
GC.WaitForPendingFinalizers(); | |||||
if (stop_training) | if (stop_training) | ||||
{ | { | ||||
break; | break; | ||||
@@ -282,6 +282,7 @@ namespace Tensorflow.Keras.Engine | |||||
var end_step = step + data_handler.StepIncrement; | var end_step = step + data_handler.StepIncrement; | ||||
End_step = end_step; | End_step = end_step; | ||||
callbacks.on_train_batch_end(end_step, logs); | callbacks.on_train_batch_end(end_step, logs); | ||||
GC.Collect(); | |||||
} | } | ||||
if (validation_data != null) | if (validation_data != null) | ||||
@@ -301,7 +302,6 @@ namespace Tensorflow.Keras.Engine | |||||
callbacks.on_epoch_end(epoch, logs); | callbacks.on_epoch_end(epoch, logs); | ||||
GC.Collect(); | GC.Collect(); | ||||
GC.WaitForPendingFinalizers(); | |||||
if (stop_training) | if (stop_training) | ||||
{ | { | ||||
break; | break; | ||||
@@ -102,9 +102,9 @@ namespace Tensorflow.Keras.Engine | |||||
for (int i = 0; i < batch_outputs.Length; i++) | for (int i = 0; i < batch_outputs.Length; i++) | ||||
batch_outputs[i] = tf.concat(new Tensor[] { batch_outputs[i], tmp_batch_outputs[i] }, axis: 0); | batch_outputs[i] = tf.concat(new Tensor[] { batch_outputs[i], tmp_batch_outputs[i] }, axis: 0); | ||||
} | } | ||||
var end_step = step + data_handler.StepIncrement; | var end_step = step + data_handler.StepIncrement; | ||||
callbacks.on_predict_batch_end(end_step, new Dictionary<string, Tensors> { { "outputs", batch_outputs } }); | callbacks.on_predict_batch_end(end_step, new Dictionary<string, Tensors> { { "outputs", batch_outputs } }); | ||||
GC.Collect(); | |||||
} | } | ||||
} | } | ||||
@@ -0,0 +1,167 @@ | |||||
using System; | |||||
using System.Collections.Generic; | |||||
using System.Text; | |||||
using System; | |||||
using Tensorflow.Keras.ArgsDefinition; | |||||
using Tensorflow.Keras.Saving; | |||||
using Tensorflow.Common.Types; | |||||
using Tensorflow.Keras.Utils; | |||||
using Tensorflow.Operations; | |||||
using Newtonsoft.Json; | |||||
using System.Security.Cryptography; | |||||
namespace Tensorflow.Keras.Layers | |||||
{ | |||||
public class DepthwiseConv2DArgs: Conv2DArgs | |||||
{ | |||||
/// <summary> | |||||
/// depth_multiplier: The number of depthwise convolution output channels for | |||||
/// each input channel.The total number of depthwise convolution output | |||||
/// channels will be equal to `filters_in* depth_multiplier`. | |||||
/// </summary> | |||||
[JsonProperty("depth_multiplier")] | |||||
public int DepthMultiplier { get; set; } = 1; | |||||
[JsonProperty("depthwise_initializer")] | |||||
public IInitializer DepthwiseInitializer { get; set; } | |||||
} | |||||
public class DepthwiseConv2D : Conv2D | |||||
{ | |||||
/// <summary> | |||||
/// depth_multiplier: The number of depthwise convolution output channels for | |||||
/// each input channel.The total number of depthwise convolution output | |||||
/// channels will be equal to `filters_in* depth_multiplier`. | |||||
/// </summary> | |||||
int DepthMultiplier = 1; | |||||
IInitializer DepthwiseInitializer; | |||||
int[] strides; | |||||
int[] dilation_rate; | |||||
string getDataFormat() | |||||
{ | |||||
return data_format == "channels_first" ? "NCHW" : "NHWC"; | |||||
} | |||||
static int _id = 1; | |||||
public DepthwiseConv2D(DepthwiseConv2DArgs args):base(args) | |||||
{ | |||||
args.Padding = args.Padding.ToUpper(); | |||||
if(string.IsNullOrEmpty(args.Name)) | |||||
name = "DepthwiseConv2D_" + _id; | |||||
this.DepthMultiplier = args.DepthMultiplier; | |||||
this.DepthwiseInitializer = args.DepthwiseInitializer; | |||||
} | |||||
public override void build(KerasShapesWrapper input_shape) | |||||
{ | |||||
//base.build(input_shape); | |||||
var shape = input_shape.ToSingleShape(); | |||||
int channel_axis = data_format == "channels_first" ? 1 : -1; | |||||
var input_channel = channel_axis < 0 ? | |||||
shape.dims[shape.ndim + channel_axis] : | |||||
shape.dims[channel_axis]; | |||||
var arg = args as DepthwiseConv2DArgs; | |||||
if (arg.Strides.ndim != shape.ndim) | |||||
{ | |||||
if (arg.Strides.ndim == 2) | |||||
{ | |||||
this.strides = new int[] { 1, (int)arg.Strides[0], (int)arg.Strides[1], 1 }; | |||||
} | |||||
else | |||||
{ | |||||
this.strides = conv_utils.normalize_tuple(new int[] { (int)arg.Strides[0] }, shape.ndim, "strides"); | |||||
} | |||||
} | |||||
else | |||||
{ | |||||
this.strides = arg.Strides.dims.Select(o=>(int)(o)).ToArray(); | |||||
} | |||||
if (arg.DilationRate.ndim != shape.ndim) | |||||
{ | |||||
this.dilation_rate = conv_utils.normalize_tuple(new int[] { (int)arg.DilationRate[0] }, shape.ndim, "dilation_rate"); | |||||
} | |||||
long channel_data = data_format == "channels_first" ? shape[0] : shape[shape.Length - 1]; | |||||
var depthwise_kernel_shape = this.kernel_size.dims.concat(new long[] { | |||||
channel_data, | |||||
this.DepthMultiplier | |||||
}); | |||||
this.kernel = this.add_weight( | |||||
shape: depthwise_kernel_shape, | |||||
initializer: this.DepthwiseInitializer != null ? this.DepthwiseInitializer : this.kernel_initializer, | |||||
name: "depthwise_kernel", | |||||
trainable: true, | |||||
dtype: DType, | |||||
regularizer: this.kernel_regularizer | |||||
); | |||||
var axes = new Dictionary<int, int>(); | |||||
axes.Add(-1, (int)input_channel); | |||||
inputSpec = new InputSpec(min_ndim: rank + 2, axes: axes); | |||||
if (use_bias) | |||||
{ | |||||
bias = add_weight(name: "bias", | |||||
shape: ((int)channel_data), | |||||
initializer: bias_initializer, | |||||
trainable: true, | |||||
dtype: DType); | |||||
} | |||||
built = true; | |||||
_buildInputShape = input_shape; | |||||
} | |||||
protected override Tensors Call(Tensors inputs, Tensors state = null, | |||||
bool? training = false, IOptionalArgs? optional_args = null) | |||||
{ | |||||
Tensor outputs = null; | |||||
outputs = gen_nn_ops.depthwise_conv2d_native( | |||||
inputs, | |||||
filter: this.kernel.AsTensor(), | |||||
strides: this.strides, | |||||
padding: this.padding, | |||||
dilations: this.dilation_rate, | |||||
data_format: this.getDataFormat(), | |||||
name: name | |||||
); | |||||
if (use_bias) | |||||
{ | |||||
if (data_format == "channels_first") | |||||
{ | |||||
throw new NotImplementedException("call channels_first"); | |||||
} | |||||
else | |||||
{ | |||||
outputs = gen_nn_ops.bias_add(outputs, ops.convert_to_tensor(bias), | |||||
data_format: this.getDataFormat(), name: name); | |||||
} | |||||
} | |||||
if (activation != null) | |||||
outputs = activation.Apply(outputs); | |||||
return outputs; | |||||
} | |||||
} | |||||
} |
@@ -112,7 +112,28 @@ namespace Tensorflow.Keras.Layers | |||||
KernelInitializer = GetInitializerByName(kernel_initializer), | KernelInitializer = GetInitializerByName(kernel_initializer), | ||||
BiasInitializer = GetInitializerByName(bias_initializer) | BiasInitializer = GetInitializerByName(bias_initializer) | ||||
}); | }); | ||||
public ILayer Conv2D(int filters, | |||||
Shape kernel_size = null, | |||||
Shape strides = null, | |||||
string padding = "valid") | |||||
=> new Conv2D(new Conv2DArgs | |||||
{ | |||||
Rank = 2, | |||||
Filters = filters, | |||||
KernelSize = (kernel_size == null) ? (5, 5) : kernel_size, | |||||
Strides = strides == null ? (1, 1) : strides, | |||||
Padding = padding, | |||||
DataFormat = null, | |||||
DilationRate = (1, 1), | |||||
Groups = 1, | |||||
UseBias = false, | |||||
KernelRegularizer = null, | |||||
KernelInitializer =tf.glorot_uniform_initializer, | |||||
BiasInitializer = tf.zeros_initializer, | |||||
BiasRegularizer = null, | |||||
ActivityRegularizer = null, | |||||
Activation = keras.activations.Linear, | |||||
}); | |||||
/// <summary> | /// <summary> | ||||
/// 2D convolution layer (e.g. spatial convolution over images). | /// 2D convolution layer (e.g. spatial convolution over images). | ||||
/// This layer creates a convolution kernel that is convolved with the layer input to produce a tensor of outputs. | /// This layer creates a convolution kernel that is convolved with the layer input to produce a tensor of outputs. | ||||
@@ -210,6 +231,38 @@ namespace Tensorflow.Keras.Layers | |||||
Activation = keras.activations.GetActivationFromName(activation) | Activation = keras.activations.GetActivationFromName(activation) | ||||
}); | }); | ||||
public ILayer DepthwiseConv2D(Shape kernel_size = null, | |||||
Shape strides = null, | |||||
string padding = "valid", | |||||
string data_format = null, | |||||
Shape dilation_rate = null, | |||||
int groups = 1, | |||||
int depth_multiplier = 1, | |||||
string activation = null, | |||||
bool use_bias = false, | |||||
string kernel_initializer = "glorot_uniform", | |||||
string bias_initializer = "zeros", | |||||
string depthwise_initializer = "glorot_uniform" | |||||
) | |||||
=> new DepthwiseConv2D(new DepthwiseConv2DArgs | |||||
{ | |||||
Rank = 2, | |||||
Filters = 1, | |||||
KernelSize = (kernel_size == null) ? (5, 5) : kernel_size, | |||||
Strides = strides == null ? (1) : strides, | |||||
Padding = padding, | |||||
DepthMultiplier = depth_multiplier, | |||||
DataFormat = data_format, | |||||
DilationRate = dilation_rate == null ? (1) : dilation_rate, | |||||
Groups = groups, | |||||
UseBias = use_bias, | |||||
KernelInitializer = GetInitializerByName(kernel_initializer), | |||||
DepthwiseInitializer = GetInitializerByName(depthwise_initializer == null ? kernel_initializer : depthwise_initializer), | |||||
BiasInitializer = GetInitializerByName(bias_initializer), | |||||
Activation = keras.activations.GetActivationFromName(activation), | |||||
}); | |||||
/// <summary> | /// <summary> | ||||
/// Transposed convolution layer (sometimes called Deconvolution). | /// Transposed convolution layer (sometimes called Deconvolution). | ||||
/// </summary> | /// </summary> | ||||
@@ -4,6 +4,7 @@ using System.Linq; | |||||
using Tensorflow; | using Tensorflow; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using System; | using System; | ||||
using System.IO; | |||||
namespace TensorFlowNET.UnitTest | namespace TensorFlowNET.UnitTest | ||||
{ | { | ||||
@@ -164,5 +165,94 @@ namespace TensorFlowNET.UnitTest | |||||
Assert.AreEqual(result.size, 16ul); | Assert.AreEqual(result.size, 16ul); | ||||
Assert.AreEqual(result[0, 0, 0, 0], 12f); | Assert.AreEqual(result[0, 0, 0, 0], 12f); | ||||
} | } | ||||
[TestMethod] | |||||
public void ImageSaveTest() | |||||
{ | |||||
var imgPath = TestHelper.GetFullPathFromDataDir("img001.bmp"); | |||||
var jpegImgPath = TestHelper.GetFullPathFromDataDir("img001.jpeg"); | |||||
var pngImgPath = TestHelper.GetFullPathFromDataDir("img001.png"); | |||||
File.Delete(jpegImgPath); | |||||
File.Delete(pngImgPath); | |||||
var contents = tf.io.read_file(imgPath); | |||||
var bmp = tf.image.decode_image(contents); | |||||
Assert.AreEqual(bmp.name, "decode_image/DecodeImage:0"); | |||||
var jpeg = tf.image.encode_jpeg(bmp); | |||||
var op1 = tf.io.write_file(jpegImgPath, jpeg); | |||||
var png = tf.image.encode_png(bmp); | |||||
var op2 = tf.io.write_file(pngImgPath, png); | |||||
this.session().run(op1); | |||||
this.session().run(op2); | |||||
Assert.IsTrue(File.Exists(jpegImgPath), "not find file:" + jpegImgPath); | |||||
Assert.IsTrue(File.Exists(pngImgPath), "not find file:" + pngImgPath); | |||||
// 如果要测试图片正确性,需要注释下面两行代码 | |||||
File.Delete(jpegImgPath); | |||||
File.Delete(pngImgPath); | |||||
} | |||||
[TestMethod] | |||||
public void ImageFlipTest() | |||||
{ | |||||
var imgPath = TestHelper.GetFullPathFromDataDir("img001.bmp"); | |||||
var contents = tf.io.read_file(imgPath); | |||||
var bmp = tf.image.decode_image(contents); | |||||
// 左右翻转 | |||||
var lrImgPath = TestHelper.GetFullPathFromDataDir("img001_lr.png"); | |||||
File.Delete(lrImgPath); | |||||
var lr = tf.image.flip_left_right(bmp); | |||||
var png = tf.image.encode_png(lr); | |||||
var op = tf.io.write_file(lrImgPath, png); | |||||
this.session().run(op); | |||||
Assert.IsTrue(File.Exists(lrImgPath), "not find file:" + lrImgPath); | |||||
// 上下翻转 | |||||
var updownImgPath = TestHelper.GetFullPathFromDataDir("img001_updown.png"); | |||||
File.Delete(updownImgPath); | |||||
var updown = tf.image.flip_up_down(bmp); | |||||
var pngupdown = tf.image.encode_png(updown); | |||||
var op2 = tf.io.write_file(updownImgPath, pngupdown); | |||||
this.session().run(op2); | |||||
Assert.IsTrue(File.Exists(updownImgPath)); | |||||
// 暂时先人工观测图片是否翻转,观测时需要删除下面这两行代码 | |||||
File.Delete(lrImgPath); | |||||
File.Delete(updownImgPath); | |||||
// 多图翻转 | |||||
// 目前直接通过 bmp 拿到 shape ,这里先用默认定义图片大小来构建了 | |||||
var mImg = tf.stack(new[] { bmp, lr }, axis:0); | |||||
print(mImg.shape); | |||||
var up2 = tf.image.flip_up_down(mImg); | |||||
var updownImgPath_m1 = TestHelper.GetFullPathFromDataDir("img001_m_ud.png"); // 直接上下翻转 | |||||
File.Delete(updownImgPath_m1); | |||||
var img001_updown_m2 = TestHelper.GetFullPathFromDataDir("img001_m_lr_ud.png"); // 先左右再上下 | |||||
File.Delete(img001_updown_m2); | |||||
var png2 = tf.image.encode_png(up2[0]); | |||||
tf.io.write_file(updownImgPath_m1, png2); | |||||
png2 = tf.image.encode_png(up2[1]); | |||||
tf.io.write_file(img001_updown_m2, png2); | |||||
// 如果要测试图片正确性,需要注释下面两行代码 | |||||
File.Delete(updownImgPath_m1); | |||||
File.Delete(img001_updown_m2); | |||||
} | |||||
} | } | ||||
} | } |
@@ -33,6 +33,40 @@ namespace Tensorflow.Keras.UnitTest | |||||
return ret; | return ret; | ||||
} | } | ||||
public void AssertArray(int[] f1, int[] f2) | |||||
{ | |||||
bool ret = false; | |||||
for (var i = 0; i < f1.Length; i++) | |||||
{ | |||||
ret = f1[i] == f2[i]; | |||||
if (!ret) | |||||
break; | |||||
} | |||||
if (!ret) | |||||
{ | |||||
Assert.Fail($"Array not Equal:[{string.Join(",", f1)}] [{string.Join(",", f2)}]"); | |||||
} | |||||
} | |||||
public void AssertArray(float[] f1, float[] f2) | |||||
{ | |||||
bool ret = false; | |||||
var tolerance = .00001f; | |||||
for (var i = 0; i < f1.Length; i++) | |||||
{ | |||||
ret = Math.Abs(f1[i] - f2[i]) <= tolerance; | |||||
if (!ret) | |||||
break; | |||||
} | |||||
if (!ret) | |||||
{ | |||||
Assert.Fail($"Array float not Equal:[{string.Join(",", f1)}] [{string.Join(",", f2)}]"); | |||||
} | |||||
} | |||||
public bool Equal(double[] d1, double[] d2) | public bool Equal(double[] d1, double[] d2) | ||||
{ | { | ||||
bool ret = false; | bool ret = false; | ||||
@@ -1,6 +1,8 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
using System.Linq; | |||||
using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
using static Tensorflow.Binding; | |||||
namespace Tensorflow.Keras.UnitTest.Layers | namespace Tensorflow.Keras.UnitTest.Layers | ||||
{ | { | ||||
@@ -193,5 +195,128 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||||
Assert.AreEqual(x.dims[2], y.shape[2]); | Assert.AreEqual(x.dims[2], y.shape[2]); | ||||
Assert.AreEqual(filters, y.shape[3]); | Assert.AreEqual(filters, y.shape[3]); | ||||
} | } | ||||
[TestMethod] | |||||
public void BasicDepthwiseConv2D() | |||||
{ | |||||
var conv = keras.layers.DepthwiseConv2D(kernel_size:3, strides:1, activation: null, | |||||
padding:"same", depthwise_initializer: "ones"); | |||||
var x = np.arange(2 * 9* 9* 3).reshape((2, 9, 9, 3)); | |||||
var x2 = ops.convert_to_tensor(x, TF_DataType.TF_FLOAT); | |||||
var y = conv.Apply(x2); | |||||
print($"input:{x2.shape} DepthwiseConv2D.out: {y.shape}"); | |||||
Assert.AreEqual(4, y.shape.ndim); | |||||
var arr = y.numpy().reshape((2, 9, 9, 3)); | |||||
AssertArray(x[new int[] { 1, 1, 1 }].ToArray<int>(), new int[] { 273, 274, 275 }); | |||||
AssertArray(arr[new int[] { 1, 1, 1 }].ToArray<float>(), new float[] { 2457f, 2466f, 2475f }); | |||||
var bn = keras.layers.BatchNormalization(); | |||||
var y2 = bn.Apply(y); | |||||
arr = y2.numpy().ToArray<float>(); | |||||
double delta = 0.0001; // 误差范围 | |||||
Assert.AreEqual(arr[0], 59.97002f, delta); | |||||
Assert.AreEqual(arr[1], 63.96802f, delta); | |||||
} | |||||
[TestMethod] | |||||
public void BasicDepthwiseConv2D_strides_2() | |||||
{ | |||||
var conv = keras.layers.DepthwiseConv2D(kernel_size: 3, strides: (1, 2, 2, 1), activation: null, | |||||
padding: "same", depthwise_initializer: "ones"); | |||||
var x = np.arange(2 * 9 * 9 * 3).reshape((2, 9, 9, 3)); | |||||
var x2 = ops.convert_to_tensor(x, TF_DataType.TF_FLOAT); | |||||
var y = conv.Apply(x2); | |||||
print($"input:{x2.shape} DepthwiseConv2D.out: {y.shape}"); | |||||
Assert.AreEqual(4, y.shape.ndim); | |||||
var arr = y.numpy().reshape((2, 5, 5, 3)); | |||||
AssertArray(x[new int[] { 1, 1, 1 }].ToArray<int>(), new int[] { 273, 274, 275 }); | |||||
AssertArray(arr[new int[] { 1, 1, 1 }].ToArray<float>(), new float[] { 2727f, 2736f, 2745f }); | |||||
var bn = keras.layers.BatchNormalization(); | |||||
var y2 = bn.Apply(y); | |||||
arr = y2.numpy().ToArray<float>(); | |||||
double delta = 0.0001; // 误差范围 | |||||
Assert.AreEqual(arr[0], 59.97002f, delta); | |||||
Assert.AreEqual(arr[1], 63.96802f, delta); | |||||
} | |||||
[TestMethod] | |||||
public void BasicDepthwiseConv2D_strides_3() | |||||
{ | |||||
var conv = keras.layers.DepthwiseConv2D(kernel_size: 3, strides: 3, activation: null, | |||||
padding: "same", depthwise_initializer: "ones"); | |||||
var x = np.arange(2 * 9 * 9 * 3).reshape((2, 9, 9, 3)); | |||||
var x2 = ops.convert_to_tensor(x, TF_DataType.TF_FLOAT); | |||||
var y = conv.Apply(x2); | |||||
print($"input:{x2.shape} DepthwiseConv2D.out: {y.shape}"); | |||||
Assert.AreEqual(4, y.shape.ndim); | |||||
var arr = y.numpy().reshape((2, 3, 3, 3)); | |||||
AssertArray(x[new int[] { 1, 1, 1 }].ToArray<int>(), new int[] { 273, 274, 275 }); | |||||
AssertArray(arr[new int[] { 1, 1, 1 }].ToArray<float>(), new float[] { 3267f, 3276f, 3285f }); | |||||
var bn = keras.layers.BatchNormalization(); | |||||
var y2 = bn.Apply(y); | |||||
arr = y2.numpy().ToArray<float>(); | |||||
double delta = 0.0001; // 误差范围 | |||||
Assert.AreEqual(arr[0], 269.86508f, delta); | |||||
Assert.AreEqual(arr[1], 278.8606f, delta); | |||||
} | |||||
[TestMethod] | |||||
public void BasicDepthwiseConv2D_UseBias() | |||||
{ | |||||
var conv = keras.layers.DepthwiseConv2D(kernel_size: 3, strides: 1, activation: null, | |||||
use_bias: true, padding: "same", | |||||
depthwise_initializer: "ones", | |||||
bias_initializer:"ones" | |||||
); | |||||
var weight = conv.get_weights(); | |||||
var x = np.arange(9 * 9 * 3).reshape((1, 9, 9, 3)); | |||||
var x2 = ops.convert_to_tensor(x, TF_DataType.TF_FLOAT); | |||||
var y = conv.Apply(x2); | |||||
Assert.AreEqual(4, y.shape.ndim); | |||||
var arr = y.numpy().ToArray<float>(); | |||||
Assert.AreEqual(arr[0], 61f); | |||||
Assert.AreEqual(arr[1], 65f); | |||||
var bn = keras.layers.BatchNormalization(); | |||||
var y2 = bn.Apply(y); | |||||
arr = y2.numpy().ToArray<float>(); | |||||
double delta = 0.0001; // 误差范围 | |||||
Assert.AreEqual(arr[0], 60.96952f, delta); | |||||
Assert.AreEqual(arr[1], 64.96752f, delta); | |||||
} | |||||
} | } | ||||
} | } |
@@ -20,6 +20,20 @@ namespace TensorFlowNET.UnitTest | |||||
return Math.Abs(f1 - f2) <= tolerance; | return Math.Abs(f1 - f2) <= tolerance; | ||||
} | } | ||||
public bool Equal(long[] l1, long[] l2) | |||||
{ | |||||
if (l1.Length != l2.Length) | |||||
return false; | |||||
for (var i = 0; i < l1.Length; i++) | |||||
{ | |||||
if (l1[i] != l2[i]) | |||||
return false; | |||||
} | |||||
return true; | |||||
} | |||||
public bool Equal(float[] f1, float[] f2) | public bool Equal(float[] f1, float[] f2) | ||||
{ | { | ||||
bool ret = false; | bool ret = false; | ||||
@@ -3,6 +3,7 @@ using Tensorflow.NumPy; | |||||
using Tensorflow; | using Tensorflow; | ||||
using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
using System.Linq; | using System.Linq; | ||||
using Tensorflow.Operations; | |||||
namespace TensorFlowNET.UnitTest.ManagedAPI | namespace TensorFlowNET.UnitTest.ManagedAPI | ||||
{ | { | ||||
@@ -105,5 +106,321 @@ namespace TensorFlowNET.UnitTest.ManagedAPI | |||||
Assert.IsTrue(Equal(a[0].ToArray<float>().Reverse().ToArray(), b[0].ToArray<float>())); | Assert.IsTrue(Equal(a[0].ToArray<float>().Reverse().ToArray(), b[0].ToArray<float>())); | ||||
Assert.IsTrue(Equal(a[1].ToArray<float>().Reverse().ToArray(), b[1].ToArray<float>())); | Assert.IsTrue(Equal(a[1].ToArray<float>().Reverse().ToArray(), b[1].ToArray<float>())); | ||||
} | } | ||||
[TestMethod] | |||||
public void ReverseImgArray3D() | |||||
{ | |||||
// 创建 sourceImg 数组 | |||||
var sourceImgArray = new float[,,] { | |||||
{ | |||||
{ 237, 28, 36 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 } | |||||
}, | |||||
{ | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 } | |||||
} | |||||
}; | |||||
var sourceImg = ops.convert_to_tensor(sourceImgArray); | |||||
// 创建 lrImg 数组 | |||||
var lrImgArray = new float[,,] { | |||||
{ | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
{ 237, 28, 36 } | |||||
}, | |||||
{ | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 } | |||||
} | |||||
}; | |||||
var lrImg = ops.convert_to_tensor(lrImgArray); | |||||
var lr = tf.image.flip_left_right(sourceImg); | |||||
Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr.numpy().ToArray<float>()), "tf.image.flip_left_right fail."); | |||||
var lr2 = tf.reverse(sourceImg, 1); | |||||
Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr2.numpy().ToArray<float>()), "tf.reverse (axis=1) fail."); | |||||
var lr3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 1 })); | |||||
Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr3.numpy().ToArray<float>()), "gen_array_ops.reverse_v2 axis=1 fail."); | |||||
// 创建 udImg 数组 | |||||
var udImgArray = new float[,,] { | |||||
{ | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 } | |||||
}, | |||||
{ | |||||
{ 237, 28, 36 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 } | |||||
} | |||||
}; | |||||
var udImg = ops.convert_to_tensor(udImgArray); | |||||
var ud = tf.image.flip_up_down(sourceImg); | |||||
Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud.numpy().ToArray<float>()), "tf.image.flip_up_down fail."); | |||||
var ud2 = tf.reverse(sourceImg, new Axis(0)); | |||||
Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud2.numpy().ToArray<float>()), "tf.reverse (axis=0) fail."); | |||||
var ud3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 0 })); | |||||
Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud3.numpy().ToArray<float>()), "gen_array_ops.reverse_v2 axis=0 fail."); | |||||
} | |||||
[TestMethod] | |||||
public void ReverseImgArray4D() | |||||
{ | |||||
// 原图左上角,加一张左右翻转后的图片 | |||||
var m = new float[,,,] { | |||||
{ | |||||
{ | |||||
{ 237, 28, 36 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 } | |||||
}, | |||||
{ | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 } | |||||
} | |||||
}, | |||||
{ | |||||
{ | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
{ 237, 28, 36 } | |||||
}, | |||||
{ | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 } | |||||
} | |||||
} | |||||
}; | |||||
var sourceImg = ops.convert_to_tensor(m); | |||||
var lrArray = new float[,,,] { | |||||
{ | |||||
{ | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
{ 237, 28, 36 }, | |||||
}, | |||||
{ | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 } | |||||
} | |||||
}, | |||||
{ | |||||
{ | |||||
{ 237, 28, 36 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
}, | |||||
{ | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 } | |||||
} | |||||
} | |||||
}; | |||||
var lrImg = ops.convert_to_tensor(lrArray); | |||||
// 创建 ud 数组 | |||||
var udArray = new float[,,,] { | |||||
{ | |||||
{ | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 } | |||||
}, | |||||
{ | |||||
{ 237, 28, 36 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 } | |||||
} | |||||
}, | |||||
{ | |||||
{ | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 } | |||||
}, | |||||
{ | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
{ 237, 28, 36 } | |||||
} | |||||
} | |||||
}; | |||||
var udImg = ops.convert_to_tensor(udArray); | |||||
var ud3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 1 })); | |||||
Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud3.numpy().ToArray<float>()), "gen_array_ops.reverse_v2 axis=1 fail."); | |||||
var ud2 = tf.reverse(sourceImg, new Axis(1)); | |||||
Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud2.numpy().ToArray<float>()), "tf.reverse (axis=1) fail."); | |||||
var ud = tf.image.flip_up_down(sourceImg); | |||||
Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud.numpy().ToArray<float>()), "tf.image.flip_up_down fail."); | |||||
// 左右翻转 | |||||
var lr = tf.image.flip_left_right(sourceImg); | |||||
Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr.numpy().ToArray<float>()), "tf.image.flip_left_right fail."); | |||||
var lr2 = tf.reverse(sourceImg, 0); | |||||
Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr2.numpy().ToArray<float>()), "tf.reverse (axis=1) fail."); | |||||
var lr3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 0 })); | |||||
Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr3.numpy().ToArray<float>()), "gen_array_ops.reverse_v2 axis=1 fail."); | |||||
} | |||||
[TestMethod] | |||||
public void ReverseImgArray4D_3x3() | |||||
{ | |||||
// 原图左上角,加一张左右翻转后的图片 | |||||
var m = new float[,,,] { | |||||
{ | |||||
{ | |||||
{ 237, 28, 36 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 } | |||||
}, | |||||
{ | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 } | |||||
}, | |||||
{ | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 } | |||||
} | |||||
}, | |||||
{ | |||||
{ | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
{ 237, 28, 36 } | |||||
}, | |||||
{ | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 } | |||||
}, | |||||
{ | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 } | |||||
} | |||||
} | |||||
}; | |||||
var sourceImg = ops.convert_to_tensor(m); | |||||
var lrArray = new float[,,,] { | |||||
{ | |||||
{ | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
{ 237, 28, 36 }, | |||||
}, | |||||
{ | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 } | |||||
}, | |||||
{ | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 } | |||||
} | |||||
}, | |||||
{ | |||||
{ | |||||
{ 237, 28, 36 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
}, | |||||
{ | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 } | |||||
}, | |||||
{ | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 } | |||||
} | |||||
} | |||||
}; | |||||
var lrImg = ops.convert_to_tensor(lrArray); | |||||
// 创建 ud 数组 | |||||
var udArray = new float[,,,] { | |||||
{ | |||||
{ | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 } | |||||
}, | |||||
{ | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 } | |||||
}, | |||||
{ | |||||
{ 237, 28, 36 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 } | |||||
} | |||||
}, | |||||
{ { | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 } | |||||
}, | |||||
{ | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 } | |||||
}, | |||||
{ | |||||
{ 255, 255, 255 }, | |||||
{ 255, 255, 255 }, | |||||
{ 237, 28, 36 } | |||||
} | |||||
} | |||||
}; | |||||
var udImg = ops.convert_to_tensor(udArray); | |||||
var ud3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 1 })); | |||||
Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud3.numpy().ToArray<float>()), "gen_array_ops.reverse_v2 axis=1 fail."); | |||||
var ud2 = tf.reverse(sourceImg, new Axis(1)); | |||||
Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud2.numpy().ToArray<float>()), "tf.reverse (axis=1) fail."); | |||||
var ud = tf.image.flip_up_down(sourceImg); | |||||
Assert.IsTrue(Equal(udImg.numpy().ToArray<float>(), ud.numpy().ToArray<float>()), "tf.image.flip_up_down fail."); | |||||
// 左右翻转 | |||||
var lr = tf.image.flip_left_right(sourceImg); | |||||
Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr.numpy().ToArray<float>()), "tf.image.flip_left_right fail."); | |||||
var lr2 = tf.reverse(sourceImg, 0); | |||||
Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr2.numpy().ToArray<float>()), "tf.reverse (axis=1) fail."); | |||||
var lr3 = gen_array_ops.reverse_v2(sourceImg, ops.convert_to_tensor(new[] { 0 })); | |||||
Assert.IsTrue(Equal(lrImg.numpy().ToArray<float>(), lr3.numpy().ToArray<float>()), "gen_array_ops.reverse_v2 axis=1 fail."); | |||||
} | |||||
} | } | ||||
} | } |
@@ -0,0 +1,44 @@ | |||||
using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
using Tensorflow.NumPy; | |||||
using System; | |||||
using System.Linq; | |||||
using static Tensorflow.Binding; | |||||
using Tensorflow; | |||||
namespace TensorFlowNET.UnitTest.NumPy | |||||
{ | |||||
[TestClass] | |||||
public class ShapeTest : EagerModeTestBase | |||||
{ | |||||
[Ignore] | |||||
[TestMethod] | |||||
public unsafe void ShapeGetLastElements() | |||||
{ | |||||
// test code from function _CheckAtLeast3DImage | |||||
// 之前的 _CheckAtLeast3DImage 有bug,现在通过测试,下面的代码是正确的 | |||||
// todo: shape["-3:"] 的写法,目前有bug,需要修复,单元测试等修复后再放开,暂时先忽略测试 | |||||
var image_shape = new Shape(new[] { 32, 64, 3 }); | |||||
var image_shape_4d = new Shape(new[] { 4, 64, 32, 3 }); | |||||
var image_shape_last_three_elements = new Shape(new[] { | |||||
image_shape.dims[image_shape.dims.Length - 3], | |||||
image_shape.dims[image_shape.dims.Length - 2], | |||||
image_shape.dims[image_shape.dims.Length - 1]}); | |||||
var image_shape_last_three_elements2 = image_shape["-3:"]; | |||||
Assert.IsTrue(Equal(image_shape_last_three_elements.dims, image_shape_last_three_elements2.dims), "3dims get fail."); | |||||
var image_shape_last_three_elements_4d = new Shape(new[] { | |||||
image_shape_4d.dims[image_shape_4d.dims.Length - 3], | |||||
image_shape_4d.dims[image_shape_4d.dims.Length - 2], | |||||
image_shape_4d.dims[image_shape_4d.dims.Length - 1]}); | |||||
var image_shape_last_three_elements2_4d = image_shape_4d["-3:"]; | |||||
Assert.IsTrue(Equals(image_shape_last_three_elements_4d.dims, image_shape_last_three_elements2_4d.dims), "4dims get fail."); | |||||
} | |||||
} | |||||
} |