You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

VdCnn.cs 2.2 kB

6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using Tensorflow;
  5. namespace TensorFlowNET.Examples.TextClassification
  6. {
  7. public class VdCnn : Python
  8. {
  9. private int embedding_size;
  10. private int[] filter_sizes;
  11. private int[] num_filters;
  12. private int[] num_blocks;
  13. private float learning_rate;
  14. private IInitializer cnn_initializer;
  15. private Tensor x;
  16. private Tensor y;
  17. private Tensor is_training;
  18. private RefVariable global_step;
  19. private RefVariable embeddings;
  20. private Tensor x_emb;
  21. private Tensor x_expanded;
  22. public VdCnn(int alphabet_size, int document_max_len, int num_class)
  23. {
  24. embedding_size = 16;
  25. filter_sizes = new int[] { 3, 3, 3, 3, 3 };
  26. num_filters = new int[] { 64, 64, 128, 256, 512 };
  27. num_blocks = new int[] { 2, 2, 2, 2 };
  28. learning_rate = 0.001f;
  29. cnn_initializer = tf.keras.initializers.he_normal();
  30. x = tf.placeholder(tf.int32, new TensorShape(-1, document_max_len), name: "x");
  31. y = tf.placeholder(tf.int32, new TensorShape(-1), name: "y");
  32. is_training = tf.placeholder(tf.boolean, new TensorShape(), name: "is_training");
  33. global_step = tf.Variable(0, trainable: false);
  34. // Embedding Layer
  35. with(tf.name_scope("embedding"), delegate
  36. {
  37. var init_embeddings = tf.random_uniform(new int[] { alphabet_size, embedding_size }, -1.0f, 1.0f);
  38. embeddings = tf.get_variable("embeddings", initializer: init_embeddings);
  39. x_emb = tf.nn.embedding_lookup(embeddings, x);
  40. x_expanded = tf.expand_dims(x_emb, -1);
  41. });
  42. // First Convolution Layer
  43. with(tf.variable_scope("conv-0"), delegate
  44. {
  45. var conv0 = tf.layers.conv2d(x_expanded,
  46. filters: num_filters[0],
  47. kernel_size: new int[] { filter_sizes[0], embedding_size },
  48. kernel_initializer: cnn_initializer,
  49. activation: tf.nn.relu);
  50. });
  51. }
  52. }
  53. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。