You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

prepare_data.py 5.1 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import os
  2. import wget
  3. import tarfile
  4. import errno
  5. import sentencepiece as spm
  6. import re
  7. from hparams import Hparams
  8. import logging
  9. logging.basicConfig(level=logging.INFO)
  10. def prepro(hp):
  11. """Load raw data -> Preprocessing -> Segmenting with sentencepice
  12. hp: hyperparams. argparse.
  13. """
  14. logging.info("# Check if raw files exist")
  15. train1 = "iwslt2016/de-en/train.tags.de-en.de"
  16. train2 = "iwslt2016/de-en/train.tags.de-en.en"
  17. eval1 = "iwslt2016/de-en/IWSLT16.TED.tst2013.de-en.de.xml"
  18. eval2 = "iwslt2016/de-en/IWSLT16.TED.tst2013.de-en.en.xml"
  19. test1 = "iwslt2016/de-en/IWSLT16.TED.tst2014.de-en.de.xml"
  20. test2 = "iwslt2016/de-en/IWSLT16.TED.tst2014.de-en.en.xml"
  21. for f in (train1, train2, eval1, eval2, test1, test2):
  22. if not os.path.isfile(f):
  23. raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), f)
  24. logging.info("# Preprocessing")
  25. # train
  26. def _prepro(x): return [line.strip() for line in open(x, 'r').read().split("\n")
  27. if not line.startswith("<")]
  28. prepro_train1, prepro_train2 = _prepro(train1), _prepro(train2)
  29. assert len(prepro_train1) == len(
  30. prepro_train2), "Check if train source and target files match."
  31. # eval
  32. def _prepro(x): return [re.sub("<[^>]+>", "", line).strip()
  33. for line in open(x, 'r').read().split("\n")
  34. if line.startswith("<seg id")]
  35. prepro_eval1, prepro_eval2 = _prepro(eval1), _prepro(eval2)
  36. assert len(prepro_eval1) == len(
  37. prepro_eval2), "Check if eval source and target files match."
  38. # test
  39. prepro_test1, prepro_test2 = _prepro(test1), _prepro(test2)
  40. assert len(prepro_test1) == len(
  41. prepro_test2), "Check if test source and target files match."
  42. logging.info("Let's see how preprocessed data look like")
  43. logging.info("prepro_train1:", prepro_train1[0])
  44. logging.info("prepro_train2:", prepro_train2[0])
  45. logging.info("prepro_eval1:", prepro_eval1[0])
  46. logging.info("prepro_eval2:", prepro_eval2[0])
  47. logging.info("prepro_test1:", prepro_test1[0])
  48. logging.info("prepro_test2:", prepro_test2[0])
  49. logging.info("# write preprocessed files to disk")
  50. os.makedirs("iwslt2016/prepro", exist_ok=True)
  51. def _write(sents, fname):
  52. with open(fname, 'w') as fout:
  53. fout.write("\n".join(sents))
  54. _write(prepro_train1, "iwslt2016/prepro/train.de")
  55. _write(prepro_train2, "iwslt2016/prepro/train.en")
  56. _write(prepro_train1+prepro_train2, "iwslt2016/prepro/train")
  57. _write(prepro_eval1, "iwslt2016/prepro/eval.de")
  58. _write(prepro_eval2, "iwslt2016/prepro/eval.en")
  59. _write(prepro_test1, "iwslt2016/prepro/test.de")
  60. _write(prepro_test2, "iwslt2016/prepro/test.en")
  61. logging.info("# Train a joint BPE model with sentencepiece")
  62. os.makedirs("iwslt2016/segmented", exist_ok=True)
  63. train = '--input=iwslt2016/prepro/train --pad_id=0 --unk_id=1 \
  64. --bos_id=2 --eos_id=3\
  65. --model_prefix=iwslt2016/segmented/bpe --vocab_size={} \
  66. --model_type=bpe'.format(hp.vocab_size)
  67. spm.SentencePieceTrainer.Train(train)
  68. logging.info("# Load trained bpe model")
  69. sp = spm.SentencePieceProcessor()
  70. sp.Load("iwslt2016/segmented/bpe.model")
  71. logging.info("# Segment")
  72. def _segment_and_write(sents, fname):
  73. with open(fname, "w") as fout:
  74. for sent in sents:
  75. pieces = sp.EncodeAsPieces(sent)
  76. fout.write(" ".join(pieces) + "\n")
  77. _segment_and_write(prepro_train1, "iwslt2016/segmented/train.de.bpe")
  78. _segment_and_write(prepro_train2, "iwslt2016/segmented/train.en.bpe")
  79. _segment_and_write(prepro_eval1, "iwslt2016/segmented/eval.de.bpe")
  80. _segment_and_write(prepro_eval2, "iwslt2016/segmented/eval.en.bpe")
  81. _segment_and_write(prepro_test1, "iwslt2016/segmented/test.de.bpe")
  82. logging.info("Let's see how segmented data look like")
  83. print("train1:", open("iwslt2016/segmented/train.de.bpe", 'r').readline())
  84. print("train2:", open("iwslt2016/segmented/train.en.bpe", 'r').readline())
  85. print("eval1:", open("iwslt2016/segmented/eval.de.bpe", 'r').readline())
  86. print("eval2:", open("iwslt2016/segmented/eval.en.bpe", 'r').readline())
  87. print("test1:", open("iwslt2016/segmented/test.de.bpe", 'r').readline())
  88. if __name__ == '__main__':
  89. if not os.path.exists('iwslt2016'):
  90. os.mkdir('iwslt2016')
  91. os.chdir('iwslt2016')
  92. file_name = 'de-en.tgz'
  93. if not os.path.exists(file_name):
  94. print('Downloading iwslt2016...')
  95. url = 'https://wit3.fbk.eu/archive/2016-01//texts/de/en/de-en.tgz'
  96. file_name = wget.download(url)
  97. print()
  98. if not os.path.exists('de-en'):
  99. print('Extracting iwslt2016...')
  100. with tarfile.open(file_name) as tar:
  101. tar.extractall('./')
  102. os.chdir('../')
  103. hparams = Hparams()
  104. parser = hparams.parser
  105. hp = parser.parse_args()
  106. print('Preprocessing iwslt2016...')
  107. prepro(hp)
  108. logging.info("Done")