You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

load_data.py 2.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import numpy as np
  2. class DataLoader(object):
  3. def __init__(self, dataset='bookcorpus', doc_num=16000, save_gap=200, batch_size = 1024):
  4. self.data_names = ['input_ids','token_type_ids','attention_mask','masked_lm_labels','next_sentence_label']
  5. self.data = {'input_ids':[],
  6. 'token_type_ids':[],
  7. 'attention_mask':[],
  8. 'masked_lm_labels':[],
  9. 'next_sentence_label':[]}
  10. self.batch_size=batch_size
  11. self.batch_data = {'input_ids':[],
  12. 'token_type_ids':[],
  13. 'attention_mask':[],
  14. 'masked_lm_labels':[],
  15. 'next_sentence_label':[]}
  16. self.cur_batch_data = {'input_ids':[],
  17. 'token_type_ids':[],
  18. 'attention_mask':[],
  19. 'masked_lm_labels':[],
  20. 'next_sentence_label':[]}
  21. self.load_data(dataset=dataset, doc_num=doc_num, save_gap=save_gap)
  22. def load_data(self, dataset='bookcorpus', doc_num=16000, save_gap=200):
  23. print('Loading preprocessed dataset %s...'%dataset)
  24. data_dir = './preprocessed_data/%s/'%dataset
  25. for i in range(0,doc_num,save_gap):
  26. start, end = i, i+save_gap-1
  27. if end > doc_num-1:
  28. end = doc_num-1
  29. range_name = '_%d_%d.npy'%(start,end)
  30. print(start,end)
  31. for data_name in self.data_names:
  32. #print(data_dir+data_name+range_name)
  33. self.data[data_name].append(np.load(data_dir+data_name+range_name))
  34. for data_name in self.data_names:
  35. self.data[data_name] = np.concatenate(self.data[data_name],axis=0)
  36. self.data_len = self.data['input_ids'].shape[0]
  37. print(self.data['input_ids'].shape)
  38. print('Successfully loaded dataset %s!'%dataset)
  39. def make_epoch_data(self):
  40. batch_data = []
  41. for i in range(0, self.data_len, self.batch_size):
  42. start = i
  43. end = start + self.batch_size
  44. if end > self.data_len:
  45. end = self.data_len
  46. if end-start != self.batch_size:
  47. break
  48. for data_name in self.data_names:
  49. self.batch_data[data_name].append(self.data[data_name][start:end])
  50. self.batch_num = len(self.batch_data['input_ids'])
  51. def get_batch(self, idx):
  52. if idx >= self.batch_num:
  53. assert False
  54. for data_name in self.data_names:
  55. self.cur_batch_data[data_name] = self.batch_data[data_name][idx]
  56. return self.cur_batch_data.copy()
  57. def align(self, arr, length):
  58. ori_len = len(arr)
  59. if length > ori_len:
  60. return arr + [0] * (length - ori_len)
  61. else:
  62. return arr[:length]

分布式深度学习系统