You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

VdCnn.cs 6.9 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using Tensorflow;
  6. using TensorFlowNET.Examples.Text.cnn_models;
  7. using static Tensorflow.Python;
  8. namespace TensorFlowNET.Examples.TextClassification
  9. {
  10. public class VdCnn : ITextClassificationModel
  11. {
  12. private int embedding_size;
  13. private int[] filter_sizes;
  14. private int[] num_filters;
  15. private int[] num_blocks;
  16. private float learning_rate;
  17. private IInitializer cnn_initializer;
  18. private IInitializer fc_initializer;
  19. public Tensor x { get; private set; }
  20. public Tensor y { get; private set; }
  21. public Tensor is_training { get; private set; }
  22. private RefVariable global_step;
  23. private RefVariable embeddings;
  24. private Tensor x_emb;
  25. private Tensor x_expanded;
  26. private Tensor logits;
  27. private Tensor predictions;
  28. private Tensor loss;
  29. public VdCnn(int alphabet_size, int document_max_len, int num_class)
  30. {
  31. embedding_size = 16;
  32. filter_sizes = new int[] { 3, 3, 3, 3, 3 };
  33. num_filters = new int[] { 64, 64, 128, 256, 512 };
  34. num_blocks = new int[] { 2, 2, 2, 2 };
  35. learning_rate = 0.001f;
  36. cnn_initializer = tf.keras.initializers.he_normal();
  37. fc_initializer = tf.truncated_normal_initializer(stddev: 0.05f);
  38. x = tf.placeholder(tf.int32, new TensorShape(-1, document_max_len), name: "x");
  39. y = tf.placeholder(tf.int32, new TensorShape(-1), name: "y");
  40. is_training = tf.placeholder(tf.boolean, new TensorShape(), name: "is_training");
  41. global_step = tf.Variable(0, trainable: false);
  42. // Embedding Layer
  43. with(tf.name_scope("embedding"), delegate
  44. {
  45. var init_embeddings = tf.random_uniform(new int[] { alphabet_size, embedding_size }, -1.0f, 1.0f);
  46. embeddings = tf.get_variable("embeddings", initializer: init_embeddings);
  47. x_emb = tf.nn.embedding_lookup(embeddings, x);
  48. x_expanded = tf.expand_dims(x_emb, -1);
  49. });
  50. Tensor conv0 = null;
  51. Tensor conv1 = null;
  52. Tensor conv2 = null;
  53. Tensor conv3 = null;
  54. Tensor conv4 = null;
  55. Tensor h_flat = null;
  56. Tensor fc1_out = null;
  57. Tensor fc2_out = null;
  58. // First Convolution Layer
  59. with(tf.variable_scope("conv-0"), delegate
  60. {
  61. conv0 = tf.layers.conv2d(x_expanded,
  62. filters: num_filters[0],
  63. kernel_size: new int[] { filter_sizes[0], embedding_size },
  64. kernel_initializer: cnn_initializer,
  65. activation: tf.nn.relu());
  66. conv0 = tf.transpose(conv0, new int[] { 0, 1, 3, 2 });
  67. });
  68. with(tf.name_scope("conv-block-1"), delegate {
  69. conv1 = conv_block(conv0, 1);
  70. });
  71. with(tf.name_scope("conv-block-2"), delegate {
  72. conv2 = conv_block(conv1, 2);
  73. });
  74. with(tf.name_scope("conv-block-3"), delegate {
  75. conv3 = conv_block(conv2, 3);
  76. });
  77. with(tf.name_scope("conv-block-4"), delegate
  78. {
  79. conv4 = conv_block(conv3, 4, max_pool: false);
  80. });
  81. // ============= k-max Pooling =============
  82. with(tf.name_scope("k-max-pooling"), delegate
  83. {
  84. var h = tf.transpose(tf.squeeze(conv4, new int[] { -1 }), new int[] { 0, 2, 1 });
  85. var top_k = tf.nn.top_k(h, k: 8, sorted: false)[0];
  86. h_flat = tf.reshape(top_k, new int[] { -1, 512 * 8 });
  87. });
  88. // ============= Fully Connected Layers =============
  89. with(tf.name_scope("fc-1"), scope =>
  90. {
  91. fc1_out = tf.layers.dense(h_flat, 2048, activation: tf.nn.relu(), kernel_initializer: fc_initializer);
  92. });
  93. with(tf.name_scope("fc-2"), scope =>
  94. {
  95. fc2_out = tf.layers.dense(fc1_out, 2048, activation: tf.nn.relu(), kernel_initializer: fc_initializer);
  96. });
  97. with(tf.name_scope("fc-3"), scope =>
  98. {
  99. logits = tf.layers.dense(fc2_out, num_class, activation: null, kernel_initializer: fc_initializer);
  100. predictions = tf.argmax(logits, -1, output_type: tf.int32);
  101. });
  102. // ============= Loss and Accuracy =============
  103. with(tf.name_scope("loss"), delegate
  104. {
  105. var y_one_hot = tf.one_hot(y, num_class);
  106. loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits: logits, labels: y_one_hot));
  107. var update_ops = tf.get_collection(ops.GraphKeys.UPDATE_OPS) as List<object>;
  108. with(tf.control_dependencies(update_ops.Select(x => (Operation)x).ToArray()), delegate
  109. {
  110. var adam = tf.train.AdamOptimizer(learning_rate);
  111. adam.minimize(loss, global_step: global_step);
  112. });
  113. });
  114. }
  115. private Tensor conv_block(Tensor input, int i, bool max_pool = true)
  116. {
  117. return with(tf.variable_scope($"conv-block-{i}"), delegate
  118. {
  119. Tensor conv = null;
  120. // Two "conv-batch_norm-relu" layers.
  121. foreach (var j in Enumerable.Range(0, 2))
  122. {
  123. with(tf.variable_scope($"conv-{j}"), delegate
  124. {
  125. // convolution
  126. conv = tf.layers.conv2d(
  127. input,
  128. filters: num_filters[i],
  129. kernel_size: new int[] { filter_sizes[i], num_filters[i - 1] },
  130. kernel_initializer: cnn_initializer,
  131. activation: null);
  132. // batch normalization
  133. conv = tf.layers.batch_normalization(conv, training: is_training);
  134. // relu
  135. conv = tf.nn.relu(conv);
  136. conv = tf.transpose(conv, new int[] { 0, 1, 3, 2 });
  137. });
  138. }
  139. if (max_pool)
  140. {
  141. // Max pooling
  142. return tf.layers.max_pooling2d(
  143. conv,
  144. pool_size: new int[] { 3, 1 },
  145. strides: new int[] { 2, 1 },
  146. padding: "SAME");
  147. }
  148. else
  149. {
  150. return conv;
  151. }
  152. });
  153. }
  154. }
  155. }