You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

processBertData.py 12 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. from datasets import load_dataset
  2. import random
  3. import hetu
  4. import os
  5. import numpy as np
  6. ''' Usage example:
  7. In dir Hetu/examples/nlp/bert/: python processBertData.py
  8. '''
  9. # https://the-eye.eu/public/AI/pile_preliminary_components/books1.tar.gz
  10. class TrainingInstance(object):
  11. """A single training instance (sentence pair)."""
  12. def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
  13. is_random_next):
  14. self.tokens = tokens
  15. self.segment_ids = segment_ids
  16. self.is_random_next = is_random_next
  17. self.masked_lm_positions = masked_lm_positions
  18. self.masked_lm_labels = masked_lm_labels
  19. def __str__(self):
  20. s = ""
  21. s += "tokens: %s\n" % (" ".join(
  22. [str(x) for x in self.tokens]))
  23. s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
  24. s += "is_random_next: %s\n" % self.is_random_next
  25. s += "masked_lm_positions: %s\n" % (" ".join(
  26. [str(x) for x in self.masked_lm_positions]))
  27. s += "masked_lm_labels: %s\n" % (" ".join(
  28. [str(x) for x in self.masked_lm_labels]))
  29. s += "\n"
  30. return s
  31. def __repr__(self):
  32. return self.__str__()
  33. def create_masked_lm_predictions(tokens, masked_lm_prob,
  34. max_predictions_per_seq, vocab_words, rng):
  35. """Creates the predictions for the masked LM objective."""
  36. cand_indexes = []
  37. for (i, token) in enumerate(tokens):
  38. if token == "[CLS]" or token == "[SEP]":
  39. continue
  40. cand_indexes.append(i)
  41. rng.shuffle(cand_indexes)
  42. output_tokens = list(tokens)
  43. num_to_predict = min(max_predictions_per_seq,
  44. max(1, int(round(len(tokens) * masked_lm_prob))))
  45. masked_lms = []
  46. for index in cand_indexes:
  47. if len(masked_lms) >= num_to_predict:
  48. break
  49. masked_token = None
  50. # replace with [MASK] at 80%.
  51. if rng.random() < 0.8:
  52. masked_token = "[MASK]"
  53. else:
  54. # keep original at 10%.
  55. if rng.random() < 0.5:
  56. masked_token = tokens[index]
  57. # replace with random word at 10%.
  58. else:
  59. masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
  60. output_tokens[index] = masked_token
  61. masked_lms.append([index, tokens[index]])
  62. masked_lms.sort(key = lambda x: x[0])
  63. masked_lm_positions = []
  64. masked_lm_labels = []
  65. for p in masked_lms:
  66. masked_lm_positions.append(p[0])
  67. masked_lm_labels.append(p[1])
  68. return (output_tokens, masked_lm_positions, masked_lm_labels)
  69. def create_data_from_document(all_document, doc_id, max_seq_length, short_seq_prob, masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
  70. """ Create Training example for input document """
  71. document = all_document[doc_id]
  72. max_num_tokens = max_seq_length - 3 # [CLS], [SEP], [SEP]
  73. target_seq_length = max_num_tokens
  74. # generate short sequence at the probility of short_seq_prob
  75. # In order to minimize the mismatch between pre-training and fine-tuning.
  76. if rng.random() < short_seq_prob:
  77. target_seq_length = rng.randint(2, max_num_tokens)
  78. instances = []
  79. current_chunk = []
  80. current_length = 0
  81. i = 0
  82. while i < len(document):
  83. segment = document[i]
  84. current_chunk.append(segment)
  85. current_length += len(segment)
  86. if i == len(document) - 1 or current_length >= target_seq_length:
  87. if current_chunk:
  88. # create sentence A
  89. a_end = 1
  90. if len(current_chunk) >= 2:
  91. a_end = rng.randint(1, len(current_chunk) - 1)
  92. tokens_a = []
  93. for j in range(a_end):
  94. tokens_a.extend([current_chunk[j]])
  95. tokens_b = []
  96. # Random next
  97. is_random_next = False
  98. if len(current_chunk) == 1 or rng.random() < 0.5:
  99. is_random_next = True
  100. target_b_length = target_seq_length - len(tokens_a)
  101. for _ in range(10):
  102. random_document_index = rng.randint(0, len(all_document) - 1)
  103. if random_document_index != doc_id:
  104. break
  105. #If picked random document is the same as the current document
  106. if random_document_index == doc_id:
  107. is_random_next = False
  108. random_document = all_document[random_document_index]
  109. random_start = rng.randint(0, len(random_document) - 1)
  110. for j in range(random_start, len(random_document)):
  111. tokens_b.extend([random_document[j]])
  112. if len(tokens_b) >= target_b_length:
  113. break
  114. # We didn't actually use these segments so we "put them back" so
  115. # they don't go to waste.
  116. num_unused_segments = len(current_chunk) - a_end
  117. i -= num_unused_segments
  118. # Actual next
  119. else:
  120. is_random_next = False
  121. for j in range(a_end, len(current_chunk)):
  122. tokens_b.extend([current_chunk[j]])
  123. truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
  124. assert len(tokens_a) >= 1
  125. assert len(tokens_b) >= 1
  126. tokens = []
  127. segment_ids = []
  128. tokens.append("[CLS]")
  129. segment_ids.append(0)
  130. for token in tokens_a:
  131. tokens.append(token)
  132. segment_ids.append(0)
  133. tokens.append("[SEP]")
  134. segment_ids.append(0)
  135. for token in tokens_b:
  136. tokens.append(token)
  137. segment_ids.append(1)
  138. tokens.append("[SEP]")
  139. segment_ids.append(1)
  140. (tokens, masked_lm_positions, masked_lm_labels) = create_masked_lm_predictions(
  141. tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
  142. instance = TrainingInstance(
  143. tokens=tokens,
  144. segment_ids=segment_ids,
  145. is_random_next=is_random_next,
  146. masked_lm_positions=masked_lm_positions,
  147. masked_lm_labels=masked_lm_labels)
  148. instances.append(instance)
  149. current_chunk = []
  150. current_length = 0
  151. i += 1
  152. return instances
  153. def convert_instances_to_data(instances, tokenizer, max_seq_length):
  154. num_instances = len(instances)
  155. input_ids_list = np.zeros([num_instances, max_seq_length], dtype="int32")
  156. input_mask_list = np.zeros([num_instances, max_seq_length], dtype="int32")
  157. segment_ids_list = np.zeros([num_instances, max_seq_length], dtype="int32")
  158. masked_lm_labels = np.full([num_instances, max_seq_length],-1, dtype="int32")
  159. next_sentence_labels_list = np.zeros(num_instances, dtype="int32")
  160. for (idx, instance) in enumerate(instances):
  161. input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
  162. input_mask = [1] * len(input_ids)
  163. segment_ids = list(instance.segment_ids)
  164. assert len(input_ids) <= max_seq_length
  165. padding_zero_list = [0]*int(max_seq_length - len(input_ids))
  166. input_ids += padding_zero_list
  167. input_mask += padding_zero_list
  168. segment_ids += padding_zero_list
  169. assert len(input_ids) == max_seq_length
  170. assert len(input_mask) == max_seq_length
  171. assert len(segment_ids) == max_seq_length
  172. masked_lm_positions = list(instance.masked_lm_positions)
  173. masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
  174. input_ids_list[idx][:] = input_ids
  175. input_mask_list[idx][:] = input_mask
  176. segment_ids_list[idx][:] = segment_ids
  177. masked_lm_labels[idx][masked_lm_positions] = masked_lm_ids
  178. next_sentence_labels_list[idx] = 1 if instance.is_random_next else 0
  179. return input_ids_list, input_mask_list, segment_ids_list, masked_lm_labels, next_sentence_labels_list
  180. def create_pretrain_data(dataset, tokenizer, max_seq_length, short_seq_prob, masked_lm_prob, max_predictions_per_seq, rng):
  181. documents, all_data = [], [[],[],[],[],[]]
  182. vocab_words = list(tokenizer.vocab.keys())
  183. save_path='./preprocessed_data/bookcorpus/'
  184. if not os.path.exists(save_path):
  185. os.makedirs(save_path)
  186. for i in range(dataset['train'].shape[0]):
  187. tokens = tokenizer.tokenize(dataset['train'][i]['text'])
  188. documents.append(tokens)
  189. instance = create_data_from_document(documents, i,\
  190. max_seq_length, short_seq_prob, masked_lm_prob,
  191. max_predictions_per_seq, vocab_words, rng)
  192. data = convert_instances_to_data(instance, tokenizer, max_seq_length)
  193. print(i, len(tokens), len(instance))
  194. for j in range(5):
  195. all_data[j].append(data[j])
  196. save_gap=200
  197. if (i+1)%save_gap==0 and i:
  198. input_ids_list, input_mask_list, segment_ids_list, masked_lm_labels, next_sentence_labels_list = [np.concatenate(all_data[j],axis=0) for j in range(5)]
  199. print('Saving data from %d to %d: doc_num = %d, input_ids_shape ='%(i+1-save_gap,i, i+1), input_ids_list.shape)
  200. save_data(input_ids_list, input_mask_list, segment_ids_list, masked_lm_labels, next_sentence_labels_list, name='_%d_%d'%(i+1-save_gap,i))
  201. all_data = [[],[],[],[],[]]
  202. if i == dataset['train'].shape[0]-1:
  203. input_ids_list, input_mask_list, segment_ids_list, masked_lm_labels, next_sentence_labels_list = [np.concatenate(all_data[j],axis=0) for j in range(5)]
  204. print('Saving data from %d to %d: doc_num = %d, input_ids_shape ='%(save_gap*int(i/save_gap),i, i+1), input_ids_list.shape)
  205. save_data(input_ids_list, input_mask_list, segment_ids_list, masked_lm_labels, next_sentence_labels_list, name='_%d_%d'%(save_gap*int(i/save_gap),i))
  206. def save_data(input_ids_list, input_mask_list, segment_ids_list, masked_lm_labels, next_sentence_labels_list,name=''):
  207. save_path='./preprocessed_data/bookcorpus/'
  208. np.save(save_path+'input_ids'+name,np.array(input_ids_list))
  209. np.save(save_path+'token_type_ids'+name,np.array(segment_ids_list))
  210. np.save(save_path+'attention_mask'+name,np.array(input_mask_list))
  211. np.save(save_path+'masked_lm_labels'+name,np.array(masked_lm_labels))
  212. np.save(save_path+'next_sentence_label'+name,np.array(next_sentence_labels_list))
  213. def truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng):
  214. """Truncates a pair of sequences to a maximum sequence length."""
  215. while True:
  216. total_length = len(tokens_a) + len(tokens_b)
  217. if total_length <= max_num_tokens:
  218. break
  219. trunc_tokens = tokens_a if len(tokens_a) > len(tokens_b) else tokens_b
  220. assert len(trunc_tokens) >= 1
  221. #add more randomness and avoid biases.
  222. if rng.random() < 0.5:
  223. del trunc_tokens[0]
  224. else:
  225. trunc_tokens.pop()
  226. def show_dataset_detail(dataset):
  227. print(dataset.shape)
  228. print(dataset.column_names)
  229. print(dataset['train'].features)
  230. print(dataset['train'][0]['text'])
  231. if __name__ == "__main__":
  232. max_seq_length = 512
  233. do_lower_case = True
  234. short_seq_prob = 0.1
  235. masked_lm_prob = 0.15
  236. max_predictions_per_seq = 20
  237. vocab_path = "./datasets/bert-base-uncased-vocab.txt"
  238. dataset = load_dataset('../bookcorpus', cache_dir = "./cached_data")
  239. print("total number of documents {} ".format(dataset['train'].shape[0]))
  240. random_seed = 123
  241. rng = random.Random(random_seed)
  242. tokenizer = hetu.BertTokenizer(vocab_file=vocab_path, do_lower_case = do_lower_case)
  243. print("vocab_size =",len(tokenizer.vocab))
  244. print("max_seq_len =", max_seq_length)
  245. create_pretrain_data(dataset, tokenizer, max_seq_length, short_seq_prob, masked_lm_prob, max_predictions_per_seq, rng)

分布式深度学习系统