| @@ -0,0 +1,76 @@ | |||||
| import numpy as np | |||||
| class DataLoader(object): | |||||
| def __init__(self, dataset='bookcorpus', doc_num=16000, save_gap=200, batch_size = 1024): | |||||
| self.data_names = ['input_ids','token_type_ids','attention_mask','masked_lm_labels','next_sentence_label'] | |||||
| self.data = {'input_ids':[], | |||||
| 'token_type_ids':[], | |||||
| 'attention_mask':[], | |||||
| 'masked_lm_labels':[], | |||||
| 'next_sentence_label':[]} | |||||
| self.batch_size=batch_size | |||||
| self.batch_data = {'input_ids':[], | |||||
| 'token_type_ids':[], | |||||
| 'attention_mask':[], | |||||
| 'masked_lm_labels':[], | |||||
| 'next_sentence_label':[]} | |||||
| self.cur_batch_data = {'input_ids':[], | |||||
| 'token_type_ids':[], | |||||
| 'attention_mask':[], | |||||
| 'masked_lm_labels':[], | |||||
| 'next_sentence_label':[]} | |||||
| self.load_data(dataset=dataset, doc_num=doc_num, save_gap=save_gap) | |||||
| def load_data(self, dataset='bookcorpus', doc_num=16000, save_gap=200): | |||||
| print('Loading preprocessed dataset %s...'%dataset) | |||||
| data_dir = './preprocessed_data/%s/'%dataset | |||||
| for i in range(0,doc_num,save_gap): | |||||
| start, end = i, i+save_gap-1 | |||||
| if end > doc_num-1: | |||||
| end = doc_num-1 | |||||
| range_name = '_%d_%d.npy'%(start,end) | |||||
| print(start,end) | |||||
| for data_name in self.data_names: | |||||
| #print(data_dir+data_name+range_name) | |||||
| self.data[data_name].append(np.load(data_dir+data_name+range_name)) | |||||
| for data_name in self.data_names: | |||||
| self.data[data_name] = np.concatenate(self.data[data_name],axis=0) | |||||
| self.data_len = self.data['input_ids'].shape[0] | |||||
| print(self.data['input_ids'].shape) | |||||
| print('Successfully loaded dataset %s!'%dataset) | |||||
| def make_epoch_data(self): | |||||
| batch_data = [] | |||||
| for i in range(0, self.data_len, self.batch_size): | |||||
| start = i | |||||
| end = start + self.batch_size | |||||
| if end > self.data_len: | |||||
| end = self.data_len | |||||
| if end-start != self.batch_size: | |||||
| break | |||||
| for data_name in self.data_names: | |||||
| self.batch_data[data_name].append(self.data[data_name][start:end]) | |||||
| self.batch_num = len(self.batch_data['input_ids']) | |||||
| def get_batch(self, idx): | |||||
| if idx >= self.batch_num: | |||||
| assert False | |||||
| for data_name in self.data_names: | |||||
| self.cur_batch_data[data_name] = self.batch_data[data_name][idx] | |||||
| return self.cur_batch_data.copy() | |||||
| def align(self, arr, length): | |||||
| ori_len = len(arr) | |||||
| if length > ori_len: | |||||
| return arr + [0] * (length - ori_len) | |||||
| else: | |||||
| return arr[:length] | |||||