Browse Source

Update get_hwf.py

pull/3/head
troyyyyy GitHub 2 years ago
parent
commit
649513bb79
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 24 additions and 9 deletions
  1. +24
    -9
      datasets/hwf/get_hwf.py

+ 24
- 9
datasets/hwf/get_hwf.py View File

@@ -7,30 +7,45 @@ img_transform = transforms.Compose([
transforms.Normalize((0.5,), (1,)) transforms.Normalize((0.5,), (1,))
]) ])


def get_data(file, precision_num = 2):
def get_data(file, get_pseudo_label, precision_num = 2):
X = [] X = []
if(get_pseudo_label):
Z = []
Y = [] Y = []
img_dir = './datasets/hwf/data/Handwritten_Math_Symbols/' img_dir = './datasets/hwf/data/Handwritten_Math_Symbols/'
with open(file) as f: with open(file) as f:
data = json.load(f) data = json.load(f)
for idx in range(len(data)): for idx in range(len(data)):
imgs = [] imgs = []
imgs_pseudo_label = []
for img_path in data[idx]['img_paths']: for img_path in data[idx]['img_paths']:
img = Image.open(img_dir + img_path).convert('L') img = Image.open(img_dir + img_path).convert('L')
img = img_transform(img) img = img_transform(img)
imgs.append(img) imgs.append(img)
X.append(imgs)
Y.append(round(data[idx]['res'], precision_num))
return X, Y
if(get_pseudo_label):
imgs_pseudo_label.append(img_path.split('/')[0])
if(len(imgs) == 3):
X.append(imgs)
if(get_pseudo_label):
Z.append(imgs_pseudo_label)
Y.append(round(data[idx]['res'], precision_num))
if(get_pseudo_label):
return X, Z, Y
else:
return X, None, Y


def get_hwf(precision_num = 2):
train_X, train_Y = get_data('./datasets/hwf/data/expr_train.json', precision_num)
test_X, test_Y = get_data('./datasets/hwf/data/expr_test.json', precision_num)
def get_hwf(train = True, get_pseudo_label = False, precision_num = 2):
if(train):
file = './datasets/hwf/data/expr_train.json'
else:
file = './datasets/hwf/data/expr_test.json'
return train_X, train_Y, test_X, test_Y
return get_data(file, get_pseudo_label, precision_num)


if __name__ == "__main__": if __name__ == "__main__":
train_X, train_Y, test_X, test_Y = get_hwf()
train_X, train_Y = get_hwf(train = True)
test_X, test_Y = get_hwf(train = False)
print(len(train_X), len(test_X)) print(len(train_X), len(test_X))
print(len(train_X[0]), train_X[0][0].shape, train_Y[0]) print(len(train_X[0]), train_X[0][0].shape, train_Y[0])

Loading…
Cancel
Save