You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DataHelpers.cs 9.2 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. using NumSharp;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Linq;
  6. using System.Security.Cryptography;
  7. using System.Text;
  8. using System.Text.RegularExpressions;
  9. using TensorFlowNET.Examples.Utility;
  10. namespace TensorFlowNET.Examples
  11. {
  12. public class DataHelpers
  13. {
  14. public static Dictionary<string, int> build_word_dict(string path)
  15. {
  16. var contents = File.ReadAllLines(path);
  17. var words = new List<string>();
  18. foreach (var content in contents)
  19. words.AddRange(clean_str(content).Split(' ').Where(x => x.Length > 1));
  20. var word_counter = words.GroupBy(x => x)
  21. .Select(x => new { Word = x.Key, Count = x.Count() })
  22. .OrderByDescending(x => x.Count)
  23. .ToArray();
  24. var word_dict = new Dictionary<string, int>();
  25. word_dict["<pad>"] = 0;
  26. word_dict["<unk>"] = 1;
  27. word_dict["<eos>"] = 2;
  28. foreach (var word in word_counter)
  29. word_dict[word.Word] = word_dict.Count;
  30. return word_dict;
  31. }
  32. public static (int[][], int[]) build_word_dataset(string path, Dictionary<string, int> word_dict, int document_max_len)
  33. {
  34. var contents = File.ReadAllLines(path);
  35. var x = contents.Select(c => (clean_str(c) + " <eos>")
  36. .Split(' ').Take(document_max_len)
  37. .Select(w => word_dict.ContainsKey(w) ? word_dict[w] : word_dict["<unk>"]).ToArray())
  38. .ToArray();
  39. for (int i = 0; i < x.Length; i++)
  40. if (x[i].Length == document_max_len)
  41. x[i][document_max_len - 1] = word_dict["<eos>"];
  42. else
  43. Array.Resize(ref x[i], document_max_len);
  44. var y = contents.Select(c => int.Parse(c.Substring(0, c.IndexOf(','))) - 1).ToArray();
  45. return (x, y);
  46. }
  47. public static (int[][], int[], int) build_char_dataset(string path, string model, int document_max_len, int? limit = null, bool shuffle=true)
  48. {
  49. if (model != "vd_cnn")
  50. throw new NotImplementedException(model);
  51. string alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:’'\"/|_#$%ˆ&*˜‘+=<>()[]{} ";
  52. /*if (step == "train")
  53. df = pd.read_csv(TRAIN_PATH, names =["class", "title", "content"]);*/
  54. var char_dict = new Dictionary<string, int>();
  55. char_dict["<pad>"] = 0;
  56. char_dict["<unk>"] = 1;
  57. foreach (char c in alphabet)
  58. char_dict[c.ToString()] = char_dict.Count;
  59. var contents = File.ReadAllLines(path);
  60. if (shuffle)
  61. new Random(17).Shuffle(contents);
  62. //File.WriteAllLines("text_classification/dbpedia_csv/train_6400.csv", contents.Take(6400));
  63. var size = limit == null ? contents.Length : limit.Value;
  64. var x = new int[size][];
  65. var y = new int[size];
  66. var tenth = size / 10;
  67. var percent = 0;
  68. for (int i = 0; i < size; i++)
  69. {
  70. if ((i + 1) % tenth == 0)
  71. {
  72. percent += 10;
  73. Console.WriteLine($"\t{percent}%");
  74. }
  75. string[] parts = contents[i].ToLower().Split(",\"").ToArray();
  76. string content = parts[2];
  77. content = content.Substring(0, content.Length - 1);
  78. var a = new int[document_max_len];
  79. for (int j = 0; j < document_max_len; j++)
  80. {
  81. if (j >= content.Length)
  82. a[j] = char_dict["<pad>"];
  83. else
  84. a[j] = char_dict.ContainsKey(content[j].ToString()) ? char_dict[content[j].ToString()] : char_dict["<unk>"];
  85. }
  86. x[i] = a;
  87. y[i] = int.Parse(parts[0]);
  88. }
  89. return (x, y, alphabet.Length + 2);
  90. }
  91. /// <summary>
  92. /// Loads MR polarity data from files, splits the data into words and generates labels.
  93. /// Returns split sentences and labels.
  94. /// </summary>
  95. /// <param name="positive_data_file"></param>
  96. /// <param name="negative_data_file"></param>
  97. /// <returns></returns>
  98. public static (string[], NDArray) load_data_and_labels(string positive_data_file, string negative_data_file)
  99. {
  100. Directory.CreateDirectory("CnnTextClassification");
  101. Utility.Web.Download(positive_data_file, "CnnTextClassification", "rt -polarity.pos");
  102. Utility.Web.Download(negative_data_file, "CnnTextClassification", "rt-polarity.neg");
  103. // Load data from files
  104. var positive_examples = File.ReadAllLines("CnnTextClassification/rt-polarity.pos")
  105. .Select(x => x.Trim())
  106. .ToArray();
  107. var negative_examples = File.ReadAllLines("CnnTextClassification/rt-polarity.neg")
  108. .Select(x => x.Trim())
  109. .ToArray();
  110. var x_text = new List<string>();
  111. x_text.AddRange(positive_examples);
  112. x_text.AddRange(negative_examples);
  113. x_text = x_text.Select(x => clean_str(x)).ToList();
  114. var positive_labels = positive_examples.Select(x => new int[2] { 0, 1 }).ToArray();
  115. var negative_labels = negative_examples.Select(x => new int[2] { 1, 0 }).ToArray();
  116. var y = np.concatenate(new int[][][] { positive_labels, negative_labels });
  117. return (x_text.ToArray(), y);
  118. }
  119. private static string clean_str(string str)
  120. {
  121. str = Regex.Replace(str, "[^A-Za-z0-9(),!?]", " ");
  122. str = Regex.Replace(str, ",", " ");
  123. return str;
  124. }
  125. /// <summary>
  126. /// Padding
  127. /// </summary>
  128. /// <param name="sequences"></param>
  129. /// <param name="pad_tok">the char to pad with</param>
  130. /// <returns>a list of list where each sublist has same length</returns>
  131. public static (int[][], int[]) pad_sequences(int[][] sequences, int pad_tok = 0)
  132. {
  133. int max_length = sequences.Select(x => x.Length).Max();
  134. return _pad_sequences(sequences, pad_tok, max_length);
  135. }
  136. public static (int[][][], int[][]) pad_sequences(int[][][] sequences, int pad_tok = 0)
  137. {
  138. int max_length_word = sequences.Select(x => x.Select(w => w.Length).Max()).Max();
  139. int[][][] sequence_padded;
  140. var sequence_length = new int[sequences.Length][];
  141. for (int i = 0; i < sequences.Length; i++)
  142. {
  143. // all words are same length now
  144. var (sp, sl) = _pad_sequences(sequences[i], pad_tok, max_length_word);
  145. sequence_length[i] = sl;
  146. }
  147. int max_length_sentence = sequences.Select(x => x.Length).Max();
  148. (sequence_padded, _) = _pad_sequences(sequences, np.repeat(pad_tok, max_length_word).Data<int>(), max_length_sentence);
  149. (sequence_length, _) = _pad_sequences(sequence_length, 0, max_length_sentence);
  150. return (sequence_padded, sequence_length);
  151. }
  152. private static (int[][], int[]) _pad_sequences(int[][] sequences, int pad_tok, int max_length)
  153. {
  154. var sequence_length = new int[sequences.Length];
  155. for (int i = 0; i < sequences.Length; i++)
  156. {
  157. sequence_length[i] = sequences[i].Length;
  158. Array.Resize(ref sequences[i], max_length);
  159. }
  160. return (sequences, sequence_length);
  161. }
  162. private static (int[][][], int[]) _pad_sequences(int[][][] sequences, int[] pad_tok, int max_length)
  163. {
  164. var sequence_length = new int[sequences.Length];
  165. for (int i = 0; i < sequences.Length; i++)
  166. {
  167. sequence_length[i] = sequences[i].Length;
  168. Array.Resize(ref sequences[i], max_length);
  169. for (int j = 0; j < max_length - sequence_length[i]; j++)
  170. {
  171. sequences[i][max_length - j - 1] = new int[pad_tok.Length];
  172. Array.Copy(pad_tok, sequences[i][max_length - j - 1], pad_tok.Length);
  173. }
  174. }
  175. return (sequences, sequence_length);
  176. }
  177. public static string CalculateMD5Hash(string input)
  178. {
  179. // step 1, calculate MD5 hash from input
  180. MD5 md5 = System.Security.Cryptography.MD5.Create();
  181. byte[] inputBytes = System.Text.Encoding.ASCII.GetBytes(input);
  182. byte[] hash = md5.ComputeHash(inputBytes);
  183. // step 2, convert byte array to hex string
  184. StringBuilder sb = new StringBuilder();
  185. for (int i = 0; i < hash.Length; i++)
  186. {
  187. sb.Append(hash[i].ToString("X2"));
  188. }
  189. return sb.ToString();
  190. }
  191. }
  192. }