|
- import os
- import torch
- import numpy as np
- from PIL import Image
-
- def cutImage(file_name):
- img = Image.open(file_name)
- oring_size = img.size
- n = img.size[0] // 256
- img = img.resize((n * 256, n * 256))
- res = []
- for i in range(n):
- for j in range(n):
- pos = (256 * j, 256 * i, 256 * (j + 1), 256 * (i + 1))
- res.append(img.crop(pos))
- return n, res, oring_size, img.size
-
-
- def mergeImage(n, imgs, o_size, n_size):
- img = Image.new('L', n_size)
- for i in range(n):
- for j in range(n):
- pos = (256 * j, 256 * i, 256 * (j + 1), 256 * (i + 1))
- img.paste(imgs[i * n + j], pos)
- img = img.resize(o_size)
- a = np.array(img)
- a[a > 17] = 17
- return Image.fromarray(a)
-
-
- def predict(model, input_path, output_dir):
- name, _ = os.path.splitext(input_path)
- name = os.path.split(name)[-1] + ".png"
- n, imgs, o_size, n_size = cutImage(input_path)
- res = []
- for img in imgs:
- img = torch.from_numpy(np.array(img)).float().unsqueeze(0)
- img = img / 255
- if torch.cuda.is_available():
- img = img.cuda()
- img = img.permute(0, 3, 1, 2)
- label = model(img)
- label = torch.argmax(label,
- dim=1).cpu().squeeze().numpy().astype(np.uint8)
- res.append(Image.fromarray(label))
- img = mergeImage(n, res, o_size, n_size)
- img.save(os.path.join(output_dir, name))
-
-
- if __name__ == "__main__":
- from model_define import init_model
- model = init_model()
- predict(model, './test.tif', '')
|