You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TextClassificationWithMovieReviews.cs 3.0 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Text;
  5. using Tensorflow;
  6. using NumSharp.Core;
  7. using Newtonsoft.Json;
  8. using System.Linq;
  9. using Keras;
  10. namespace TensorFlowNET.Examples
  11. {
  12. public class TextClassificationWithMovieReviews : Python, IExample
  13. {
  14. string dir = "text_classification_with_movie_reviews";
  15. string dataFile = "imdb.zip";
  16. public void Run()
  17. {
  18. var((train_data, train_labels), (test_data, test_labels)) = PrepareData();
  19. Console.WriteLine($"Training entries: {train_data.size}, labels: {train_labels.size}");
  20. // A dictionary mapping words to an integer index
  21. var word_index = GetWordIndex();
  22. train_data = keras.preprocessing.sequence.pad_sequences(train_data,
  23. value: word_index["<PAD>"],
  24. padding: "post",
  25. maxlen: 256);
  26. }
  27. private ((NDArray, NDArray), (NDArray, NDArray)) PrepareData()
  28. {
  29. Directory.CreateDirectory(dir);
  30. // get model file
  31. string url = $"https://github.com/SciSharp/TensorFlow.NET/raw/master/data/{dataFile}";
  32. string zipFile = Path.Join(dir, $"imdb.zip");
  33. Utility.Web.Download(url, zipFile);
  34. Utility.Compress.UnZip(zipFile, dir);
  35. // prepare training dataset
  36. NDArray x_train = File.ReadAllLines(Path.Join(dir, "x_train.txt"));
  37. NDArray labels_train = File.ReadAllLines(Path.Join(dir, "y_train.txt"));
  38. NDArray indices_train = File.ReadAllLines(Path.Join(dir, "indices_train.txt"));
  39. // x_train = x_train[indices_train];
  40. // labels_train = labels_train[indices_train];
  41. NDArray x_test = File.ReadAllLines(Path.Join(dir, "x_test.txt"));
  42. NDArray labels_test = File.ReadAllLines(Path.Join(dir, "y_test.txt"));
  43. NDArray indices_test = File.ReadAllLines(Path.Join(dir, "indices_test.txt"));
  44. // x_test = x_test[indices_test];
  45. // labels_test = labels_test[indices_test];
  46. // not completed
  47. var xs = x_train.hstack(x_test);
  48. var labels = labels_train.hstack(labels_test);
  49. var idx = x_train.size;
  50. var y_train = labels_train;
  51. var y_test = labels_test;
  52. return ((x_train, y_train), (x_test, y_test));
  53. }
  54. private Dictionary<string, int> GetWordIndex()
  55. {
  56. var result = new Dictionary<string, int>();
  57. var json = File.ReadAllText(Path.Join(dir, "imdb_word_index.json"));
  58. var dict = JsonConvert.DeserializeObject<Dictionary<string, int>>(json);
  59. dict.Keys.Select(k => result[k] = dict[k] + 3).ToList();
  60. result["<PAD>"] = 0;
  61. result["<START>"] = 1;
  62. result["<UNK>"] = 2; // unknown
  63. result["<UNUSED>"] = 3;
  64. return result;
  65. }
  66. }
  67. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。