You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DataHelpers.cs 7.0 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. using NumSharp;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Linq;
  6. using System.Text;
  7. using System.Text.RegularExpressions;
  8. namespace TensorFlowNET.Examples
  9. {
  10. public class DataHelpers
  11. {
  12. private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv";
  13. private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv";
  14. public static (NDArray, NDArray, int) build_char_dataset(string step, string model, int document_max_len, int? limit = null)
  15. {
  16. if (model != "vd_cnn")
  17. throw new NotImplementedException(model);
  18. string alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:’'\"/|_#$%ˆ&*˜‘+=<>()[]{} ";
  19. /*if (step == "train")
  20. df = pd.read_csv(TRAIN_PATH, names =["class", "title", "content"]);*/
  21. var char_dict = new Dictionary<string, int>();
  22. char_dict["<pad>"] = 0;
  23. char_dict["<unk>"] = 1;
  24. foreach (char c in alphabet)
  25. char_dict[c.ToString()] = char_dict.Count;
  26. var contents = File.ReadAllLines(TRAIN_PATH);
  27. var size = limit == null ? contents.Length : limit.Value;
  28. var x = new NDArray(np.int32, new Shape(size, document_max_len));
  29. var y = new NDArray(np.int32, new Shape(size));
  30. var tenth = size / 10;
  31. var percent = 0;
  32. for (int i = 0; i < size; i++)
  33. {
  34. if ((i + 1) % tenth == 0)
  35. {
  36. percent += 10;
  37. Console.WriteLine($"\t{percent}%");
  38. }
  39. string[] parts = contents[i].ToLower().Split(",\"").ToArray();
  40. string content = parts[2];
  41. content = content.Substring(0, content.Length - 1);
  42. var a = new int[document_max_len];
  43. for (int j = 0; j < document_max_len; j++)
  44. {
  45. if (j >= content.Length)
  46. a[j] = char_dict["<pad>"];
  47. //x[i, j] = char_dict["<pad>"];
  48. else
  49. a[j] = char_dict.ContainsKey(content[j].ToString()) ? char_dict[content[j].ToString()] : char_dict["<unk>"];
  50. //x[i, j] = char_dict.ContainsKey(content[j].ToString()) ? char_dict[content[j].ToString()] : char_dict["<unk>"];
  51. }
  52. x[i] = a;
  53. y[i] = int.Parse(parts[0]);
  54. }
  55. return (x, y, alphabet.Length + 2);
  56. }
  57. /// <summary>
  58. /// Loads MR polarity data from files, splits the data into words and generates labels.
  59. /// Returns split sentences and labels.
  60. /// </summary>
  61. /// <param name="positive_data_file"></param>
  62. /// <param name="negative_data_file"></param>
  63. /// <returns></returns>
  64. public static (string[], NDArray) load_data_and_labels(string positive_data_file, string negative_data_file)
  65. {
  66. Directory.CreateDirectory("CnnTextClassification");
  67. Utility.Web.Download(positive_data_file, "CnnTextClassification", "rt -polarity.pos");
  68. Utility.Web.Download(negative_data_file, "CnnTextClassification", "rt-polarity.neg");
  69. // Load data from files
  70. var positive_examples = File.ReadAllLines("CnnTextClassification/rt-polarity.pos")
  71. .Select(x => x.Trim())
  72. .ToArray();
  73. var negative_examples = File.ReadAllLines("CnnTextClassification/rt-polarity.neg")
  74. .Select(x => x.Trim())
  75. .ToArray();
  76. var x_text = new List<string>();
  77. x_text.AddRange(positive_examples);
  78. x_text.AddRange(negative_examples);
  79. x_text = x_text.Select(x => clean_str(x)).ToList();
  80. var positive_labels = positive_examples.Select(x => new int[2] { 0, 1 }).ToArray();
  81. var negative_labels = negative_examples.Select(x => new int[2] { 1, 0 }).ToArray();
  82. var y = np.concatenate(new int[][][] { positive_labels, negative_labels });
  83. return (x_text.ToArray(), y);
  84. }
  85. private static string clean_str(string str)
  86. {
  87. str = Regex.Replace(str, @"[^A-Za-z0-9(),!?\'\`]", " ");
  88. str = Regex.Replace(str, @"\'s", " \'s");
  89. return str;
  90. }
  91. /// <summary>
  92. /// Padding
  93. /// </summary>
  94. /// <param name="sequences"></param>
  95. /// <param name="pad_tok">the char to pad with</param>
  96. /// <returns>a list of list where each sublist has same length</returns>
  97. public static (int[][], int[]) pad_sequences(int[][] sequences, int pad_tok = 0)
  98. {
  99. int max_length = sequences.Select(x => x.Length).Max();
  100. return _pad_sequences(sequences, pad_tok, max_length);
  101. }
  102. public static (int[][][], int[][]) pad_sequences(int[][][] sequences, int pad_tok = 0)
  103. {
  104. int max_length_word = sequences.Select(x => x.Select(w => w.Length).Max()).Max();
  105. int[][][] sequence_padded;
  106. var sequence_length = new int[sequences.Length][];
  107. for (int i = 0; i < sequences.Length; i++)
  108. {
  109. // all words are same length now
  110. var (sp, sl) = _pad_sequences(sequences[i], pad_tok, max_length_word);
  111. sequence_length[i] = sl;
  112. }
  113. int max_length_sentence = sequences.Select(x => x.Length).Max();
  114. (sequence_padded, _) = _pad_sequences(sequences, np.repeat(pad_tok, max_length_word).Data<int>(), max_length_sentence);
  115. (sequence_length, _) = _pad_sequences(sequence_length, 0, max_length_sentence);
  116. return (sequence_padded, sequence_length);
  117. }
  118. private static (int[][], int[]) _pad_sequences(int[][] sequences, int pad_tok, int max_length)
  119. {
  120. var sequence_length = new int[sequences.Length];
  121. for (int i = 0; i < sequences.Length; i++)
  122. {
  123. sequence_length[i] = sequences[i].Length;
  124. Array.Resize(ref sequences[i], max_length);
  125. }
  126. return (sequences, sequence_length);
  127. }
  128. private static (int[][][], int[]) _pad_sequences(int[][][] sequences, int[] pad_tok, int max_length)
  129. {
  130. var sequence_length = new int[sequences.Length];
  131. for (int i = 0; i < sequences.Length; i++)
  132. {
  133. sequence_length[i] = sequences[i].Length;
  134. Array.Resize(ref sequences[i], max_length);
  135. for (int j = 0; j < max_length - sequence_length[i]; j++)
  136. {
  137. sequences[i][max_length - j - 1] = new int[pad_tok.Length];
  138. Array.Copy(pad_tok, sequences[i][max_length - j - 1], pad_tok.Length);
  139. }
  140. }
  141. return (sequences, sequence_length);
  142. }
  143. }
  144. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。