You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TextClassificationWithMovieReviews.cs 4.2 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using Tensorflow;
  5. using NumSharp.Core;
  6. using Newtonsoft.Json;
  7. using System.Linq;
  8. using System.Text.RegularExpressions;
  9. namespace TensorFlowNET.Examples
  10. {
  11. public class TextClassificationWithMovieReviews : Python, IExample
  12. {
  13. string dir = "text_classification_with_movie_reviews";
  14. string dataFile = "imdb.zip";
  15. public void Run()
  16. {
  17. var((train_data, train_labels), (test_data, test_labels)) = PrepareData();
  18. Console.WriteLine($"Training entries: {train_data.size}, labels: {train_labels.size}");
  19. // A dictionary mapping words to an integer index
  20. var word_index = GetWordIndex();
  21. train_data = keras.preprocessing.sequence.pad_sequences(train_data,
  22. value: word_index["<PAD>"],
  23. padding: "post",
  24. maxlen: 256);
  25. test_data = keras.preprocessing.sequence.pad_sequences(test_data,
  26. value: word_index["<PAD>"],
  27. padding: "post",
  28. maxlen: 256);
  29. // input shape is the vocabulary count used for the movie reviews (10,000 words)
  30. int vocab_size = 10000;
  31. var model = keras.Sequential();
  32. model.add(keras.layers.Embedding(vocab_size, 16));
  33. }
  34. private ((NDArray, NDArray), (NDArray, NDArray)) PrepareData()
  35. {
  36. Directory.CreateDirectory(dir);
  37. // get model file
  38. string url = $"https://github.com/SciSharp/TensorFlow.NET/raw/master/data/{dataFile}";
  39. Utility.Web.Download(url, dir, "imdb.zip");
  40. Utility.Compress.UnZip(Path.Join(dir, $"imdb.zip"), dir);
  41. // prepare training dataset
  42. var x_train = ReadData(Path.Join(dir, "x_train.txt"));
  43. var labels_train = ReadData(Path.Join(dir, "y_train.txt"));
  44. var indices_train = ReadData(Path.Join(dir, "indices_train.txt"));
  45. x_train = x_train[indices_train];
  46. labels_train = labels_train[indices_train];
  47. var x_test = ReadData(Path.Join(dir, "x_test.txt"));
  48. var labels_test = ReadData(Path.Join(dir, "y_test.txt"));
  49. var indices_test = ReadData(Path.Join(dir, "indices_test.txt"));
  50. x_test = x_test[indices_test];
  51. labels_test = labels_test[indices_test];
  52. // not completed
  53. var xs = x_train.hstack(x_test);
  54. var labels = labels_train.hstack(labels_test);
  55. var idx = x_train.size;
  56. var y_train = labels_train;
  57. var y_test = labels_test;
  58. return ((x_train, y_train), (x_test, y_test));
  59. }
  60. private NDArray ReadData(string file)
  61. {
  62. var lines = File.ReadAllLines(file);
  63. var nd = new NDArray(lines[0].StartsWith("[") ? typeof(object) : np.int32, new Shape(lines.Length));
  64. if (lines[0].StartsWith("["))
  65. {
  66. for (int i = 0; i < lines.Length; i++)
  67. {
  68. var matches = Regex.Matches(lines[i], @"\d+\s*");
  69. var data = new int[matches.Count];
  70. for (int j = 0; j < data.Length; j++)
  71. data[j] = Convert.ToInt32(matches[j].Value);
  72. nd[i] = data.ToArray();
  73. }
  74. }
  75. else
  76. {
  77. for (int i = 0; i < lines.Length; i++)
  78. nd[i] = Convert.ToInt32(lines[i]);
  79. }
  80. return nd;
  81. }
  82. private Dictionary<string, int> GetWordIndex()
  83. {
  84. var result = new Dictionary<string, int>();
  85. var json = File.ReadAllText(Path.Join(dir, "imdb_word_index.json"));
  86. var dict = JsonConvert.DeserializeObject<Dictionary<string, int>>(json);
  87. dict.Keys.Select(k => result[k] = dict[k] + 3).ToList();
  88. result["<PAD>"] = 0;
  89. result["<START>"] = 1;
  90. result["<UNK>"] = 2; // unknown
  91. result["<UNUSED>"] = 3;
  92. return result;
  93. }
  94. }
  95. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。