Browse Source

Create get_hwf.py

pull/3/head
troyyyyy GitHub 2 years ago
parent
commit
05ae1517b9
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 36 additions and 0 deletions
  1. +36
    -0
      datasets/hwf/get_hwf.py

+ 36
- 0
datasets/hwf/get_hwf.py View File

@@ -0,0 +1,36 @@
import json
from PIL import Image
from torchvision.transforms import transforms

img_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (1,))
])

def get_data(file):
X = []
Y = []
img_dir = './data/Handwritten_Math_Symbols/'
with open(file) as f:
data = json.load(f)
for idx in range(len(data)):
imgs = []
for img_path in data[idx]['img_paths']:
img = Image.open(img_dir + img_path).convert('L')
img = img_transform(img)
imgs.append(img)
X.append(imgs)
Y.append(data[idx]['res'])
return X, Y

def get_hwf():
train_X, train_Y = get_data('./data/expr_train.json')
test_X, test_Y = get_data('./data/expr_test.json')
return train_X, train_Y, test_X, test_Y

if __name__ == "__main__":
train_X, train_Y, test_X, test_Y = get_hwf()
print(len(train_X), len(test_X))
print(len(train_X[0]), train_X[0][0].shape, train_Y[0])

Loading…
Cancel
Save