import numpy as np class DataLoader(object): def __init__(self, dataset='bookcorpus', doc_num=16000, save_gap=200, batch_size = 1024): self.data_names = ['input_ids','token_type_ids','attention_mask','masked_lm_labels','next_sentence_label'] self.data = {'input_ids':[], 'token_type_ids':[], 'attention_mask':[], 'masked_lm_labels':[], 'next_sentence_label':[]} self.batch_size=batch_size self.batch_data = {'input_ids':[], 'token_type_ids':[], 'attention_mask':[], 'masked_lm_labels':[], 'next_sentence_label':[]} self.cur_batch_data = {'input_ids':[], 'token_type_ids':[], 'attention_mask':[], 'masked_lm_labels':[], 'next_sentence_label':[]} self.load_data(dataset=dataset, doc_num=doc_num, save_gap=save_gap) def load_data(self, dataset='bookcorpus', doc_num=16000, save_gap=200): print('Loading preprocessed dataset %s...'%dataset) data_dir = './preprocessed_data/%s/'%dataset for i in range(0,doc_num,save_gap): start, end = i, i+save_gap-1 if end > doc_num-1: end = doc_num-1 range_name = '_%d_%d.npy'%(start,end) print(start,end) for data_name in self.data_names: #print(data_dir+data_name+range_name) self.data[data_name].append(np.load(data_dir+data_name+range_name)) for data_name in self.data_names: self.data[data_name] = np.concatenate(self.data[data_name],axis=0) self.data_len = self.data['input_ids'].shape[0] print(self.data['input_ids'].shape) print('Successfully loaded dataset %s!'%dataset) def make_epoch_data(self): batch_data = [] for i in range(0, self.data_len, self.batch_size): start = i end = start + self.batch_size if end > self.data_len: end = self.data_len if end-start != self.batch_size: break for data_name in self.data_names: self.batch_data[data_name].append(self.data[data_name][start:end]) self.batch_num = len(self.batch_data['input_ids']) def get_batch(self, idx): if idx >= self.batch_num: assert False for data_name in self.data_names: self.cur_batch_data[data_name] = self.batch_data[data_name][idx] return self.cur_batch_data.copy() def align(self, arr, length): ori_len = len(arr) if length > ori_len: return arr + [0] * (length - ori_len) else: return arr[:length]