You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

README.md 3.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # LSTM Example
  2. ## Description
  3. This example is for LSTM model training and evaluation.
  4. ## Requirements
  5. - Install [MindSpore](https://www.mindspore.cn/install/en).
  6. - Download the dataset aclImdb_v1.
  7. > Unzip the aclImdb_v1 dataset to any path you want and the folder structure should be as follows:
  8. > ```
  9. > .
  10. > ├── train # train dataset
  11. > └── test # infer dataset
  12. > ```
  13. - Download the GloVe file.
  14. > Unzip the glove.6B.zip to any path you want and the folder structure should be as follows:
  15. > ```
  16. > .
  17. > ├── glove.6B.100d.txt
  18. > ├── glove.6B.200d.txt
  19. > ├── glove.6B.300d.txt # we will use this one later.
  20. > └── glove.6B.50d.txt
  21. > ```
  22. > Adding a new line at the beginning of the file which named `glove.6B.300d.txt`.
  23. > It means reading a total of 400,000 words, each represented by a 300-latitude word vector.
  24. > ```
  25. > 400000 300
  26. > ```
  27. ## Running the Example
  28. ### Training
  29. ```
  30. python train.py --preprocess=true --aclimdb_path=your_imdb_path --glove_path=your_glove_path > out.train.log 2>&1 &
  31. ```
  32. The python command above will run in the background, you can view the results through the file `out.train.log`.
  33. After training, you'll get some checkpoint files under the script folder by default.
  34. You will get the loss value as following:
  35. ```
  36. # grep "loss is " out.train.log
  37. epoch: 1 step: 390, loss is 0.6003723
  38. epcoh: 2 step: 390, loss is 0.35312173
  39. ...
  40. ```
  41. ### Evaluation
  42. ```
  43. python eval.py --ckpt_path=./lstm-20-390.ckpt > out.eval.log 2>&1 &
  44. ```
  45. The above python command will run in the background, you can view the results through the file `out.eval.log`.
  46. You will get the accuracy as following:
  47. ```
  48. # grep "acc" out.eval.log
  49. result: {'acc': 0.83}
  50. ```
  51. ## Usage:
  52. ### Training
  53. ```
  54. usage: train.py [--preprocess {true,false}] [--aclimdb_path ACLIMDB_PATH]
  55. [--glove_path GLOVE_PATH] [--preprocess_path PREPROCESS_PATH]
  56. [--ckpt_path CKPT_PATH] [--device_target {GPU,CPU}]
  57. parameters/options:
  58. --preprocess whether to preprocess data.
  59. --aclimdb_path path where the dataset is stored.
  60. --glove_path path where the GloVe is stored.
  61. --preprocess_path path where the pre-process data is stored.
  62. --ckpt_path the path to save the checkpoint file.
  63. --device_target the target device to run, support "GPU", "CPU".
  64. ```
  65. ### Evaluation
  66. ```
  67. usage: eval.py [--preprocess {true,false}] [--aclimdb_path ACLIMDB_PATH]
  68. [--glove_path GLOVE_PATH] [--preprocess_path PREPROCESS_PATH]
  69. [--ckpt_path CKPT_PATH] [--device_target {GPU,CPU}]
  70. parameters/options:
  71. --preprocess whether to preprocess data.
  72. --aclimdb_path path where the dataset is stored.
  73. --glove_path path where the GloVe is stored.
  74. --preprocess_path path where the pre-process data is stored.
  75. --ckpt_path the checkpoint file path used to evaluate model.
  76. --device_target the target device to run, support "GPU", "CPU".
  77. ```