You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

VdCnn.cs 6.8 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using Tensorflow;
  6. using static Tensorflow.Python;
  7. namespace TensorFlowNET.Examples.Text
  8. {
  9. public class VdCnn : ITextModel
  10. {
  11. private int embedding_size;
  12. private int[] filter_sizes;
  13. private int[] num_filters;
  14. private int[] num_blocks;
  15. private float learning_rate;
  16. private IInitializer cnn_initializer;
  17. private IInitializer fc_initializer;
  18. public Tensor x { get; private set; }
  19. public Tensor y { get; private set; }
  20. public Tensor is_training { get; private set; }
  21. private RefVariable global_step;
  22. private RefVariable embeddings;
  23. private Tensor x_emb;
  24. private Tensor x_expanded;
  25. private Tensor logits;
  26. private Tensor predictions;
  27. private Tensor loss;
  28. public VdCnn(int alphabet_size, int document_max_len, int num_class)
  29. {
  30. embedding_size = 16;
  31. filter_sizes = new int[] { 3, 3, 3, 3, 3 };
  32. num_filters = new int[] { 64, 64, 128, 256, 512 };
  33. num_blocks = new int[] { 2, 2, 2, 2 };
  34. learning_rate = 0.001f;
  35. cnn_initializer = tf.keras.initializers.he_normal();
  36. fc_initializer = tf.truncated_normal_initializer(stddev: 0.05f);
  37. x = tf.placeholder(tf.int32, new TensorShape(-1, document_max_len), name: "x");
  38. y = tf.placeholder(tf.int32, new TensorShape(-1), name: "y");
  39. is_training = tf.placeholder(tf.@bool, new TensorShape(), name: "is_training");
  40. global_step = tf.Variable(0, trainable: false);
  41. // Embedding Layer
  42. with(tf.name_scope("embedding"), delegate
  43. {
  44. var init_embeddings = tf.random_uniform(new int[] { alphabet_size, embedding_size }, -1.0f, 1.0f);
  45. embeddings = tf.get_variable("embeddings", initializer: init_embeddings);
  46. x_emb = tf.nn.embedding_lookup(embeddings, x);
  47. x_expanded = tf.expand_dims(x_emb, -1);
  48. });
  49. Tensor conv0 = null;
  50. Tensor conv1 = null;
  51. Tensor conv2 = null;
  52. Tensor conv3 = null;
  53. Tensor conv4 = null;
  54. Tensor h_flat = null;
  55. Tensor fc1_out = null;
  56. Tensor fc2_out = null;
  57. // First Convolution Layer
  58. with(tf.variable_scope("conv-0"), delegate
  59. {
  60. conv0 = tf.layers.conv2d(x_expanded,
  61. filters: num_filters[0],
  62. kernel_size: new int[] { filter_sizes[0], embedding_size },
  63. kernel_initializer: cnn_initializer,
  64. activation: tf.nn.relu());
  65. conv0 = tf.transpose(conv0, new int[] { 0, 1, 3, 2 });
  66. });
  67. with(tf.name_scope("conv-block-1"), delegate {
  68. conv1 = conv_block(conv0, 1);
  69. });
  70. with(tf.name_scope("conv-block-2"), delegate {
  71. conv2 = conv_block(conv1, 2);
  72. });
  73. with(tf.name_scope("conv-block-3"), delegate {
  74. conv3 = conv_block(conv2, 3);
  75. });
  76. with(tf.name_scope("conv-block-4"), delegate
  77. {
  78. conv4 = conv_block(conv3, 4, max_pool: false);
  79. });
  80. // ============= k-max Pooling =============
  81. with(tf.name_scope("k-max-pooling"), delegate
  82. {
  83. var h = tf.transpose(tf.squeeze(conv4, new int[] { -1 }), new int[] { 0, 2, 1 });
  84. var top_k = tf.nn.top_k(h, k: 8, sorted: false)[0];
  85. h_flat = tf.reshape(top_k, new int[] { -1, 512 * 8 });
  86. });
  87. // ============= Fully Connected Layers =============
  88. with(tf.name_scope("fc-1"), scope =>
  89. {
  90. fc1_out = tf.layers.dense(h_flat, 2048, activation: tf.nn.relu(), kernel_initializer: fc_initializer);
  91. });
  92. with(tf.name_scope("fc-2"), scope =>
  93. {
  94. fc2_out = tf.layers.dense(fc1_out, 2048, activation: tf.nn.relu(), kernel_initializer: fc_initializer);
  95. });
  96. with(tf.name_scope("fc-3"), scope =>
  97. {
  98. logits = tf.layers.dense(fc2_out, num_class, activation: null, kernel_initializer: fc_initializer);
  99. predictions = tf.argmax(logits, -1, output_type: tf.int32);
  100. });
  101. // ============= Loss and Accuracy =============
  102. with(tf.name_scope("loss"), delegate
  103. {
  104. var y_one_hot = tf.one_hot(y, num_class);
  105. loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits: logits, labels: y_one_hot));
  106. var update_ops = tf.get_collection(ops.GraphKeys.UPDATE_OPS) as List<object>;
  107. with(tf.control_dependencies(update_ops.Select(x => (Operation)x).ToArray()), delegate
  108. {
  109. var adam = tf.train.AdamOptimizer(learning_rate);
  110. adam.minimize(loss, global_step: global_step);
  111. });
  112. });
  113. }
  114. private Tensor conv_block(Tensor input, int i, bool max_pool = true)
  115. {
  116. return with(tf.variable_scope($"conv-block-{i}"), delegate
  117. {
  118. Tensor conv = null;
  119. // Two "conv-batch_norm-relu" layers.
  120. foreach (var j in Enumerable.Range(0, 2))
  121. {
  122. with(tf.variable_scope($"conv-{j}"), delegate
  123. {
  124. // convolution
  125. conv = tf.layers.conv2d(
  126. input,
  127. filters: num_filters[i],
  128. kernel_size: new int[] { filter_sizes[i], num_filters[i - 1] },
  129. kernel_initializer: cnn_initializer,
  130. activation: null);
  131. // batch normalization
  132. conv = tf.layers.batch_normalization(conv, training: is_training);
  133. // relu
  134. conv = tf.nn.relu(conv);
  135. conv = tf.transpose(conv, new int[] { 0, 1, 3, 2 });
  136. });
  137. }
  138. if (max_pool)
  139. {
  140. // Max pooling
  141. return tf.layers.max_pooling2d(
  142. conv,
  143. pool_size: new int[] { 3, 1 },
  144. strides: new int[] { 2, 1 },
  145. padding: "SAME");
  146. }
  147. else
  148. {
  149. return conv;
  150. }
  151. });
  152. }
  153. }
  154. }