You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_load.py 4.4 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import numpy as np
  2. class DataLoader(object):
  3. def __init__(self, fpath1, fpath2, maxlen1, maxlen2, vocab_fpath):
  4. self.sents1, self.sents2 = self.load_data(
  5. fpath1, fpath2, maxlen1, maxlen2)
  6. self.token2idx, self.idx2token = self.load_vocab(vocab_fpath)
  7. self.maxlen1 = maxlen1
  8. self.maxlen2 = maxlen2
  9. def load_vocab(self, vocab_fpath):
  10. '''Loads vocabulary file and returns idx<->token maps
  11. vocab_fpath: string. vocabulary file path.
  12. Note that these are reserved
  13. 0: <pad>, 1: <unk>, 2: <s>, 3: </s>
  14. Returns
  15. two dictionaries.
  16. '''
  17. vocab = [line.split()[0] for line in open(
  18. vocab_fpath, 'r', encoding='utf-8').read().splitlines()]
  19. token2idx = {token: idx for idx, token in enumerate(vocab)}
  20. idx2token = {idx: token for idx, token in enumerate(vocab)}
  21. return token2idx, idx2token
  22. def load_data(self, fpath1, fpath2, maxlen1, maxlen2):
  23. '''Loads source and target data and filters out too lengthy samples.
  24. fpath1: source file path. string.
  25. fpath2: target file path. string.
  26. maxlen1: source sent maximum length. scalar.
  27. maxlen2: target sent maximum length. scalar.
  28. Returns
  29. sents1: list of source sents
  30. sents2: list of target sents
  31. '''
  32. sents1, sents2 = [], []
  33. with open(fpath1, 'r', encoding='utf-8') as f1, open(fpath2, 'r', encoding='utf-8') as f2:
  34. for sent1, sent2 in zip(f1, f2):
  35. if len(sent1.split()) + 1 > maxlen1:
  36. continue # 1: </s>
  37. if len(sent2.split()) + 1 > maxlen2:
  38. continue # 1: </s>
  39. sents1.append(sent1.strip())
  40. sents2.append(sent2.strip())
  41. return sents1, sents2
  42. def encode(self, inp, type, dict):
  43. '''Converts string to number. Used for `generator_fn`.
  44. inp: 1d byte array.
  45. type: "x" (source side) or "y" (target side)
  46. dict: token2idx dictionary
  47. Returns
  48. list of numbers
  49. '''
  50. inp_str = inp
  51. if type == "x":
  52. tokens = inp_str.split() + ["</s>"]
  53. else:
  54. tokens = ["<s>"] + inp_str.split() + ["</s>"]
  55. x = [dict.get(t, dict["<unk>"]) for t in tokens]
  56. return x
  57. def make_epoch_data(self, batch_size, shuffle=False):
  58. import copy
  59. new_sents1 = copy.deepcopy(self.sents1)
  60. new_sents2 = copy.deepcopy(self.sents2)
  61. if shuffle:
  62. import random
  63. random.shuffle(new_sents1)
  64. random.shuffle(new_sents2)
  65. xs = [self.encode(sent1, "x", self.token2idx) for sent1 in new_sents1]
  66. ys = [self.encode(sent2, "y", self.token2idx) for sent2 in new_sents2]
  67. batch_xs = []
  68. batch_ys = []
  69. for i in range(0, len(xs), batch_size):
  70. start = i
  71. end = start + batch_size
  72. batch_xs.append(xs[start:end])
  73. batch_ys.append(ys[start:end])
  74. if len(batch_xs[-1]) != batch_size:
  75. batch_xs = batch_xs[:-1]
  76. batch_ys = batch_ys[:-1]
  77. self.cur_xs = batch_xs
  78. self.cur_ys = batch_ys
  79. self.batch_num = len(batch_xs)
  80. self.idx = 0
  81. def get_batch(self, fill_maxlen=True):
  82. if self.idx >= self.batch_num:
  83. assert False
  84. cur_batch_x = self.cur_xs[self.idx]
  85. cur_batch_y = self.cur_ys[self.idx]
  86. self.idx += 1
  87. if fill_maxlen:
  88. cur_largest_len_x = self.maxlen1
  89. cur_largest_len_y = self.maxlen2
  90. else:
  91. cur_largest_len_x = max([len(x) for x in cur_batch_x])
  92. cur_largest_len_y = max([len(y) for y in cur_batch_y])
  93. cur_batch_x = np.array([self.align(x, cur_largest_len_x)
  94. for x in cur_batch_x]).astype(np.float32)
  95. cur_batch_y = np.array([self.align(y, cur_largest_len_y)
  96. for y in cur_batch_y]).astype(np.float32)
  97. return (cur_batch_x, cur_largest_len_x), (cur_batch_y, cur_largest_len_y)
  98. def align(self, arr, length):
  99. ori_len = len(arr)
  100. if length > ori_len:
  101. return arr + [0] * (length - ori_len)
  102. else:
  103. return arr[:length]
  104. def get_pad(self):
  105. return self.token2idx["<pad>"]