You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DataHelpers.cs 3.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. using NumSharp.Core;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using System.Linq;
  6. using System.Text;
  7. using System.Text.RegularExpressions;
  8. namespace TensorFlowNET.Examples.CnnTextClassification
  9. {
  10. public class DataHelpers
  11. {
  12. private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv";
  13. private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv";
  14. public static (int[][], int[], int) build_char_dataset(string step, string model, int document_max_len)
  15. {
  16. string alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:’'\"/|_#$%ˆ&*˜‘+=<>()[]{} ";
  17. /*if (step == "train")
  18. df = pd.read_csv(TRAIN_PATH, names =["class", "title", "content"]);*/
  19. var char_dict = new Dictionary<string, int>();
  20. char_dict["<pad>"] = 0;
  21. char_dict["<unk>"] = 1;
  22. foreach (char c in alphabet)
  23. char_dict[c.ToString()] = char_dict.Count;
  24. var contents = File.ReadAllLines(TRAIN_PATH);
  25. var x = new int[contents.Length][];
  26. var y = new int[contents.Length];
  27. for (int i = 0; i < contents.Length; i++)
  28. {
  29. string[] parts = contents[i].ToLower().Split(",\"").ToArray();
  30. string content = parts[2];
  31. content = content.Substring(0, content.Length - 1);
  32. x[i] = new int[document_max_len];
  33. for (int j = 0; j < document_max_len; j++)
  34. {
  35. if (j >= content.Length)
  36. x[i][j] = char_dict["<pad>"];
  37. else
  38. x[i][j] = char_dict.ContainsKey(content[j].ToString()) ? char_dict[content[j].ToString()] : char_dict["<unk>"];
  39. }
  40. y[i] = int.Parse(parts[0]);
  41. }
  42. return (x, y, alphabet.Length + 2);
  43. }
  44. /// <summary>
  45. /// Loads MR polarity data from files, splits the data into words and generates labels.
  46. /// Returns split sentences and labels.
  47. /// </summary>
  48. /// <param name="positive_data_file"></param>
  49. /// <param name="negative_data_file"></param>
  50. /// <returns></returns>
  51. public static (string[], NDArray) load_data_and_labels(string positive_data_file, string negative_data_file)
  52. {
  53. Directory.CreateDirectory("CnnTextClassification");
  54. Utility.Web.Download(positive_data_file, "CnnTextClassification", "rt -polarity.pos");
  55. Utility.Web.Download(negative_data_file, "CnnTextClassification", "rt-polarity.neg");
  56. // Load data from files
  57. var positive_examples = File.ReadAllLines("CnnTextClassification/rt-polarity.pos")
  58. .Select(x => x.Trim())
  59. .ToArray();
  60. var negative_examples = File.ReadAllLines("CnnTextClassification/rt-polarity.neg")
  61. .Select(x => x.Trim())
  62. .ToArray();
  63. var x_text = new List<string>();
  64. x_text.AddRange(positive_examples);
  65. x_text.AddRange(negative_examples);
  66. x_text = x_text.Select(x => clean_str(x)).ToList();
  67. var positive_labels = positive_examples.Select(x => new int[2] { 0, 1 }).ToArray();
  68. var negative_labels = negative_examples.Select(x => new int[2] { 1, 0 }).ToArray();
  69. var y = np.concatenate(new int[][][] { positive_labels, negative_labels });
  70. return (x_text.ToArray(), y);
  71. }
  72. private static string clean_str(string str)
  73. {
  74. str = Regex.Replace(str, @"[^A-Za-z0-9(),!?\'\`]", " ");
  75. str = Regex.Replace(str, @"\'s", " \'s");
  76. return str;
  77. }
  78. }
  79. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。