You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model_predict.py 1.6 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import os
  2. import torch
  3. import numpy as np
  4. from PIL import Image
  5. def cutImage(file_name):
  6. img = Image.open(file_name)
  7. oring_size = img.size
  8. n = img.size[0] // 256
  9. img = img.resize((n * 256, n * 256))
  10. res = []
  11. for i in range(n):
  12. for j in range(n):
  13. pos = (256 * j, 256 * i, 256 * (j + 1), 256 * (i + 1))
  14. res.append(img.crop(pos))
  15. return n, res, oring_size, img.size
  16. def mergeImage(n, imgs, o_size, n_size):
  17. img = Image.new('L', n_size)
  18. for i in range(n):
  19. for j in range(n):
  20. pos = (256 * j, 256 * i, 256 * (j + 1), 256 * (i + 1))
  21. img.paste(imgs[i * n + j], pos)
  22. img = img.resize(o_size)
  23. a = np.array(img)
  24. a[a > 17] = 17
  25. return Image.fromarray(a)
  26. def predict(model, input_path, output_dir):
  27. name, _ = os.path.splitext(input_path)
  28. name = os.path.split(name)[-1] + ".png"
  29. n, imgs, o_size, n_size = cutImage(input_path)
  30. res = []
  31. for img in imgs:
  32. img = torch.from_numpy(np.array(img)).float().unsqueeze(0)
  33. img = img / 255
  34. if torch.cuda.is_available():
  35. img = img.cuda()
  36. img = img.permute(0, 3, 1, 2)
  37. label = model(img)
  38. label = torch.argmax(label,
  39. dim=1).cpu().squeeze().numpy().astype(np.uint8)
  40. res.append(Image.fromarray(label))
  41. img = mergeImage(n, res, o_size, n_size)
  42. img.save(os.path.join(output_dir, name))
  43. if __name__ == "__main__":
  44. from model_define import init_model
  45. model = init_model()
  46. predict(model, './test.tif', '')