You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test-infer.py 3.6 kB

2 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import torch
  2. import torch.nn as nn
  3. from torchvision import models, transforms
  4. from PIL import Image
  5. from fastapi import FastAPI, File, UploadFile
  6. from fastapi.responses import Response, JSONResponse
  7. import io
  8. import uvicorn
  9. import torch.nn.functional as F
  10. import os
  11. import argparse
  12. from c2net.context import prepare
  13. # import multiprocessing as mp
  14. # mp.set_start_method('spawn', True)
  15. c2net_context = prepare()
  16. files = [f for f in os.listdir(c2net_context.pretrain_model_path) if os.path.isdir(c2net_context.pretrain_model_path+ f"/{f}")]
  17. model_name = ""
  18. for file in files:
  19. if (not file.startswith(".")) and len(os.listdir(c2net_context.pretrain_model_path+ f"/{file}")) > 0:
  20. model_name = file
  21. break
  22. model_path = c2net_context.pretrain_model_path + f"/{model_name}"
  23. pth_file = ""
  24. for root, dirs, files in os.walk(model_path):
  25. for file in files:
  26. if file.endswith(".pth"):
  27. pth_file = file
  28. custom_model_path = c2net_context.pretrain_model_path+"/"+ model_name + "/" + pth_file
  29. batch_size = 10
  30. number_of_labels = 10
  31. classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  32. base_url = os.getenv('OPENI_SELF_URL')
  33. class Network(nn.Module):
  34. def __init__(self):
  35. super(Network, self).__init__()
  36. self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=1)
  37. self.bn1 = nn.BatchNorm2d(12)
  38. self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=1)
  39. self.bn2 = nn.BatchNorm2d(12)
  40. self.pool = nn.MaxPool2d(2, 2)
  41. self.conv4 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=1)
  42. self.bn4 = nn.BatchNorm2d(24)
  43. self.conv5 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=1)
  44. self.bn5 = nn.BatchNorm2d(24)
  45. self.fc1 = nn.Linear(24 * 10 * 10, 10)
  46. def forward(self, input):
  47. output = F.relu(self.bn1(self.conv1(input)))
  48. output = F.relu(self.bn2(self.conv2(output)))
  49. output = self.pool(output)
  50. output = F.relu(self.bn4(self.conv4(output)))
  51. output = F.relu(self.bn5(self.conv5(output)))
  52. output = output.view(-1, 24 * 10 * 10)
  53. output = self.fc1(output)
  54. return output
  55. app = FastAPI()
  56. TIMEOUT_KEEP_ALIVE = 20
  57. model = Network()
  58. transform_fn = transforms.Compose([
  59. transforms.Resize(32),
  60. transforms.CenterCrop(32),
  61. transforms.ToTensor(),
  62. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  63. ])
  64. def inference_one_image(image, card):
  65. image = image.to(card)
  66. model_path = custom_model_path
  67. model.load_state_dict(torch.load(model_path))
  68. model.eval()
  69. device = torch.device(card)
  70. model.to(device)
  71. output = model(image)
  72. _, index = torch.max(output, 1)
  73. return index
  74. @app.get(os.path.join(base_url,"test"))
  75. async def api_get():
  76. return "This is a GET request."
  77. @app.post(os.path.join(base_url,"infer"))
  78. async def classification_cnn(file: UploadFile = File(...)) -> Response:
  79. file_content = await file.read()
  80. image = Image.open(io.BytesIO(file_content))
  81. img = transform_fn(image)
  82. batch = torch.unsqueeze(img, 0)
  83. predicted = inference_one_image(batch, "cuda:0")
  84. if predicted is None:
  85. ret = {"error": "predicted is None"}
  86. else:
  87. ret = {"result": classes[predicted]}
  88. return JSONResponse(ret)
  89. if __name__ == '__main__':
  90. uvicorn.run(app, host='0.0.0.0', port=int(os.getenv('OPENI_SELF_PORT')), timeout_keep_alive=TIMEOUT_KEEP_ALIVE)

No Description