You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CnnTextTrain.cs 2.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. using NumSharp.Core;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Text;
  6. using Tensorflow;
  7. namespace TensorFlowNET.Examples.CnnTextClassification
  8. {
  9. public class CnnTextTrain : Python, IExample
  10. {
  11. // Percentage of the training data to use for validation
  12. private float dev_sample_percentage = 0.1f;
  13. // Data source for the positive data.
  14. private string positive_data_file = "https://raw.githubusercontent.com/dennybritz/cnn-text-classification-tf/master/data/rt-polaritydata/rt-polarity.pos";
  15. // Data source for the negative data.
  16. private string negative_data_file = "https://raw.githubusercontent.com/dennybritz/cnn-text-classification-tf/master/data/rt-polaritydata/rt-polarity.neg";
  17. // Dimensionality of character embedding (default: 128)
  18. private int embedding_dim = 128;
  19. // Comma-separated filter sizes (default: '3,4,5')
  20. private string filter_sizes = "3,4,5";
  21. // Number of filters per filter size (default: 128)
  22. private int num_filters = 128;
  23. // Dropout keep probability (default: 0.5)
  24. private float dropout_keep_prob = 0.5f;
  25. // L2 regularization lambda (default: 0.0)
  26. private float l2_reg_lambda = 0.0f;
  27. // Batch Size (default: 64)
  28. private int batch_size = 64;
  29. // Number of training epochs (default: 200)
  30. private int num_epochs = 200;
  31. // Evaluate model on dev set after this many steps (default: 100)
  32. private int evaluate_every = 100;
  33. // Save model after this many steps (default: 100)
  34. private int checkpoint_every = 100;
  35. // Number of checkpoints to store (default: 5)
  36. private int num_checkpoints = 5;
  37. // Allow device soft device placement
  38. private bool allow_soft_placement = true;
  39. // Log placement of ops on devices
  40. private bool log_device_placement = false;
  41. public void Run()
  42. {
  43. var (x_train, y_train, vocab_processor, x_dev, y_dev) = preprocess();
  44. }
  45. public (NDArray, NDArray, NDArray, NDArray, NDArray) preprocess()
  46. {
  47. var (x_text, y) = DataHelpers.load_data_and_labels(positive_data_file, negative_data_file);
  48. // Build vocabulary
  49. int max_document_length = x_text.Select(x => x.Split(' ').Length).Max();
  50. var vocab_processor = learn.preprocessing.VocabularyProcessor(max_document_length)
  51. throw new NotImplementedException("");
  52. }
  53. }
  54. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。