| @@ -0,0 +1,76 @@ | |||
| import numpy as np | |||
| class DataLoader(object): | |||
| def __init__(self, dataset='bookcorpus', doc_num=16000, save_gap=200, batch_size = 1024): | |||
| self.data_names = ['input_ids','token_type_ids','attention_mask','masked_lm_labels','next_sentence_label'] | |||
| self.data = {'input_ids':[], | |||
| 'token_type_ids':[], | |||
| 'attention_mask':[], | |||
| 'masked_lm_labels':[], | |||
| 'next_sentence_label':[]} | |||
| self.batch_size=batch_size | |||
| self.batch_data = {'input_ids':[], | |||
| 'token_type_ids':[], | |||
| 'attention_mask':[], | |||
| 'masked_lm_labels':[], | |||
| 'next_sentence_label':[]} | |||
| self.cur_batch_data = {'input_ids':[], | |||
| 'token_type_ids':[], | |||
| 'attention_mask':[], | |||
| 'masked_lm_labels':[], | |||
| 'next_sentence_label':[]} | |||
| self.load_data(dataset=dataset, doc_num=doc_num, save_gap=save_gap) | |||
| def load_data(self, dataset='bookcorpus', doc_num=16000, save_gap=200): | |||
| print('Loading preprocessed dataset %s...'%dataset) | |||
| data_dir = './preprocessed_data/%s/'%dataset | |||
| for i in range(0,doc_num,save_gap): | |||
| start, end = i, i+save_gap-1 | |||
| if end > doc_num-1: | |||
| end = doc_num-1 | |||
| range_name = '_%d_%d.npy'%(start,end) | |||
| print(start,end) | |||
| for data_name in self.data_names: | |||
| #print(data_dir+data_name+range_name) | |||
| self.data[data_name].append(np.load(data_dir+data_name+range_name)) | |||
| for data_name in self.data_names: | |||
| self.data[data_name] = np.concatenate(self.data[data_name],axis=0) | |||
| self.data_len = self.data['input_ids'].shape[0] | |||
| print(self.data['input_ids'].shape) | |||
| print('Successfully loaded dataset %s!'%dataset) | |||
| def make_epoch_data(self): | |||
| batch_data = [] | |||
| for i in range(0, self.data_len, self.batch_size): | |||
| start = i | |||
| end = start + self.batch_size | |||
| if end > self.data_len: | |||
| end = self.data_len | |||
| if end-start != self.batch_size: | |||
| break | |||
| for data_name in self.data_names: | |||
| self.batch_data[data_name].append(self.data[data_name][start:end]) | |||
| self.batch_num = len(self.batch_data['input_ids']) | |||
| def get_batch(self, idx): | |||
| if idx >= self.batch_num: | |||
| assert False | |||
| for data_name in self.data_names: | |||
| self.cur_batch_data[data_name] = self.batch_data[data_name][idx] | |||
| return self.cur_batch_data.copy() | |||
| def align(self, arr, length): | |||
| ori_len = len(arr) | |||
| if length > ori_len: | |||
| return arr + [0] * (length - ori_len) | |||
| else: | |||
| return arr[:length] | |||