You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common.cs 2.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using Tensorflow;
  5. using static Tensorflow.Binding;
  6. namespace TensorFlowNET.Examples.ImageProcessing.YOLO
  7. {
  8. class common
  9. {
  10. public static Tensor convolutional(Tensor input_data, int[] filters_shape, Tensor trainable,
  11. string name, bool downsample = false, bool activate = true,
  12. bool bn = true)
  13. {
  14. return tf_with(tf.variable_scope(name), scope =>
  15. {
  16. int[] strides;
  17. string padding;
  18. if (downsample)
  19. {
  20. throw new NotImplementedException("");
  21. }
  22. else
  23. {
  24. strides = new int[] { 1, 1, 1, 1 };
  25. padding = "SAME";
  26. }
  27. var weight = tf.get_variable(name: "weight", dtype: tf.float32, trainable: true,
  28. shape: filters_shape, initializer: tf.random_normal_initializer(stddev: 0.01f));
  29. var conv = tf.nn.conv2d(input: input_data, filter: weight, strides: strides, padding: padding);
  30. if (bn)
  31. {
  32. conv = tf.layers.batch_normalization(conv, beta_initializer: tf.zeros_initializer,
  33. gamma_initializer: tf.ones_initializer,
  34. moving_mean_initializer: tf.zeros_initializer,
  35. moving_variance_initializer: tf.ones_initializer, training: trainable);
  36. }
  37. else
  38. {
  39. throw new NotImplementedException("");
  40. }
  41. if (activate)
  42. conv = tf.nn.leaky_relu(conv, alpha: 0.1f);
  43. return conv;
  44. });
  45. }
  46. public static Tensor residual_block(Tensor input_data, int input_channel, int filter_num1,
  47. int filter_num2, Tensor trainable, string name)
  48. {
  49. var short_cut = input_data;
  50. return tf_with(tf.variable_scope(name), scope =>
  51. {
  52. input_data = convolutional(input_data, filters_shape: new int[] { 1, 1, input_channel, filter_num1 },
  53. trainable: trainable, name: "conv1");
  54. input_data = convolutional(input_data, filters_shape: new int[] { 3, 3, filter_num1, filter_num2 },
  55. trainable: trainable, name: "conv2");
  56. var residual_output = input_data + short_cut;
  57. return residual_output;
  58. });
  59. }
  60. }
  61. }